Using Augraphy in Tensorflow deep learning pipeline augmentation process.#

In this example, a classifier will be trained in Tensorflow deep learning framework to classify between clean and dirty document images.

Augraphy will be used to augment the clean images and generate the dirty set of images.

Transfer learning is applied on a Resnet model to further fine tune it in classifying the images into “dirty” and “clean”.

The image dataset is extracted from a Kaggle document denoising competition: Denoising ShabbyPages.

The notebook on this example can be downloaded at this link.

  1. The first step is to download the image dataset. We are using gdown to download the data here.

[ ]:
# download and unzip document image data

!gdown --id 1uJPavzL7K3FFr9MEfZbdX3SNa1bGEdPu
!unzip shabby_small.zip
  1. Next, we install the latest version of Augraphy from their repository.

[ ]:
# Install  Augraphy, the main image augmentation library

!pip install git+https://github.com/sparkfish/augraphy
  1. Then, we import some basic and Tensorflow related libraries.

[ ]:
# import libraries

import sys
import glob
import cv2
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, Model
import tensorflow_datasets as tfds
from tensorflow.keras.callbacks import EarlyStopping

from tqdm import tqdm


from matplotlib import pyplot as plt
from time import time

AUTOTUNE = tf.data.AUTOTUNE

  1. The next step is to create an Augrahpy augmentation pipeline. We can use from augraphy import * to import all the necessary functions and modules.

An Augraphy pipeline consists of 3 phases (ink, paper and post), and we use AugraphyPipeline(ink_phase, paper_phase, post_phase) to initialize the pipeline instance.

Each ink_phase, paper_phase and post_phase is a list contains all of the augmentation instances. In this example, we will be using a simple augmentation pipeline and they will be using the augmentation default parameters.

[ ]:
# create an Augraphy augmentation pipeline

from augraphy import *

ink_phase = [Dithering(p=0.5),
             InkBleed(p=0.5),
             OneOf([LowInkRandomLines(p=1), LowInkPeriodicLines(p=1)]),
            ]

paper_phase = [ColorPaper(p=0.5)]

post_phase = [Markup(p=0.25),
              DirtyRollers(p=0.25),
              Scribbles(p=0.25),
              BindingsAndFasteners(p=0.25),
              BadPhotoCopy(p=0.25),
              DirtyDrum(p=0.25),
              ]

augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)

  1. Before the training, we need to define some training parameters such as batch_size, epochs, and image_size.

[ ]:
# training parameters

batch_size = 32
epochs = 60
image_size = 400
  1. glob is used to retrieve the path of all images.

[ ]:
# define data

clean_train_path = "/content/shabby_small/train/train_clean/clean/"
clean_validate_path = "/content/shabby_small/validate/validate_clean/clean/"
clean_test_path = "/content/shabby_small/test/test_clean/clean/"

train_files = glob.glob(clean_train_path+"/*.png")
validate_files = glob.glob(clean_validate_path+"/*.png")
test_files = glob.glob(clean_test_path+"/*.png")
  1. We use tf.data.Dataset.from_tensor_slices to load images and labels. Each of the loaded dataset is mapped with read_image to read the image from path. Dirty set of images are mapped with augment_image to enable the augmentations.

After the augmentation, it is important to map the dataset with tf.ensure_shape so that the shape of the dataset is known.

[ ]:
# create datasets

def read_image(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (image_size, image_size))
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image.set_shape((image_size, image_size, 3))
    return image, label

def augment_image(image, label):
    image_single = tf.reshape(image, [image_size, image_size, 3])
    augmented_image  = image_single.numpy()
    augmented_image = augmentation_pipeline(augmented_image)
    if len(augmented_image.shape)<3:
        augmented_image = tf.image.rgb_to_grayscale(augmented_image)
    augmented_image = tf.cast(augmented_image, tf.float32)
    augmented_image.set_shape((image_size, image_size, 3))
    return augmented_image, label


