Tensorflow Object Detection API – Custom Image Classification — Can I detect if a painting is a nativity painting?

Leonardo_Da_Vinci_-_Vergine delle Rocce (Louvre)
wikimedia.org
Open In Colab

Can I teach an AI model using Tensorflow Object Detection API, how to identify the theme of a painting?

The only way to find out is by trying! And I think you might be surprised by the outcome of this experiment.

But, I am going to narrow down my problem. I want to identify if a painting is about the Nativity or not. Which is a simpler binary classification problem.

But first, let me explain what a nativity painting is. A nativity painting is a painting where the subject is the birth of Jesus Christ, very revered in Christianity. During the last two thousand years, many famous artists like, for example, Leonardo Da Vinci, were commissioned to create paintings for Churches, so there should be plenty of paintings to pick about the nativity.

The Nativity By Boticelli
wikimedia.org

Before Starting

I created this notebook with code from the “Tensorflow tutorial on Image Classification. You can find the original tutorial in the link below:

https://www.tensorflow.org/tutorials/images/classification

Import TensorFlow and other libraries

import matplotlib.pyplot as plt 
import PIL
import tensorflow as tf
import os
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

import pandas as pd
import requests # to get image from the web
import shutil # to save it locally
import time
import numpy as np

Download Training, Validation and Test Image Data Sets

In order to train an image classifier, we need to have a training image dataset, validation dataset, and a test dataset.

Since we are training a binary image classifier, we will have images for two different classes:

  • Nativity
  • Others

During the training of our model we will use the training dataset to teach the model how to classify a painting: either a nativity painting or not(other).

At the end of each training cycle(epoch) we will use the validationg data-set to score how well the model is doing by calculating the accuracy and the loss. The accuracy measures how many times our model got the right answer. Higher is better. The loss measures the delta, i.e. the difference between the predicted value and the actual value. Lower is better.

It is important that the validation data set is separate from the training dataset because the AI model is very good at cheating. If you don’t separate the two, the model, will simply memorise the answers instead of learning the intrinsic characteristics of what we are trying to teach it.

At the end of the training we will also use a separate test dataset from the training and validation dataset, to do an independent benchmark of the model performance.

You will notice that we are downloading three files:

  • nativity_dataset.csv – contains all nativity paintings
  • other_dataset.csv – contains many paintings except nativity paintings
  • test_dataset.csv – contains labeled paintings

Wait a moment! Did I not just say that the training data set should be separate from the validation data set, so why keep it in the same files?

Yes, but because we are doing data exploration, it is a good thing to have some flexibility. Typically you are advised to have 80% of the training data and 20% of the validation data. But, this is not a hard and fast rule. We might want to change these percentages and see what gives us better results as part of our experimentation. This is also known as Hyperparameter tuning. On the other hand the test data set should be fixed, so we can compare different models with different architectures in a consistent way.

Some more utility functions just to help download the images from our image dataset. Notice that getFileNameFromUrl() does some very basic cleanup and extraction of the filename in the url.

def getFileNameFromUrl(url):
  firstpos=url.rindex("/")
  lastpos=len(url)
  filename=url[firstpos+1:lastpos]
  print(f"url={url} firstpos={firstpos} lastpos={lastpos} filename={filename}")
  return filename

