Using Augraphy in Tensorflow deep learning pipeline augmentation process.#
In this example, a classifier will be trained in Tensorflow deep learning framework to classify between clean and dirty document images.
Augraphy will be used to augment the clean images and generate the dirty set of images.
Transfer learning is applied on a Resnet model to further fine tune it in classifying the images into “dirty” and “clean”.
The image dataset is extracted from a Kaggle document denoising competition: Denoising ShabbyPages.
The notebook on this example can be downloaded at this link.
The first step is to download the image dataset. We are using gdown to download the data here.
[ ]:
# download and unzip document image data
!gdown --id 1uJPavzL7K3FFr9MEfZbdX3SNa1bGEdPu
!unzip shabby_small.zip
Next, we install the latest version of Augraphy from their repository.
[ ]:
# Install Augraphy, the main image augmentation library
!pip install git+https://github.com/sparkfish/augraphy
Then, we import some basic and Tensorflow related libraries.
[ ]:
# import libraries
import sys
import glob
import cv2
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, Model
import tensorflow_datasets as tfds
from tensorflow.keras.callbacks import EarlyStopping
from tqdm import tqdm
from matplotlib import pyplot as plt
from time import time
AUTOTUNE = tf.data.AUTOTUNE
The next step is to create an Augrahpy augmentation pipeline. We can use
from augraphy import *
to import all the necessary functions and modules.
An Augraphy pipeline consists of 3 phases (ink, paper and post), and we use AugraphyPipeline(ink_phase, paper_phase, post_phase)
to initialize the pipeline instance.
Each ink_phase
, paper_phase
and post_phase
is a list
contains all of the augmentation instances. In this example, we will be using a simple augmentation pipeline and they will be using the augmentation default parameters.
[ ]:
# create an Augraphy augmentation pipeline
from augraphy import *
ink_phase = [Dithering(p=0.5),
InkBleed(p=0.5),
OneOf([LowInkRandomLines(p=1), LowInkPeriodicLines(p=1)]),
]
paper_phase = [ColorPaper(p=0.5)]
post_phase = [Markup(p=0.25),
DirtyRollers(p=0.25),
Scribbles(p=0.25),
BindingsAndFasteners(p=0.25),
BadPhotoCopy(p=0.25),
DirtyDrum(p=0.25),
]
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
Before the training, we need to define some training parameters such as
batch_size
,epochs
, andimage_size
.
[ ]:
# training parameters
batch_size = 32
epochs = 60
image_size = 400
glob
is used to retrieve the path of all images.
[ ]:
# define data
clean_train_path = "/content/shabby_small/train/train_clean/clean/"
clean_validate_path = "/content/shabby_small/validate/validate_clean/clean/"
clean_test_path = "/content/shabby_small/test/test_clean/clean/"
train_files = glob.glob(clean_train_path+"/*.png")
validate_files = glob.glob(clean_validate_path+"/*.png")
test_files = glob.glob(clean_test_path+"/*.png")
We use
tf.data.Dataset.from_tensor_slices
to load images and labels. Each of the loaded dataset is mapped withread_image
to read the image from path. Dirty set of images are mapped withaugment_image
to enable the augmentations.
After the augmentation, it is important to map the dataset with tf.ensure_shape
so that the shape of the dataset is known.
[ ]:
# create datasets
def read_image(image_path, label):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, (image_size, image_size))
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image.set_shape((image_size, image_size, 3))
return image, label
def augment_image(image, label):
image_single = tf.reshape(image, [image_size, image_size, 3])
augmented_image = image_single.numpy()
augmented_image = augmentation_pipeline(augmented_image)
if len(augmented_image.shape)<3:
augmented_image = tf.image.rgb_to_grayscale(augmented_image)
augmented_image = tf.cast(augmented_image, tf.float32)
augmented_image.set_shape((image_size, image_size, 3))
return augmented_image, label
# create dataset from training data and labels
labels_clean = tf.cast(np.array([0 for _ in range(len(train_files))]), tf.float32)
labels_dirty = tf.cast(np.array([1 for _ in range(len(train_files))]), tf.float32)
train_dataset_clean = tf.data.Dataset.from_tensor_slices((train_files,labels_clean))
train_dataset_dirty = tf.data.Dataset.from_tensor_slices((train_files,labels_dirty))
labels_clean = tf.cast(np.array([0 for _ in range(len(validate_files))]), tf.float32)
labels_dirty = tf.cast(np.array([1 for _ in range(len(validate_files))]), tf.float32)
validate_dataset_clean = tf.data.Dataset.from_tensor_slices((validate_files,labels_clean))
validate_dataset_dirty = tf.data.Dataset.from_tensor_slices((validate_files,labels_dirty))
labels_clean = tf.cast(np.array([0 for _ in range(len(test_files))]), tf.float32)
labels_dirty = tf.cast(np.array([1 for _ in range(len(test_files))]), tf.float32)
test_dataset_clean = tf.data.Dataset.from_tensor_slices((test_files,labels_clean))
test_dataset_dirty = tf.data.Dataset.from_tensor_slices((test_files,labels_dirty))
# add preprocessing and augmentation functions into dataset
train_images_clean = (train_dataset_clean.shuffle(len(train_files), seed=42)
.map(read_image, num_parallel_calls=AUTOTUNE)
.cache()
.batch(batch_size)
.prefetch(AUTOTUNE)
)
train_images_dirty = (train_dataset_dirty
.shuffle(len(train_files), seed=42)
.map(read_image, num_parallel_calls=AUTOTUNE)
.map(lambda x, y: tf.py_function(augment_image, [x, y], [tf.float32,tf.float32]))
.map(lambda x, y: [tf.ensure_shape(x, (image_size, image_size, 3)), tf.ensure_shape(y, ()) ] )
.cache()
.batch(batch_size)
.prefetch(AUTOTUNE)
)
validate_images_clean = (validate_dataset_clean.shuffle(len(validate_files), seed=42)
.map(read_image, num_parallel_calls=AUTOTUNE)
.cache()
.batch(batch_size)
.prefetch(AUTOTUNE)
)
validate_images_dirty = (validate_dataset_dirty
.shuffle(len(validate_files), seed=42)
.map(read_image, num_parallel_calls=AUTOTUNE)
.map(lambda x, y: tf.py_function(augment_image, [x, y], [tf.float32,tf.float32]))
.map(lambda x, y: [tf.ensure_shape(x, (image_size, image_size, 3)), tf.ensure_shape(y, ()) ] )
.cache()
.batch(batch_size)
.prefetch(AUTOTUNE)
)
test_images_clean = (test_dataset_clean.
shuffle(len(test_files), seed=42)
.map(read_image, num_parallel_calls=AUTOTUNE)
.cache()
.batch(1)
.prefetch(AUTOTUNE)
)
test_images_dirty = (test_dataset_dirty
.shuffle(len(test_files), seed=42)
.map(read_image, num_parallel_calls=AUTOTUNE)
.map(lambda x, y: tf.py_function(augment_image, [x, y], [tf.float32,tf.float32]))
.map(lambda x, y: [tf.ensure_shape(x, (image_size, image_size, 3)), tf.ensure_shape(y, ()) ] )
.cache()
.batch(1)
.prefetch(AUTOTUNE)
)
# merge dirty and clean, create the final datasets
train_images = train_images_clean.concatenate(train_images_dirty)
validate_images = validate_images_clean.concatenate(validate_images_dirty)
test_images = test_images_clean.concatenate(test_images_dirty)
After that, we check and make sure the images are loaded and augmented correctly by visual inspection.
[ ]:
# display some of the training images
plt.rcParams["figure.figsize"] = (20,12)
# clean
for image in train_images:
plt.figure()
for i in range(image[0].shape[0]):
plt.subplot(3,5,i+1)
plt.imshow(image[0][i]/255)
plt.title("clean image")
if i>=14:
break
break
# dirty
for image in train_images:
if image[1][0] == 1:
plt.figure()
for i in range(image[0].shape[0]):
plt.subplot(3,5,i+1)
plt.imshow(image[0][i]/255)
plt.title("dirty image")
if i>=14:
break
break
We are using Resnet in the transfer learning process. In Tensorflow framework, we can import it by using
tf.keras.applications.ResNet152
. The model parameters are freezed by settinglayer.trainable = False
. Then, a new fully connected is added to train the model with new data.
[ ]:
# transfer learning using resnet
pretrained_model = tf.keras.applications.ResNet152(weights = 'imagenet', include_top = False, input_shape = (400,400,3))
for layer in pretrained_model.layers:
layer.trainable = False
x = layers.Flatten()(pretrained_model.output)
x = layers.Dense(1000, activation='relu')(x)
predictions = layers.Dense(1, activation = 'sigmoid')(x)
model = Model(inputs = pretrained_model.input, outputs = predictions)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00005)
model.compile(optimizer=optimizer, loss='binary_crossentropy',metrics = ['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels_notop.h5
234698864/234698864 [==============================] - 1s 0us/step
Then, we add an early stopping instance and add it into the training process. The addition of early stopping prevents overfitting by monitoring the loss value of validation data. It stops the training process when the validation loss is not improved for more than 10 epochs.
[ ]:
# train model
callback = EarlyStopping(monitor='val_loss', patience=10)
train_history = model.fit(train_images, validation_data = validate_images, epochs=epochs, callbacks=[callback], verbose=1)
Epoch 1/60
42/42 [==============================] - 351s 8s/step - loss: 46.1014 - accuracy: 0.8290 - val_loss: 14.3259 - val_accuracy: 0.5000
Epoch 2/60
42/42 [==============================] - 43s 1s/step - loss: 6.9163 - accuracy: 0.1588 - val_loss: 0.6928 - val_accuracy: 0.5034
Epoch 3/60
42/42 [==============================] - 45s 1s/step - loss: 0.6914 - accuracy: 0.5107 - val_loss: 0.6905 - val_accuracy: 0.5172
Epoch 4/60
42/42 [==============================] - 44s 1s/step - loss: 0.6821 - accuracy: 0.5389 - val_loss: 0.6735 - val_accuracy: 0.5638
Epoch 5/60
42/42 [==============================] - 52s 1s/step - loss: 0.6383 - accuracy: 0.6359 - val_loss: 0.6023 - val_accuracy: 0.7138
Epoch 6/60
42/42 [==============================] - 52s 1s/step - loss: 0.5016 - accuracy: 0.8733 - val_loss: 0.4955 - val_accuracy: 0.8793
Epoch 7/60
42/42 [==============================] - 45s 1s/step - loss: 0.4923 - accuracy: 0.8725 - val_loss: 0.4547 - val_accuracy: 0.8966
Epoch 8/60
42/42 [==============================] - 44s 1s/step - loss: 0.4336 - accuracy: 0.9145 - val_loss: 0.4436 - val_accuracy: 0.8793
Epoch 9/60
42/42 [==============================] - 45s 1s/step - loss: 0.4167 - accuracy: 0.9198 - val_loss: 0.4315 - val_accuracy: 0.8793
Epoch 10/60
42/42 [==============================] - 45s 1s/step - loss: 0.4030 - accuracy: 0.9305 - val_loss: 0.4236 - val_accuracy: 0.8914
Epoch 11/60
42/42 [==============================] - 44s 1s/step - loss: 0.3894 - accuracy: 0.9489 - val_loss: 0.4258 - val_accuracy: 0.8724
Epoch 12/60
42/42 [==============================] - 52s 1s/step - loss: 0.3854 - accuracy: 0.9489 - val_loss: 0.4202 - val_accuracy: 0.8879
Epoch 13/60
42/42 [==============================] - 43s 1s/step - loss: 0.3816 - accuracy: 0.9534 - val_loss: 0.4170 - val_accuracy: 0.8966
Epoch 14/60
42/42 [==============================] - 44s 1s/step - loss: 0.3767 - accuracy: 0.9603 - val_loss: 0.4173 - val_accuracy: 0.8879
Epoch 15/60
42/42 [==============================] - 44s 1s/step - loss: 0.3728 - accuracy: 0.9611 - val_loss: 0.4151 - val_accuracy: 0.8931
Epoch 16/60
42/42 [==============================] - 44s 1s/step - loss: 0.3703 - accuracy: 0.9649 - val_loss: 0.4127 - val_accuracy: 0.8931
Epoch 17/60
42/42 [==============================] - 43s 1s/step - loss: 0.3672 - accuracy: 0.9679 - val_loss: 0.4128 - val_accuracy: 0.8948
Epoch 18/60
42/42 [==============================] - 44s 1s/step - loss: 0.3637 - accuracy: 0.9725 - val_loss: 0.4142 - val_accuracy: 0.8879
Epoch 19/60
42/42 [==============================] - 51s 1s/step - loss: 0.3620 - accuracy: 0.9718 - val_loss: 0.4123 - val_accuracy: 0.8931
Epoch 20/60
42/42 [==============================] - 45s 1s/step - loss: 0.3607 - accuracy: 0.9740 - val_loss: 0.4088 - val_accuracy: 0.9000
Epoch 21/60
42/42 [==============================] - 45s 1s/step - loss: 0.3578 - accuracy: 0.9809 - val_loss: 0.4133 - val_accuracy: 0.8914
Epoch 22/60
42/42 [==============================] - 44s 1s/step - loss: 0.3555 - accuracy: 0.9840 - val_loss: 0.4190 - val_accuracy: 0.8707
Epoch 23/60
42/42 [==============================] - 45s 1s/step - loss: 0.3520 - accuracy: 0.9908 - val_loss: 0.4444 - val_accuracy: 0.7966
Epoch 24/60
42/42 [==============================] - 52s 1s/step - loss: 0.3563 - accuracy: 0.9748 - val_loss: 0.4128 - val_accuracy: 0.8828
Epoch 25/60
42/42 [==============================] - 43s 1s/step - loss: 0.3516 - accuracy: 0.9863 - val_loss: 0.4283 - val_accuracy: 0.8379
Epoch 26/60
42/42 [==============================] - 44s 1s/step - loss: 0.3505 - accuracy: 0.9878 - val_loss: 0.4294 - val_accuracy: 0.8466
Epoch 27/60
42/42 [==============================] - 51s 1s/step - loss: 0.3515 - accuracy: 0.9847 - val_loss: 0.4139 - val_accuracy: 0.8862
Epoch 28/60
42/42 [==============================] - 53s 1s/step - loss: 0.3515 - accuracy: 0.9855 - val_loss: 0.4072 - val_accuracy: 0.8983
Epoch 29/60
42/42 [==============================] - 52s 1s/step - loss: 0.3515 - accuracy: 0.9832 - val_loss: 0.4013 - val_accuracy: 0.9138
Epoch 30/60
42/42 [==============================] - 44s 1s/step - loss: 0.3502 - accuracy: 0.9832 - val_loss: 0.3981 - val_accuracy: 0.9190
Epoch 31/60
42/42 [==============================] - 45s 1s/step - loss: 0.3483 - accuracy: 0.9893 - val_loss: 0.3991 - val_accuracy: 0.9207
Epoch 32/60
42/42 [==============================] - 44s 1s/step - loss: 0.3456 - accuracy: 0.9931 - val_loss: 0.4128 - val_accuracy: 0.8897
Epoch 33/60
42/42 [==============================] - 45s 1s/step - loss: 0.3444 - accuracy: 0.9969 - val_loss: 0.4242 - val_accuracy: 0.8638
Epoch 34/60
42/42 [==============================] - 44s 1s/step - loss: 0.3433 - accuracy: 0.9962 - val_loss: 0.4407 - val_accuracy: 0.8276
Epoch 35/60
42/42 [==============================] - 44s 1s/step - loss: 0.3445 - accuracy: 0.9901 - val_loss: 0.4166 - val_accuracy: 0.8793
Epoch 36/60
42/42 [==============================] - 45s 1s/step - loss: 0.3434 - accuracy: 0.9924 - val_loss: 0.4197 - val_accuracy: 0.8724
Epoch 37/60
42/42 [==============================] - 52s 1s/step - loss: 0.3425 - accuracy: 0.9954 - val_loss: 0.4200 - val_accuracy: 0.8690
Epoch 38/60
42/42 [==============================] - 44s 1s/step - loss: 0.3423 - accuracy: 0.9947 - val_loss: 0.4150 - val_accuracy: 0.8897
Epoch 39/60
42/42 [==============================] - 44s 1s/step - loss: 0.3414 - accuracy: 0.9985 - val_loss: 0.4184 - val_accuracy: 0.8724
Epoch 40/60
42/42 [==============================] - 44s 1s/step - loss: 0.3412 - accuracy: 0.9977 - val_loss: 0.4118 - val_accuracy: 0.8897
Once the training is done, loss curve is plotted.
[ ]:
# plot training curve
plt.rcParams["figure.figsize"] = (10,10)
train_losses = train_history.history['loss']
validate_losses = train_history.history['val_loss']
accuracy = train_history.history['accuracy']
x_train = [i for i in range(len(train_losses))]
x_validate = [i for i in range(len(validate_losses))]
plt.figure()
plt.grid()
plt.plot(x_train, train_losses, "red", label='train Loss')
plt.plot(x_validate, validate_losses, "blue", label='validate Loss')
plt.title("Training and validation losses")
plt.legend(loc="upper right")
plt.ylim(0, 1)
(0.0, 1.0)
The last step is to test the trained model. The score the prediction is set at 0.5, so that any score > 0.5 is dirty image, while any score < 0.5 is a clean image. From the results, We can see the model is doing pretty well.
[ ]:
# predict testing images
predicted_labels = model.predict(test_images)
372/372 [==============================] - 59s 160ms/step
[ ]:
# Display predicted results
plt.rcParams["figure.figsize"] = (20,5)
# Classify clean images
fig=plt.figure()
fig.suptitle("Classify Test Set Clean Images")
n_clean = 0
for i, image in enumerate(test_images):
if image[1] == 0:
n_clean += 1
plt.subplot(1,5,n_clean)
plt.imshow(image[0][0]/255)
if predicted_labels[i] < 0.5:
predicted_label = "Clean"
else:
predicted_label = "Dirty"
plt.title("Predicted: "+predicted_label)
if n_clean >=5:
break
[ ]:
# Classify Dirty images
fig=plt.figure()
fig.suptitle("Classify Test Set Dirty Images")
n_dirty = 0
for i, image in enumerate(test_images):
if image[1] == 1:
n_dirty += 1
plt.subplot(1,5,n_dirty)
plt.imshow(image[0][0]/255)
if predicted_labels[i] < 0.5:
predicted_label = "Clean"
else:
predicted_label = "Dirty"
plt.title("Predicted: "+predicted_label)
if n_dirty >=5:
break