# create dataset from training data and labels
labels_clean = tf.cast(np.array([0 for _ in range(len(train_files))]), tf.float32)
labels_dirty = tf.cast(np.array([1 for _ in range(len(train_files))]), tf.float32)
train_dataset_clean = tf.data.Dataset.from_tensor_slices((train_files,labels_clean))
train_dataset_dirty = tf.data.Dataset.from_tensor_slices((train_files,labels_dirty))

labels_clean = tf.cast(np.array([0 for _ in range(len(validate_files))]), tf.float32)
labels_dirty = tf.cast(np.array([1 for _ in range(len(validate_files))]), tf.float32)
validate_dataset_clean = tf.data.Dataset.from_tensor_slices((validate_files,labels_clean))
validate_dataset_dirty = tf.data.Dataset.from_tensor_slices((validate_files,labels_dirty))

labels_clean = tf.cast(np.array([0 for _ in range(len(test_files))]), tf.float32)
labels_dirty = tf.cast(np.array([1 for _ in range(len(test_files))]), tf.float32)
test_dataset_clean = tf.data.Dataset.from_tensor_slices((test_files,labels_clean))
test_dataset_dirty = tf.data.Dataset.from_tensor_slices((test_files,labels_dirty))


# add preprocessing and augmentation functions into dataset
train_images_clean = (train_dataset_clean.shuffle(len(train_files), seed=42)
                      .map(read_image, num_parallel_calls=AUTOTUNE)
                      .cache()
                      .batch(batch_size)
                      .prefetch(AUTOTUNE)
                      )
train_images_dirty = (train_dataset_dirty
                      .shuffle(len(train_files), seed=42)
                      .map(read_image, num_parallel_calls=AUTOTUNE)
                      .map(lambda x, y: tf.py_function(augment_image, [x, y], [tf.float32,tf.float32]))
                      .map(lambda x, y: [tf.ensure_shape(x, (image_size, image_size, 3)), tf.ensure_shape(y, ()) ]   )
                      .cache()
                      .batch(batch_size)
                      .prefetch(AUTOTUNE)
                      )

validate_images_clean = (validate_dataset_clean.shuffle(len(validate_files), seed=42)
                        .map(read_image, num_parallel_calls=AUTOTUNE)
                        .cache()
                        .batch(batch_size)
                        .prefetch(AUTOTUNE)
                        )
validate_images_dirty = (validate_dataset_dirty
                        .shuffle(len(validate_files), seed=42)
                        .map(read_image, num_parallel_calls=AUTOTUNE)
                        .map(lambda x, y: tf.py_function(augment_image, [x, y], [tf.float32,tf.float32]))
                        .map(lambda x, y: [tf.ensure_shape(x, (image_size, image_size, 3)), tf.ensure_shape(y, ()) ]   )
                        .cache()
                        .batch(batch_size)
                        .prefetch(AUTOTUNE)
                        )

test_images_clean = (test_dataset_clean.
                     shuffle(len(test_files), seed=42)
                    .map(read_image, num_parallel_calls=AUTOTUNE)
                    .cache()
                    .batch(1)
                    .prefetch(AUTOTUNE)
                    )
test_images_dirty = (test_dataset_dirty
                    .shuffle(len(test_files), seed=42)
                    .map(read_image, num_parallel_calls=AUTOTUNE)
                    .map(lambda x, y: tf.py_function(augment_image, [x, y], [tf.float32,tf.float32]))
                    .map(lambda x, y: [tf.ensure_shape(x, (image_size, image_size, 3)), tf.ensure_shape(y, ()) ]   )
                    .cache()
                    .batch(1)
                    .prefetch(AUTOTUNE)
                    )

# merge dirty and clean, create the final datasets
train_images = train_images_clean.concatenate(train_images_dirty)
validate_images = validate_images_clean.concatenate(validate_images_dirty)
test_images = test_images_clean.concatenate(test_images_dirty)
  1. After that, we check and make sure the images are loaded and augmented correctly by visual inspection.

[ ]:
# display some of the training images