def downloadImage(imageUrl, destinationFolder):
  filename = getFileNameFromUrl(imageUrl)
  # Open the url image, set stream to True, this will return the stream content.
  r = requests.get(imageUrl, stream = True)

  # Check if the image was retrieved successfully
  if r.status_code == 200:
      # Set decode_content value to True, otherwise the downloaded image file's size will be zero.
      r.raw.decode_content = True

      # Open a local file with wb ( write binary ) permission.
      filePath = os.path.join(destinationFolder, filename)
      if not os.path.exists(filePath):
        with open(filePath,'wb') as f:
            shutil.copyfileobj(r.raw, f)
        print('Image sucessfully Downloaded: ',filename)
        print("Sleeping for 1 seconds before attempting next download")
        time.sleep(1)
      else:
        print(f'Skipping image {filename} as it is already Downloaded: ')

  else:
      print(f'Image url={imageUrl} and filename={filename} Couldn't be retreived. HTTP Status={r.status_code}')
df = pd.read_csv("nativity_dataset.csv")

# create directory to which we download if it doesn't exist
destinationFolder = "/content/dataset/nativity"
os.makedirs(destinationFolder, exist_ok=True)

for i, row in df.iterrows():
  print(f"Index: {i}")
  print(f"{row['Image URL']}n")

  downloadImage(row["Image URL"], destinationFolder)
Index: 0
https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/LSE/CUMU/LSE_CUMU_TN07034-001.jpg

url=https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/LSE/CUMU/LSE_CUMU_TN07034-001.jpg firstpos=68 lastpos=93 filename=LSE_CUMU_TN07034-001.jpg
Image sucessfully Downloaded:  LSE_CUMU_TN07034-001.jpg
Sleeping for 1 seconds before attempting next download
Index: 1
https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/GMIII/MCAG/GMIII_MCAG_1947_188-001.jpg

Resize All images to be no bigger than 90000 pixels(width x height)

Some of the images in our dataset are over 80MB in size. If I try to resize an image with Python, it will try to load the image into memory. Not a great idea. So we are going to use Imagemick to do the job super fast.

!apt install imagemagick
Reading package lists... Done
Building dependency tree       
Reading state information... Done
imagemagick is already the newest version (8:6.9.7.4+dfsg-16ubuntu6.9).
0 upgraded, 0 newly installed, 0 to remove and 17 not upgraded.

Now we define the utility function resizeImages to resize images and copy from a sourceFolder to a destinationFolder.

def resizeImages(sourceFolder, destinationFolder, maxPixels=1048576):
  os.makedirs(destinationFolder, exist_ok=True)
  for path, subdirs, files in os.walk(sourceFolder):
      relativeDir=path.replace(sourceFolder, "")
      destinationFolderPath = destinationFolder + relativeDir
      os.makedirs(destinationFolderPath,exist_ok=True)
      for fileName in files:
          sourceFilepath=os.path.join(path,fileName)
          destinationFilepath=os.path.join(destinationFolderPath, fileName)
          print(f"sourceFilepath={sourceFilepath} destinationFilepath={destinationFilepath}")
          os.system(f"convert {sourceFilepath} -resize {maxPixels}@> {destinationFilepath}")
# resize training images
sourceFolder="/content/dataset"
destinationFolder = "/content/resized/dataset"
resizeImages(sourceFolder, destinationFolder, maxPixels=90000)

# resize testing images
sourceFolder="/content/test_dataset"
destinationFolder = "/content/resized/test_dataset"
resizeImages(sourceFolder, destinationFolder, maxPixels=90000)
sourceFilepath=/content/dataset/others/Quentin_Massys-The_Adoration_of_the_Magi-1526%2CMetropolitan_Museum_of_Art%2CNew_York.jpg destinationFilepath=/content/resized/dataset/others/Quentin_Massys-The_Adoration_of_the_Magi-1526%2CMetropolitan_Museum_of_Art%2CNew_York.jpg
...

Map image labels to numeric values

We are using Binary cross-entropy for our classification so we need to make sure our labels are either a 0 or a 1. Nativity = 1 and Others = 0

We will rename the folders to a 0 and a 1 since that is what tf.keras.preprocessing.image_dataset_from_directory uses to create the labels for our data set.

# !rm -fr  /content/resized/dataset/1
# !rm -fr  /content/resized/dataset/0
!mv /content/resized/dataset/nativity /content/resized/dataset/1
!mv /content/resized/dataset/others /content/resized/dataset/0

# !rm -fr  /content/resized/test_dataset/1
# !rm -fr  /content/resized/test_dataset/0
!mv /content/resized/test_dataset/nativity /content/resized/test_dataset/1
!mv /content/resized/test_dataset/others /content/resized/test_dataset/0

After downloading, you should now have a copy of the dataset available. There are 429 total images:

import pathlib
data_dir = pathlib.Path("/content/resized/dataset")
test_data_dir = pathlib.Path("/content/resized/test_dataset")

image_count = len(list(data_dir.glob('*/*')))
print(image_count)
454

Here are some paintings of the nativity:

nativity_label="1"
nativity = list(data_dir.glob(f'{nativity_label}/*'))

PIL.Image.open(str(nativity[0]))
This image has an empty alt attribute; its file name is image-13.png
PIL.Image.open(str(nativity[1]))
This image has an empty alt attribute; its file name is image-14.png

And some random paintings:

others_label="0"
others = list(data_dir.glob(f'{others_label}/*'))
PIL.Image.open(str(others[1]))
This image has an empty alt attribute; its file name is image-15.png
wikimedia.org
PIL.Image.open(str(others[2]))
This image has an empty alt attribute; its file name is image-16.png
wikimedia.org

Load using keras.preprocessing

Keras provides a bunch of really convenient functions to make our life easier when working with Tensorflow. tf.keras.preprocessing.image_dataset_from_directory is one of them. It loads images from the files into tf.data.DataSet format.

batch_size = 32
img_height = 300
img_width = 300

In general it is advised to split data into training data and validation data using a 80% 20% split. Remember, this is not a hard and fast rule.

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size, label_mode='binary')
Found 452 files belonging to 2 classes.
Using 362 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size,label_mode='binary')
Found 452 files belonging to 2 classes.
Using 90 files for validation.
#Retrieve a batch of images from the test set
test_data_dir = pathlib.Path("/content/resized/test_dataset")
test_batch_size=37
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
  test_data_dir,
  seed=200,
  image_size=(img_height, img_width),
  batch_size=test_batch_size,label_mode='binary')
Found 37 files belonging to 2 classes.

You can find the class names in the class_names attribute on these datasets. These correspond to the directory names in alphabetical order.

class_names = train_ds.class_names
print(class_names)
['0', '1']

Visualize the data

Here are the first 9 images from the training dataset.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    if labels[i] == 1.0:
      title = "Nativity"
    else:
      title = "Others"

    plt.title(title)
    plt.axis("off")
This image has an empty alt attribute; its file name is image-17.png

We inspect the image_batch and labels_batch variables.

The image_batch is a tensor of the shape (32, 300, 300, 3). This is a batch of 32 images of shape 300x300x3 (the last dimension refers to color channels RGB). The label_batch is a tensor of the shape (32,), these are corresponding labels to the 32 images.

You can call .numpy() on the image_batch and labels_batch tensors to convert them to a numpy.ndarray.

for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break
(32, 300, 300, 3)
(32, 1)

Configure the dataset for performance

Let’s make sure to use buffered prefetching so you can yield data from disk without having I/O become blocking. These are two important methods you should use when loading data.

Dataset.cache() keeps the images in memory after they’re loaded off disk during the first epoch. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache.

Dataset.prefetch() overlaps data preprocessing and model execution while training.

Interested readers can learn more about both methods, as well as how to cache data to disk in the data performance guide.

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

We define an utility function

labelMappings={"0":"Others","1":"Nativity",
               0.0:"Others",1.0 :"Nativity"}

def predictWithTestDataset(model):
  image_batch, label_batch = test_ds.as_numpy_iterator().next()
  predictions = model.predict_on_batch(image_batch).flatten()

  predictions = tf.where(predictions < 0.5, 0, 1)

  #print('Predictions:n', predictions.numpy())
  #print('Labels:n', label_batch)
  correctPredictions=0
  plt.figure(figsize=(20, 20))
  print(f"number predictions={len(predictions)}")
  for i in range(len(predictions)):
    ax = plt.subplot(8, 5, i +1)
    plt.imshow(image_batch[i].astype("uint8"))
    prediction = class_names[predictions[i]]
    predictionLabel = labelMappings[prediction]
    gtLabel = labelMappings[label_batch[i][0]]
    if gtLabel == predictionLabel:
      correctPredictions += 1
    plt.title(f"P={predictionLabel} GT={gtLabel}")
    plt.axis("off")

  accuracy = correctPredictions/len(predictions)
  print(f"Accuracy:{accuracy}")

Standardize the data

The RGB channel values are in the [0, 255] range. This is not ideal for a neural network; in general you should seek to make your input values small. Here, you will standardize values to be in the [0, 1] range by using a Rescaling layer.

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

There are two ways to use this layer. You can apply it to the dataset by calling map:

normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))
0.0 0.8131995

Or, you can include the layer inside your model definition, which can simplify deployment. Let’s use the second approach here.

Create the model

The model consists of three convolution blocks with a max pool layer in each of them. There’s a fully connected layer with 128 units on top of it that is activated by a relu activation function. This model has not been tuned for high accuracy, the goal of this tutorial is to show a standard approach.


model = Sequential([
  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(1, activation='sigmoid')
])

Compile the model

For this tutorial, choose the optimizers.Adam optimizer and losses.SparseCategoricalCrossentropy loss function. To view training and validation accuracy for each training epoch, pass the metrics argument.

model.compile(optimizer='adam', loss=keras.losses.BinaryCrossentropy(from_logits=True), metrics=[keras.metrics.BinaryAccuracy()])

Model summary

View all the layers of the network using the model’s summary method:

model.summary()
Model: "sequential_31"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling_27 (Rescaling)     (None, 300, 300, 3)       0         
_________________________________________________________________
conv2d_108 (Conv2D)          (None, 300, 300, 16)      448       
_________________________________________________________________
max_pooling2d_60 (MaxPooling (None, 150, 150, 16)      0         
_________________________________________________________________
conv2d_109 (Conv2D)          (None, 150, 150, 32)      4640      
_________________________________________________________________
max_pooling2d_61 (MaxPooling (None, 75, 75, 32)        0         
_________________________________________________________________
conv2d_110 (Conv2D)          (None, 75, 75, 64)        18496     
_________________________________________________________________
max_pooling2d_62 (MaxPooling (None, 37, 37, 64)        0         
_________________________________________________________________
flatten_20 (Flatten)         (None, 87616)             0         
_________________________________________________________________
dense_52 (Dense)             (None, 128)               11214976  
_________________________________________________________________
dense_53 (Dense)             (None, 1)                 129       
=================================================================
Total params: 11,238,689
Trainable params: 11,238,689
Non-trainable params: 0
_________________________________________________________________

Train the model

epochs=10
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/10
12/12 [==============================] - 1s 85ms/step - loss: 1.9371 - binary_accuracy: 0.5107 - val_loss: 0.7001 - val_binary_accuracy: 0.4444
Epoch 2/10
12/12 [==============================] - 1s 49ms/step - loss: 0.6491 - binary_accuracy: 0.6737 - val_loss: 0.7258 - val_binary_accuracy: 0.4778
Epoch 3/10
12/12 [==============================] - 1s 49ms/step - loss: 0.5943 - binary_accuracy: 0.6958 - val_loss: 0.7169 - val_binary_accuracy: 0.5333
Epoch 4/10
12/12 [==============================] - 1s 49ms/step - loss: 0.5111 - binary_accuracy: 0.7762 - val_loss: 0.7201 - val_binary_accuracy: 0.5667
Epoch 5/10
12/12 [==============================] - 1s 49ms/step - loss: 0.4013 - binary_accuracy: 0.8427 - val_loss: 0.6920 - val_binary_accuracy: 0.5667
Epoch 6/10
12/12 [==============================] - 1s 49ms/step - loss: 0.3027 - binary_accuracy: 0.8921 - val_loss: 0.8354 - val_binary_accuracy: 0.5889
Epoch 7/10
12/12 [==============================] - 1s 50ms/step - loss: 0.2438 - binary_accuracy: 0.9049 - val_loss: 0.8499 - val_binary_accuracy: 0.5778
Epoch 8/10
12/12 [==============================] - 1s 49ms/step - loss: 0.1725 - binary_accuracy: 0.9292 - val_loss: 0.9742 - val_binary_accuracy: 0.5222
Epoch 9/10
12/12 [==============================] - 1s 50ms/step - loss: 0.2792 - binary_accuracy: 0.8878 - val_loss: 0.9390 - val_binary_accuracy: 0.5222
Epoch 10/10
12/12 [==============================] - 1s 50ms/step - loss: 0.1347 - binary_accuracy: 0.9658 - val_loss: 0.9914 - val_binary_accuracy: 0.5889

Visualize training results

print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']
# acc = history.history['accuracy']
# val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [1.5215493440628052, 0.6543726325035095, 0.5897634625434875, 0.5006453990936279, 0.39839598536491394, 0.2903604209423065, 0.22604547441005707, 0.22543807327747345, 0.2558016777038574, 0.14820142090320587], 'binary_accuracy': [0.5580110549926758, 0.6491712927818298, 0.7099447250366211, 0.7845304012298584, 0.8425414562225342, 0.8867403268814087, 0.9116021990776062, 0.9060773253440857, 0.9033148884773254, 0.9558011293411255], 'val_loss': [0.7000985741615295, 0.7257931232452393, 0.7169376611709595, 0.7200638055801392, 0.6920430660247803, 0.8354127407073975, 0.8498525619506836, 0.9741556644439697, 0.9390344619750977, 0.9914490580558777], 'val_binary_accuracy': [0.4444444477558136, 0.47777777910232544, 0.5333333611488342, 0.5666666626930237, 0.5666666626930237, 0.5888888835906982, 0.5777778029441833, 0.5222222208976746, 0.5222222208976746, 0.5888888835906982]}
This image has an empty alt attribute; its file name is image-18.png

Looking at the plots, we are seeing a typical sign of overfitting. Overfitting happens when the model fits a bit too much with the training data but does poorly against the validation data. Notice that the accuracy increases along with the epochs for the training accuracy but with the validation data, the accuracy doesn’t increase, and in this case the loss increases.

predictWithTestDataset(model)
number predictions=37
Accuracy:0.6216216216216216
This image has an empty alt attribute; its file name is image-19.png

Data augmentation

Overfitting generally occurs when there are a small number of training examples. Data augmentation takes the approach of generating additional training data from your existing examples by augmenting them using random transformations that yield believable-looking images. This helps expose the model to more aspects of the data and generalize better.

You will implement data augmentation using the layers from tf.keras.layers.experimental.preprocessing. These can be included inside your model like other layers, and run on the GPU.

data_augmentation = keras.Sequential(
  [
    layers.experimental.preprocessing.RandomFlip("horizontal", 
                                                 input_shape=(img_height, 
                                                              img_width,
                                                              3)),
    layers.experimental.preprocessing.RandomRotation(0.1) 
    ]
)

Let’s visualize what a few augmented examples look like by applying data augmentation to the same image several times:

plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
  for i in range(9):
    augmented_images = data_augmentation(images)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_images[0].numpy().astype("uint8"))
    plt.axis("off")
This image has an empty alt attribute; its file name is image-20.png

You will use data augmentation to train a model in a moment.

Dropout

Another technique to reduce overfitting is to introduce Dropout to the network, a form of regularization.

When you apply Dropout to a layer it randomly drops out (by setting the activation to zero) a number of output units from the layer during the training process. Dropout takes a fractional number as its input value, in the form such as 0.1, 0.2, 0.4, etc. This means dropping out 10%, 20% or 40% of the output units randomly from the applied layer.

Let’s create a new neural network using layers.Dropout, then train it using augmented images.


model = Sequential([
  data_augmentation,
  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Dropout(0.2),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(1, activation='sigmoid')
])

Compile and train the model

from tensorflow import optimizers
model.compile(loss=keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=optimizers.RMSprop(lr=1e-4),
                  metrics=[keras.metrics.BinaryAccuracy()])
model.summary()
Model: "sequential_34"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_32 (Sequential)   (None, 300, 300, 3)       0         
_________________________________________________________________
rescaling_29 (Rescaling)     (None, 300, 300, 3)       0         
_________________________________________________________________
conv2d_114 (Conv2D)          (None, 300, 300, 16)      448       
_________________________________________________________________
max_pooling2d_66 (MaxPooling (None, 150, 150, 16)      0         
_________________________________________________________________
conv2d_115 (Conv2D)          (None, 150, 150, 32)      4640      
_________________________________________________________________
max_pooling2d_67 (MaxPooling (None, 75, 75, 32)        0         
_________________________________________________________________
conv2d_116 (Conv2D)          (None, 75, 75, 64)        18496     
_________________________________________________________________
max_pooling2d_68 (MaxPooling (None, 37, 37, 64)        0         
_________________________________________________________________
dropout_26 (Dropout)         (None, 37, 37, 64)        0         
_________________________________________________________________
flatten_22 (Flatten)         (None, 87616)             0         
_________________________________________________________________
dense_56 (Dense)             (None, 128)               11214976  
_________________________________________________________________
dense_57 (Dense)             (None, 1)                 129       
=================================================================
Total params: 11,238,689
Trainable params: 11,238,689
Non-trainable params: 0
_________________________________________________________________
epochs = 25
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/25
12/12 [==============================] - 2s 76ms/step - loss: 0.9417 - binary_accuracy: 0.5580 - val_loss: 0.7153 - val_binary_accuracy: 0.5222
Epoch 2/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6869 - binary_accuracy: 0.5338 - val_loss: 0.7236 - val_binary_accuracy: 0.5333
Epoch 3/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6557 - binary_accuracy: 0.5985 - val_loss: 0.8124 - val_binary_accuracy: 0.5222
Epoch 4/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6447 - binary_accuracy: 0.6315 - val_loss: 0.6829 - val_binary_accuracy: 0.5556
Epoch 5/25
12/12 [==============================] - 1s 65ms/step - loss: 0.6482 - binary_accuracy: 0.6273 - val_loss: 0.6708 - val_binary_accuracy: 0.5778
Epoch 6/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6482 - binary_accuracy: 0.6348 - val_loss: 0.6733 - val_binary_accuracy: 0.5556
Epoch 7/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6325 - binary_accuracy: 0.6592 - val_loss: 0.6762 - val_binary_accuracy: 0.5333
Epoch 8/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5994 - binary_accuracy: 0.6680 - val_loss: 0.6587 - val_binary_accuracy: 0.6111
Epoch 9/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6204 - binary_accuracy: 0.6904 - val_loss: 0.7240 - val_binary_accuracy: 0.5333
Epoch 10/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6343 - binary_accuracy: 0.6480 - val_loss: 0.6776 - val_binary_accuracy: 0.5667
Epoch 11/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6439 - binary_accuracy: 0.6107 - val_loss: 0.6811 - val_binary_accuracy: 0.5556
Epoch 12/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6361 - binary_accuracy: 0.6301 - val_loss: 0.6612 - val_binary_accuracy: 0.6222
Epoch 13/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6025 - binary_accuracy: 0.6949 - val_loss: 0.6725 - val_binary_accuracy: 0.5778
Epoch 14/25
12/12 [==============================] - 1s 61ms/step - loss: 0.5977 - binary_accuracy: 0.6868 - val_loss: 0.7521 - val_binary_accuracy: 0.5444
Epoch 15/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5713 - binary_accuracy: 0.6833 - val_loss: 0.6427 - val_binary_accuracy: 0.6444
Epoch 16/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5918 - binary_accuracy: 0.6939 - val_loss: 0.6515 - val_binary_accuracy: 0.6333
Epoch 17/25
12/12 [==============================] - 1s 61ms/step - loss: 0.5831 - binary_accuracy: 0.7253 - val_loss: 0.6556 - val_binary_accuracy: 0.5889
Epoch 18/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5626 - binary_accuracy: 0.7121 - val_loss: 0.6877 - val_binary_accuracy: 0.5667
Epoch 19/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5476 - binary_accuracy: 0.7327 - val_loss: 0.6398 - val_binary_accuracy: 0.6556
Epoch 20/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5551 - binary_accuracy: 0.7283 - val_loss: 0.6465 - val_binary_accuracy: 0.6333
Epoch 21/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5436 - binary_accuracy: 0.7312 - val_loss: 0.7083 - val_binary_accuracy: 0.5667
Epoch 22/25
12/12 [==============================] - 1s 65ms/step - loss: 0.5987 - binary_accuracy: 0.6781 - val_loss: 0.8078 - val_binary_accuracy: 0.5222
Epoch 23/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5534 - binary_accuracy: 0.7139 - val_loss: 0.6705 - val_binary_accuracy: 0.6111
Epoch 24/25
12/12 [==============================] - 1s 85ms/step - loss: 0.5617 - binary_accuracy: 0.7406 - val_loss: 0.6471 - val_binary_accuracy: 0.6111
Epoch 25/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5541 - binary_accuracy: 0.7303 - val_loss: 0.6263 - val_binary_accuracy: 0.7000

Visualize training results

After applying data augmentation and Dropout, there is less overfitting than before, and training and validation accuracy are closer aligned.

print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [0.8289327025413513, 0.6810811758041382, 0.6626855731010437, 0.6704213619232178, 0.650795042514801, 0.6398268938064575, 0.6561762690544128, 0.6122907400131226, 0.6228107810020447, 0.6147234439849854, 0.6149318814277649, 0.6190508604049683, 0.607336699962616, 0.5861756801605225, 0.593088686466217, 0.6063793301582336, 0.5983501672744751, 0.5894269347190857, 0.5698645114898682, 0.5585014224052429, 0.5401753783226013, 0.5774908065795898, 0.5512883067131042, 0.5710932016372681, 0.5434897541999817], 'binary_accuracy': [0.5386740565299988, 0.5386740565299988, 0.6077347993850708, 0.6160221099853516, 0.6022099256515503, 0.6325966715812683, 0.6325966715812683, 0.6685082912445068, 0.6961326003074646, 0.6602209806442261, 0.6408839821815491, 0.6574585437774658, 0.6712707281112671, 0.6933701634407043, 0.6574585437774658, 0.6767956018447876, 0.7044199109077454, 0.6933701634407043, 0.7154695987701416, 0.7265193462371826, 0.7265193462371826, 0.6906077265739441, 0.7154695987701416, 0.7071823477745056, 0.7375690340995789], 'val_loss': [0.7153488993644714, 0.7235575318336487, 0.8124216794967651, 0.6829271912574768, 0.6708189249038696, 0.673344612121582, 0.676236629486084, 0.6586815714836121, 0.7239749431610107, 0.677582323551178, 0.6810950636863708, 0.6611502170562744, 0.6725294589996338, 0.7520950436592102, 0.642659068107605, 0.6514749526977539, 0.6556094884872437, 0.687703013420105, 0.639808714389801, 0.6464514136314392, 0.7082778811454773, 0.8077911138534546, 0.670492947101593, 0.6470986008644104, 0.6263118386268616], 'val_binary_accuracy': [0.5222222208976746, 0.5333333611488342, 0.5222222208976746, 0.5555555820465088, 0.5777778029441833, 0.5555555820465088, 0.5333333611488342, 0.6111111044883728, 0.5333333611488342, 0.5666666626930237, 0.5555555820465088, 0.6222222447395325, 0.5777778029441833, 0.5444444417953491, 0.644444465637207, 0.6333333253860474, 0.5888888835906982, 0.5666666626930237, 0.6555555462837219, 0.6333333253860474, 0.5666666626930237, 0.5222222208976746, 0.6111111044883728, 0.6111111044883728, 0.699999988079071]}
This image has an empty alt attribute; its file name is image-21.png

Predict on new data

Finally, let’s use our model to classify an image that wasn’t included in the training or validation sets.

predictWithTestDataset(model)
number predictions=37
Accuracy:0.7027027027027027
This image has an empty alt attribute; its file name is image-22.png

Transfer Learning

We have already used image augmentation to try and get better results from our model, and I have to say the results were not bad at all. We were able to get a model with 65% accuracy. Surely, we can do better than that, if we are able to collect hundreds more, perhaps thousands of more images for our training and validation data set.

You can certainly do that, but there is another way that doesn’t involve the tedious and expensive process of collecting more training data: Transfer Learning.

With transfer learning, we can borrow a model that is already trained against thousands of images and re-train it for our use case, but with much fewer images than it would have been possible to if we trained a model from scratch.

To do so we can use Keras to download a pre-trained model with the Xception architecture already trained on Imagenet.

To perform transfer learning we need to freeze the weights of the base model and perform the training as we normally would. You will notice that we still do the image augmentation, and the regularization.

base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(img_height, img_width, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(img_height, img_width, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)

model.summary()
Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_28 (InputLayer)        [(None, 300, 300, 3)]     0         
_________________________________________________________________
sequential_32 (Sequential)   (None, 300, 300, 3)       0         
_________________________________________________________________
normalization_13 (Normalizat (None, 300, 300, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 10, 10, 2048)      20861480  
_________________________________________________________________
global_average_pooling2d_13  (None, 2048)              0         
_________________________________________________________________
dropout_28 (Dropout)         (None, 2048)              0         
_________________________________________________________________
dense_59 (Dense)             (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________
# model.compile(optimizer=keras.optimizers.Adam(),
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#               metrics=['accuracy'])
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 25
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/25
12/12 [==============================] - 7s 376ms/step - loss: 0.6989 - binary_accuracy: 0.5116 - val_loss: 0.6324 - val_binary_accuracy: 0.5333
Epoch 2/25
12/12 [==============================] - 4s 322ms/step - loss: 0.6106 - binary_accuracy: 0.5943 - val_loss: 0.5748 - val_binary_accuracy: 0.6222
Epoch 3/25
12/12 [==============================] - 4s 322ms/step - loss: 0.5557 - binary_accuracy: 0.6647 - val_loss: 0.5378 - val_binary_accuracy: 0.6889
Epoch 4/25
12/12 [==============================] - 4s 326ms/step - loss: 0.5280 - binary_accuracy: 0.6333 - val_loss: 0.5127 - val_binary_accuracy: 0.7222
Epoch 5/25
12/12 [==============================] - 4s 329ms/step - loss: 0.4751 - binary_accuracy: 0.7638 - val_loss: 0.4912 - val_binary_accuracy: 0.7889
Epoch 6/25
12/12 [==============================] - 4s 331ms/step - loss: 0.4586 - binary_accuracy: 0.7535 - val_loss: 0.4775 - val_binary_accuracy: 0.7556
Epoch 7/25
12/12 [==============================] - 4s 335ms/step - loss: 0.4328 - binary_accuracy: 0.7778 - val_loss: 0.4625 - val_binary_accuracy: 0.8111
Epoch 8/25
12/12 [==============================] - 4s 339ms/step - loss: 0.3951 - binary_accuracy: 0.8387 - val_loss: 0.4519 - val_binary_accuracy: 0.8111
Epoch 9/25
12/12 [==============================] - 4s 344ms/step - loss: 0.3745 - binary_accuracy: 0.8427 - val_loss: 0.4435 - val_binary_accuracy: 0.8111
Epoch 10/25
12/12 [==============================] - 4s 348ms/step - loss: 0.3631 - binary_accuracy: 0.8373 - val_loss: 0.4395 - val_binary_accuracy: 0.7889
Epoch 11/25
12/12 [==============================] - 4s 350ms/step - loss: 0.3449 - binary_accuracy: 0.8705 - val_loss: 0.4302 - val_binary_accuracy: 0.8111
Epoch 12/25
12/12 [==============================] - 4s 355ms/step - loss: 0.3409 - binary_accuracy: 0.8623 - val_loss: 0.4249 - val_binary_accuracy: 0.8222
Epoch 13/25
12/12 [==============================] - 4s 356ms/step - loss: 0.3491 - binary_accuracy: 0.8848 - val_loss: 0.4214 - val_binary_accuracy: 0.8333
Epoch 14/25
12/12 [==============================] - 4s 356ms/step - loss: 0.3522 - binary_accuracy: 0.8569 - val_loss: 0.4173 - val_binary_accuracy: 0.8333
Epoch 15/25
12/12 [==============================] - 4s 354ms/step - loss: 0.3106 - binary_accuracy: 0.8641 - val_loss: 0.4120 - val_binary_accuracy: 0.8333
Epoch 16/25
12/12 [==============================] - 4s 348ms/step - loss: 0.3108 - binary_accuracy: 0.8973 - val_loss: 0.4059 - val_binary_accuracy: 0.8333
Epoch 17/25
12/12 [==============================] - 4s 348ms/step - loss: 0.3041 - binary_accuracy: 0.8840 - val_loss: 0.4043 - val_binary_accuracy: 0.8333
Epoch 18/25
12/12 [==============================] - 4s 364ms/step - loss: 0.3106 - binary_accuracy: 0.8548 - val_loss: 0.3994 - val_binary_accuracy: 0.8444
Epoch 19/25
12/12 [==============================] - 4s 343ms/step - loss: 0.3072 - binary_accuracy: 0.8774 - val_loss: 0.4031 - val_binary_accuracy: 0.8333
Epoch 20/25
12/12 [==============================] - 4s 341ms/step - loss: 0.3008 - binary_accuracy: 0.8870 - val_loss: 0.3960 - val_binary_accuracy: 0.8444
Epoch 21/25
12/12 [==============================] - 4s 342ms/step - loss: 0.2959 - binary_accuracy: 0.8738 - val_loss: 0.3969 - val_binary_accuracy: 0.8444
Epoch 22/25
12/12 [==============================] - 4s 340ms/step - loss: 0.2655 - binary_accuracy: 0.8874 - val_loss: 0.3959 - val_binary_accuracy: 0.8444
Epoch 23/25
12/12 [==============================] - 4s 340ms/step - loss: 0.2452 - binary_accuracy: 0.9098 - val_loss: 0.3957 - val_binary_accuracy: 0.8444
Epoch 24/25
12/12 [==============================] - 4s 359ms/step - loss: 0.2532 - binary_accuracy: 0.9214 - val_loss: 0.3906 - val_binary_accuracy: 0.8444
Epoch 25/25
12/12 [==============================] - 4s 340ms/step - loss: 0.2512 - binary_accuracy: 0.9202 - val_loss: 0.3963 - val_binary_accuracy: 0.8556
print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [0.6730840802192688, 0.6024077534675598, 0.5369390249252319, 0.5121317505836487, 0.47049736976623535, 0.43795397877693176, 0.4308575689792633, 0.409035861492157, 0.39621278643608093, 0.38378414511680603, 0.36360302567481995, 0.3390505313873291, 0.33790066838264465, 0.34375235438346863, 0.3241599202156067, 0.3240811824798584, 0.30652204155921936, 0.3121297359466553, 0.29941326379776, 0.31102144718170166, 0.29046544432640076, 0.2721157371997833, 0.2742222845554352, 0.273193895816803, 0.2644941210746765], 'binary_accuracy': [0.541436493396759, 0.580110490322113, 0.7099447250366211, 0.6546961069107056, 0.7762430906295776, 0.7817679643630981, 0.7734806537628174, 0.8121547102928162, 0.8093922734260559, 0.8066298365592957, 0.8563535809516907, 0.8535911440849304, 0.8701657652854919, 0.8618784546852112, 0.8535911440849304, 0.889502763748169, 0.8812154531478882, 0.8729282021522522, 0.8784530162811279, 0.8674033284187317, 0.8839778900146484, 0.8784530162811279, 0.8839778900146484, 0.9171270728111267, 0.8950276374816895], 'val_loss': [0.6324149966239929, 0.5748280882835388, 0.537817120552063, 0.5126944184303284, 0.4911538362503052, 0.47753846645355225, 0.4625410735607147, 0.45192599296569824, 0.44351860880851746, 0.4394892752170563, 0.430193156003952, 0.42491695284843445, 0.42138800024986267, 0.4172518849372864, 0.4119878113269806, 0.40589141845703125, 0.40429064631462097, 0.399353951215744, 0.40307578444480896, 0.39602288603782654, 0.39687782526016235, 0.39587682485580444, 0.39574024081230164, 0.39059093594551086, 0.3963099718093872], 'val_binary_accuracy': [0.5333333611488342, 0.6222222447395325, 0.6888889074325562, 0.7222222089767456, 0.7888888716697693, 0.7555555701255798, 0.8111110925674438, 0.8111110925674438, 0.8111110925674438, 0.7888888716697693, 0.8111110925674438, 0.8222222328186035, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8444444537162781, 0.8333333134651184, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.855555534362793]}
This image has an empty alt attribute; its file name is image-23.png

Test model against test dataset

predictWithTestDataset(model)
WARNING:tensorflow:5 out of the last 12 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f58803470d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

number predictions=37
Accuracy:0.7837837837837838
This image has an empty alt attribute; its file name is image-24.png
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 2
history = model.fit(train_ds, epochs=epochs, validation_data=val_ds)
Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_28 (InputLayer)        [(None, 300, 300, 3)]     0         
_________________________________________________________________
sequential_32 (Sequential)   (None, 300, 300, 3)       0         
_________________________________________________________________
normalization_13 (Normalizat (None, 300, 300, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 10, 10, 2048)      20861480  
_________________________________________________________________
global_average_pooling2d_13  (None, 2048)              0         
_________________________________________________________________
dropout_28 (Dropout)         (None, 2048)              0         
_________________________________________________________________
dense_59 (Dense)             (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/2
12/12 [==============================] - 19s 1s/step - loss: 0.1918 - binary_accuracy: 0.9218 - val_loss: 0.3842 - val_binary_accuracy: 0.8556
Epoch 2/2
12/12 [==============================] - 16s 1s/step - loss: 0.1469 - binary_accuracy: 0.9509 - val_loss: 0.3520 - val_binary_accuracy: 0.8556
print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [0.1816166639328003, 0.14842070639133453], 'binary_accuracy': [0.9226519465446472, 0.950276255607605], 'val_loss': [0.38415849208831787, 0.3520398437976837], 'val_binary_accuracy': [0.855555534362793, 0.855555534362793]}
This image has an empty alt attribute; its file name is image-25.png
predictWithTestDataset(model)
WARNING:tensorflow:6 out of the last 13 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f58f13847b8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
number predictions=37
Accuracy:0.8378378378378378

Conclusion

We have been able to get some pretty good results with a limited dataset. This is indeed very promising! There are many ways to further improve these results, from gathering more images, experimenting with different image sizes, and even trying new model architectures.

Hope you found this notebook useful. And hope to see you for the next one. Happy coding!