plt.rcParams["figure.figsize"] = (20,12)
# clean
for image in train_images:
    plt.figure()
    for i in range(image[0].shape[0]):
        plt.subplot(3,5,i+1)
        plt.imshow(image[0][i]/255)
        plt.title("clean image")
        if i>=14:
            break
    break

# dirty
for image in train_images:
    if image[1][0] == 1:
        plt.figure()
        for i in range(image[0].shape[0]):
            plt.subplot(3,5,i+1)
            plt.imshow(image[0][i]/255)
            plt.title("dirty image")
            if i>=14:
                break
        break
../_images/examples_tensorflow_integration_classification_example_16_0.png
../_images/examples_tensorflow_integration_classification_example_16_1.png
  1. We are using Resnet in the transfer learning process. In Tensorflow framework, we can import it by using tf.keras.applications.ResNet152. The model parameters are freezed by setting layer.trainable = False. Then, a new fully connected is added to train the model with new data.

[ ]:
# transfer learning using resnet

pretrained_model = tf.keras.applications.ResNet152(weights = 'imagenet', include_top = False, input_shape = (400,400,3))

for layer in pretrained_model.layers:
    layer.trainable = False

x = layers.Flatten()(pretrained_model.output)
x = layers.Dense(1000, activation='relu')(x)
predictions = layers.Dense(1, activation = 'sigmoid')(x)

model = Model(inputs = pretrained_model.input, outputs = predictions)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00005)
model.compile(optimizer=optimizer, loss='binary_crossentropy',metrics = ['accuracy'])


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels_notop.h5
234698864/234698864 [==============================] - 1s 0us/step
  1. Then, we add an early stopping instance and add it into the training process. The addition of early stopping prevents overfitting by monitoring the loss value of validation data. It stops the training process when the validation loss is not improved for more than 10 epochs.

[ ]:
# train model
callback = EarlyStopping(monitor='val_loss', patience=10)
train_history = model.fit(train_images, validation_data = validate_images, epochs=epochs, callbacks=[callback], verbose=1)
Epoch 1/60
42/42 [==============================] - 351s 8s/step - loss: 46.1014 - accuracy: 0.8290 - val_loss: 14.3259 - val_accuracy: 0.5000
Epoch 2/60
42/42 [==============================] - 43s 1s/step - loss: 6.9163 - accuracy: 0.1588 - val_loss: 0.6928 - val_accuracy: 0.5034
Epoch 3/60
42/42 [==============================] - 45s 1s/step - loss: 0.6914 - accuracy: 0.5107 - val_loss: 0.6905 - val_accuracy: 0.5172
Epoch 4/60
42/42 [==============================] - 44s 1s/step - loss: 0.6821 - accuracy: 0.5389 - val_loss: 0.6735 - val_accuracy: 0.5638
Epoch 5/60
42/42 [==============================] - 52s 1s/step - loss: 0.6383 - accuracy: 0.6359 - val_loss: 0.6023 - val_accuracy: 0.7138
Epoch 6/60
42/42 [==============================] - 52s 1s/step - loss: 0.5016 - accuracy: 0.8733 - val_loss: 0.4955 - val_accuracy: 0.8793
Epoch 7/60
42/42 [==============================] - 45s 1s/step - loss: 0.4923 - accuracy: 0.8725 - val_loss: 0.4547 - val_accuracy: 0.8966
Epoch 8/60
42/42 [==============================] - 44s 1s/step - loss: 0.4336 - accuracy: 0.9145 - val_loss: 0.4436 - val_accuracy: 0.8793
Epoch 9/60
42/42 [==============================] - 45s 1s/step - loss: 0.4167 - accuracy: 0.9198 - val_loss: 0.4315 - val_accuracy: 0.8793
Epoch 10/60
42/42 [==============================] - 45s 1s/step - loss: 0.4030 - accuracy: 0.9305 - val_loss: 0.4236 - val_accuracy: 0.8914
Epoch 11/60
42/42 [==============================] - 44s 1s/step - loss: 0.3894 - accuracy: 0.9489 - val_loss: 0.4258 - val_accuracy: 0.8724
Epoch 12/60
42/42 [==============================] - 52s 1s/step - loss: 0.3854 - accuracy: 0.9489 - val_loss: 0.4202 - val_accuracy: 0.8879
Epoch 13/60
42/42 [==============================] - 43s 1s/step - loss: 0.3816 - accuracy: 0.9534 - val_loss: 0.4170 - val_accuracy: 0.8966
Epoch 14/60
42/42 [==============================] - 44s 1s/step - loss: 0.3767 - accuracy: 0.9603 - val_loss: 0.4173 - val_accuracy: 0.8879
Epoch 15/60
42/42 [==============================] - 44s 1s/step - loss: 0.3728 - accuracy: 0.9611 - val_loss: 0.4151 - val_accuracy: 0.8931
Epoch 16/60
42/42 [==============================] - 44s 1s/step - loss: 0.3703 - accuracy: 0.9649 - val_loss: 0.4127 - val_accuracy: 0.8931
Epoch 17/60
42/42 [==============================] - 43s 1s/step - loss: 0.3672 - accuracy: 0.9679 - val_loss: 0.4128 - val_accuracy: 0.8948
Epoch 18/60
42/42 [==============================] - 44s 1s/step - loss: 0.3637 - accuracy: 0.9725 - val_loss: 0.4142 - val_accuracy: 0.8879
Epoch 19/60
42/42 [==============================] - 51s 1s/step - loss: 0.3620 - accuracy: 0.9718 - val_loss: 0.4123 - val_accuracy: 0.8931
Epoch 20/60
42/42 [==============================] - 45s 1s/step - loss: 0.3607 - accuracy: 0.9740 - val_loss: 0.4088 - val_accuracy: 0.9000
Epoch 21/60
42/42 [==============================] - 45s 1s/step - loss: 0.3578 - accuracy: 0.9809 - val_loss: 0.4133 - val_accuracy: 0.8914
Epoch 22/60
42/42 [==============================] - 44s 1s/step - loss: 0.3555 - accuracy: 0.9840 - val_loss: 0.4190 - val_accuracy: 0.8707
Epoch 23/60
42/42 [==============================] - 45s 1s/step - loss: 0.3520 - accuracy: 0.9908 - val_loss: 0.4444 - val_accuracy: 0.7966
Epoch 24/60
42/42 [==============================] - 52s 1s/step - loss: 0.3563 - accuracy: 0.9748 - val_loss: 0.4128 - val_accuracy: 0.8828
Epoch 25/60
42/42 [==============================] - 43s 1s/step - loss: 0.3516 - accuracy: 0.9863 - val_loss: 0.4283 - val_accuracy: 0.8379
Epoch 26/60
42/42 [==============================] - 44s 1s/step - loss: 0.3505 - accuracy: 0.9878 - val_loss: 0.4294 - val_accuracy: 0.8466
Epoch 27/60
42/42 [==============================] - 51s 1s/step - loss: 0.3515 - accuracy: 0.9847 - val_loss: 0.4139 - val_accuracy: 0.8862
Epoch 28/60
42/42 [==============================] - 53s 1s/step - loss: 0.3515 - accuracy: 0.9855 - val_loss: 0.4072 - val_accuracy: 0.8983
Epoch 29/60
42/42 [==============================] - 52s 1s/step - loss: 0.3515 - accuracy: 0.9832 - val_loss: 0.4013 - val_accuracy: 0.9138
Epoch 30/60
42/42 [==============================] - 44s 1s/step - loss: 0.3502 - accuracy: 0.9832 - val_loss: 0.3981 - val_accuracy: 0.9190
Epoch 31/60
42/42 [==============================] - 45s 1s/step - loss: 0.3483 - accuracy: 0.9893 - val_loss: 0.3991 - val_accuracy: 0.9207
Epoch 32/60
42/42 [==============================] - 44s 1s/step - loss: 0.3456 - accuracy: 0.9931 - val_loss: 0.4128 - val_accuracy: 0.8897
Epoch 33/60
42/42 [==============================] - 45s 1s/step - loss: 0.3444 - accuracy: 0.9969 - val_loss: 0.4242 - val_accuracy: 0.8638
Epoch 34/60
42/42 [==============================] - 44s 1s/step - loss: 0.3433 - accuracy: 0.9962 - val_loss: 0.4407 - val_accuracy: 0.8276
Epoch 35/60
42/42 [==============================] - 44s 1s/step - loss: 0.3445 - accuracy: 0.9901 - val_loss: 0.4166 - val_accuracy: 0.8793
Epoch 36/60
42/42 [==============================] - 45s 1s/step - loss: 0.3434 - accuracy: 0.9924 - val_loss: 0.4197 - val_accuracy: 0.8724
Epoch 37/60
42/42 [==============================] - 52s 1s/step - loss: 0.3425 - accuracy: 0.9954 - val_loss: 0.4200 - val_accuracy: 0.8690
Epoch 38/60
42/42 [==============================] - 44s 1s/step - loss: 0.3423 - accuracy: 0.9947 - val_loss: 0.4150 - val_accuracy: 0.8897
Epoch 39/60
42/42 [==============================] - 44s 1s/step - loss: 0.3414 - accuracy: 0.9985 - val_loss: 0.4184 - val_accuracy: 0.8724
Epoch 40/60
42/42 [==============================] - 44s 1s/step - loss: 0.3412 - accuracy: 0.9977 - val_loss: 0.4118 - val_accuracy: 0.8897
  1. Once the training is done, loss curve is plotted.

[ ]:
# plot training curve

plt.rcParams["figure.figsize"] = (10,10)
train_losses = train_history.history['loss']
validate_losses = train_history.history['val_loss']
accuracy = train_history.history['accuracy']

x_train = [i for i in range(len(train_losses))]
x_validate = [i for i in range(len(validate_losses))]

plt.figure()
plt.grid()
plt.plot(x_train, train_losses, "red", label='train Loss')
plt.plot(x_validate, validate_losses, "blue", label='validate Loss')
plt.title("Training and validation losses")
plt.legend(loc="upper right")
plt.ylim(0, 1)
(0.0, 1.0)
../_images/examples_tensorflow_integration_classification_example_22_1.png
  1. The last step is to test the trained model. The score the prediction is set at 0.5, so that any score > 0.5 is dirty image, while any score < 0.5 is a clean image. From the results, We can see the model is doing pretty well.

[ ]:
# predict testing images

predicted_labels = model.predict(test_images)
372/372 [==============================] - 59s 160ms/step
[ ]:
# Display predicted results

plt.rcParams["figure.figsize"] = (20,5)

# Classify clean images
fig=plt.figure()
fig.suptitle("Classify Test Set Clean Images")
n_clean = 0
for i, image in enumerate(test_images):
    if image[1] == 0:
        n_clean += 1
        plt.subplot(1,5,n_clean)
        plt.imshow(image[0][0]/255)
        if predicted_labels[i] < 0.5:
            predicted_label = "Clean"
        else:
            predicted_label = "Dirty"
        plt.title("Predicted: "+predicted_label)
        if n_clean >=5:
            break


../_images/examples_tensorflow_integration_classification_example_25_0.png
[ ]:
# Classify Dirty images

fig=plt.figure()
fig.suptitle("Classify Test Set Dirty Images")
n_dirty = 0
for i, image in enumerate(test_images):
    if image[1] == 1:
        n_dirty += 1
        plt.subplot(1,5,n_dirty)
        plt.imshow(image[0][0]/255)
        if predicted_labels[i] < 0.5:
            predicted_label = "Clean"
        else:
            predicted_label = "Dirty"
        plt.title("Predicted: "+predicted_label)
        if n_dirty >=5:
            break
../_images/examples_tensorflow_integration_classification_example_26_0.png