Using Augraphy in Pytorch deep learning pipeline augmentation process.#
In this example, a classifier will be trained in Pytorch deep learning framework to classify between clean and dirty document images.
Augraphy will be used to augment the clean images and generate the dirty set of images.
Transfer learning is applied on a Resnet model to further fine tune it in classifying the images into “dirty” and “clean”.
The image dataset is extracted from a Kaggle document denoising competition: Denoising ShabbyPages.
The notebook on this example can be downloaded at this link.
The first step is to download the image dataset. We are using gdown to download the data here.
[ ]:
# download and unzip document image data
!gdown --id 1uJPavzL7K3FFr9MEfZbdX3SNa1bGEdPu
!unzip shabby_small.zip
Next, we install the latest version of Augraphy from their repository.
[ ]:
# Install Augraphy, the main image augmentation library
!pip install git+https://github.com/sparkfish/augraphy
Then, we import some basic and Pytorch related libraries.
[ ]:
# import libraries
import sys
import glob
import cv2
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets, models
from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.optim import lr_scheduler
from torch.autograd import Variable
from matplotlib import pyplot as plt
from time import time
The next step is to create an Augrahpy augmentation pipeline. We can use
from augraphy import *
to import all the necessary functions and modules.
An Augraphy pipeline consists of 3 phases (ink, paper and post), and we use AugraphyPipeline(ink_phase, paper_phase, post_phase)
to initialize the pipeline instance.
Each ink_phase
, paper_phase
and post_phase
is a list
contains all of the augmentation instances. In this example, we will be using a simple augmentation pipeline and they will be using the augmentation default parameters.
[ ]:
# create an Augraphy augmentation pipeline
from augraphy import *
ink_phase = [Dithering(p=0.5),
InkBleed(p=0.5),
OneOf([LowInkRandomLines(p=1), LowInkPeriodicLines(p=1)]),
]
paper_phase = [ColorPaper(p=0.5)]
post_phase = [Markup(p=0.25),
DirtyRollers(p=0.25),
Scribbles(p=0.25),
BindingsAndFasteners(p=0.25),
BadPhotoCopy(p=0.25),
DirtyDrum(p=0.25),
]
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
Different set of image data requires a different transformation object. For clean images,
To_BGR
andTo_Tensor
are used to convert input image from PIL image into numpy image (in BGR channels). By usingdatasets.ImageFolder
, images are in PIL format, so we need to convert it to numpy image (in BGR channels) here.
For dirty images, other than To_BGR
and To_Tensor
, we have an additional augmentation_pipeline
to augment the image and generate dirty effect.
[ ]:
# create augmentation object
class To_Tensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, image):
# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
#if len(image.shape)>2:
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
if len(image.shape)<3:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
image = image.reshape(3, image.shape[0], image.shape[1])
return torch.from_numpy(image.astype("float64")/255)
class To_BGR(object):
"""Convert from PIL RGB Image to numpy array BGR image."""
def __call__(self, image):
# convert from PIL RGB to BGR
image_numpy = np.array(image)
if len(image_numpy.shape)<3:
return cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR)
else:
return cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR)
# augmentations for clean and dirty images
dirty_transforms = transforms.Compose([To_BGR(),
augmentation_pipeline,
To_Tensor()])
clean_transforms = transforms.Compose([To_BGR(), To_Tensor()])
Before the training, we need to define some training parameters such as
batch_size
,epochs
,shuffle
andimage_size
. For thebatch_size
value, the actual value will be 2 times of this value because we load a single batch of clean and dirty images in a single iteration, and they will be merged before we parse them into the training step.
[ ]:
# training parameters
# actual batch size will be x2 of this, due to clean+dirty images
batch_size = 16
epochs = 60
shuffle=1
image_size = 400
datasets.ImageFolder
is used to load training images from their folder and create the dataset. We have a separated transform objectdirty_transforms
for dirty images andclean_transforms
for clean images.
torch.utils.data.DataLoader
is used to create the dataloader of the training from the created dataset.
[ ]:
# Create datasets
train_dirty_path = "/content/shabby_small/train/train_dirty/"
train_clean_path = "/content/shabby_small/train/train_clean/"
validate_dirty_path = "/content/shabby_small/validate/validate_dirty/"
validate_clean_path = "/content/shabby_small/validate/validate_clean/"
test_dirty_path = "/content/shabby_small/test/test_dirty/"
test_clean_path = "/content/shabby_small/test/test_clean/"
# create datasets
train_dirty_data = datasets.ImageFolder(train_dirty_path,transform=dirty_transforms)
train_clean_data = datasets.ImageFolder(train_clean_path,transform=clean_transforms)
validate_dirty_data = datasets.ImageFolder(validate_dirty_path,transform=dirty_transforms)
validate_clean_data = datasets.ImageFolder(validate_clean_path,transform=clean_transforms)
test_dirty_data = datasets.ImageFolder(test_dirty_path,transform=dirty_transforms)
test_clean_data = datasets.ImageFolder(test_clean_path,transform=clean_transforms)
# create dataloader
train_dirty_loader = torch.utils.data.DataLoader(train_dirty_data, shuffle = shuffle, batch_size=batch_size)
train_clean_loader = torch.utils.data.DataLoader(train_clean_data, shuffle = shuffle, batch_size=batch_size)
validate_dirty_loader = torch.utils.data.DataLoader(validate_dirty_data, shuffle = shuffle, batch_size=batch_size)
validate_clean_loader = torch.utils.data.DataLoader(validate_clean_data, shuffle = shuffle, batch_size=batch_size)
test_dirty_loader = torch.utils.data.DataLoader(test_dirty_data, shuffle = shuffle, batch_size=batch_size)
test_clean_loader = torch.utils.data.DataLoader(test_clean_data, shuffle = shuffle, batch_size=batch_size)
Here, we check and make sure the images are loaded and augmented correctly by visual inspection.
[ ]:
# Display some of the clean and dirty images
plt.rcParams["figure.figsize"] = (20,12)
# clean
for data in train_clean_loader:
plt.figure()
for i in range(15):
plt.subplot(3,5,i+1)
plt.imshow(data[0][i].reshape(image_size, image_size, 3))
plt.title("clean")
break
# dirty
for data in train_dirty_loader:
plt.figure()
for i in range(15):
plt.subplot(3,5,i+1)
plt.imshow(data[0][i].reshape(image_size, image_size, 3))
plt.title("dirty")
break
We are using Resnet in the transfer learning process. From Pytorch framework, we can import it by using
models.resnet50(pretrained=True)
. Then, we need to freeze the existing model parameters and add a fully connected layer for training purpose. The loss function isBCEWithLogitsLoss
because we need a sigmoid layer for binary classification problem.
[ ]:
# Define training and model parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# define model
model =models.resnet50(pretrained=True)
#freeze all params
for params in model.parameters():
params.requires_grad_ = False
# add final layer based on current final layer inputs to single classification output
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)
# define optimizer and loss function
criterion = BCEWithLogitsLoss() # with sigmoid
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 362MB/s]
The model training process includes an early stopping feature and the validation data evaluations. In both of the training and validation iterations, we merge each batch dirty and clean images before parsing them to the model so that each batch has a same number of clean and dirty images.
[ ]:
# training iterations
train_losses = []
validate_losses = []
patient_count = 0
patient_threshold = 10
best_validate_loss = 100
for epoch in range(epochs):
train_loss = 0
start_time = time()
# reset to training mode
model.train()
for batch_number, (noisy_samples, clean_samples) in enumerate(zip(train_dirty_loader, train_clean_loader)):
noisy_data = noisy_samples[0].float()
noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float()
clean_data = clean_samples[0].float()
clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float()
# combine clean and dirty image
all_data = torch.cat([noisy_data, clean_data], dim=0).to(device)
all_labels = torch.cat([noisy_label, clean_label]).reshape(all_data.shape[0], 1).to(device)
optimizer.zero_grad()
# forward pass
predicted_labels = model(all_data)
# compute loss
loss = criterion(predicted_labels, all_labels)
train_loss += loss.cpu().data.item()
# backward pass
loss.backward()
# update parameters
optimizer.step()
# append each epoch training loss
average_train_loss = train_loss/(len(train_dirty_loader) + len(train_clean_loader) )
train_losses.append(average_train_loss)
# evaluate model:
model.eval()
with torch.no_grad():
validate_loss = 0
for batch_number, (noisy_samples, clean_samples) in enumerate(zip(validate_dirty_loader, validate_clean_loader)):
noisy_data = noisy_samples[0].float()
noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float()
clean_data = clean_samples[0].float()
clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float()
# combine clean and dirty image
all_data = torch.cat([noisy_data, clean_data], dim=0).to(device)
all_labels = torch.cat([noisy_label, clean_label]).reshape(all_data.shape[0], 1).to(device)
# forward pass
predicted_labels = model(all_data)
# compute loss
loss = criterion(predicted_labels, all_labels)
validate_loss += loss.cpu().data.item()
# append each epoch validate loss
average_validate_loss = validate_loss/(len(validate_dirty_loader) + len(validate_clean_loader))
validate_losses.append(average_validate_loss)
# EARLY STOPPING
if average_validate_loss<=best_validate_loss:
best_validate_loss = average_validate_loss
patient_count = 0
else:
patient_count += 1
if patient_count > patient_threshold:
break
print("Epoch = "+str(epoch)+", train_loss = "+str(average_train_loss)+", validate_loss = "+str(average_validate_loss)+", training duration = "+str(round(time()-start_time))+" seconds")
Epoch = 0, train_loss = 0.16265885250234022, validate_loss = 0.09095224934188943, training duration = 292 seconds
Epoch = 1, train_loss = 0.07528629535582007, validate_loss = 0.06522136818813651, training duration = 265 seconds
Epoch = 2, train_loss = 0.05403669600988307, validate_loss = 0.060843666133127715, training duration = 241 seconds
Epoch = 3, train_loss = 0.048908831642531764, validate_loss = 0.06135064368381312, training duration = 287 seconds
Epoch = 4, train_loss = 0.04750860836811182, validate_loss = 0.050219189081537094, training duration = 250 seconds
Epoch = 5, train_loss = 0.04251422401426769, validate_loss = 0.04269793512005555, training duration = 273 seconds
Epoch = 6, train_loss = 0.03895130033445794, validate_loss = 0.04299438656552842, training duration = 255 seconds
Epoch = 7, train_loss = 0.03376865359704669, validate_loss = 0.040767080346612555, training duration = 274 seconds
Epoch = 8, train_loss = 0.037427358352011296, validate_loss = 0.03735993959401783, training duration = 295 seconds
Epoch = 9, train_loss = 0.027450435317871048, validate_loss = 0.0340286838380914, training duration = 267 seconds
Epoch = 10, train_loss = 0.030007415280745523, validate_loss = 0.04129516693616384, training duration = 278 seconds
Epoch = 11, train_loss = 0.026311413248682896, validate_loss = 0.04192912134979116, training duration = 286 seconds
Epoch = 12, train_loss = 0.026286867640276507, validate_loss = 0.043772940352363024, training duration = 275 seconds
Epoch = 13, train_loss = 0.02571372339138534, validate_loss = 0.03408439231938437, training duration = 263 seconds
Epoch = 14, train_loss = 0.025083887924599212, validate_loss = 0.031926931997172926, training duration = 272 seconds
Epoch = 15, train_loss = 0.027622007015274792, validate_loss = 0.043157988178886865, training duration = 250 seconds
Epoch = 16, train_loss = 0.02393035617338993, validate_loss = 0.04161530176765824, training duration = 281 seconds
Epoch = 17, train_loss = 0.022220263144046796, validate_loss = 0.038027058355510235, training duration = 248 seconds
Epoch = 18, train_loss = 0.028428362529134242, validate_loss = 0.028791493098986775, training duration = 251 seconds
Epoch = 19, train_loss = 0.022880588873948265, validate_loss = 0.041628423686090266, training duration = 262 seconds
Epoch = 20, train_loss = 0.024806299792012064, validate_loss = 0.03297404432669282, training duration = 272 seconds
Epoch = 21, train_loss = 0.023054723957235495, validate_loss = 0.038020130222369185, training duration = 282 seconds
Epoch = 22, train_loss = 0.024718658873675074, validate_loss = 0.03629387138215335, training duration = 239 seconds
Epoch = 23, train_loss = 0.019438672135016176, validate_loss = 0.03745270794943759, training duration = 278 seconds
Epoch = 24, train_loss = 0.021008325894022496, validate_loss = 0.029737368011602053, training duration = 280 seconds
Epoch = 25, train_loss = 0.022721592334053683, validate_loss = 0.04427085086507233, training duration = 266 seconds
Epoch = 26, train_loss = 0.021167143543319004, validate_loss = 0.0402586883620212, training duration = 270 seconds
Epoch = 27, train_loss = 0.02301945056334683, validate_loss = 0.0231909407792907, training duration = 275 seconds
Epoch = 28, train_loss = 0.021348601735265153, validate_loss = 0.04714697344522727, training duration = 267 seconds
Epoch = 29, train_loss = 0.017056679160038873, validate_loss = 0.0243394438018042, training duration = 286 seconds
Epoch = 30, train_loss = 0.02305366937355025, validate_loss = 0.026760961308977323, training duration = 278 seconds
Epoch = 31, train_loss = 0.01824065751293901, validate_loss = 0.02530191934324409, training duration = 279 seconds
Epoch = 32, train_loss = 0.01787017357376654, validate_loss = 0.03044675217059098, training duration = 279 seconds
Epoch = 33, train_loss = 0.016991370946748137, validate_loss = 0.03152604773027921, training duration = 261 seconds
Epoch = 34, train_loss = 0.018007332896359446, validate_loss = 0.03396627848575774, training duration = 269 seconds
Epoch = 35, train_loss = 0.014609841947875372, validate_loss = 0.027170873455409156, training duration = 259 seconds
Epoch = 36, train_loss = 0.019973909323353592, validate_loss = 0.03745627410611824, training duration = 255 seconds
Epoch = 37, train_loss = 0.017210176370742663, validate_loss = 0.026869589315825386, training duration = 276 seconds
Once the training is done, loss curve is plotted.
[ ]:
# Display loss curve
x_train = [n for n in range(len(train_losses))]
x_validate = [n for n in range(len(validate_losses))]
plt.rcParams["figure.figsize"] = (10,10)
plt.figure()
plt.grid()
plt.plot(x_train, train_losses, color="red", label="Train")
plt.plot(x_validate, validate_losses, color="blue", label="Validate")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.ylim(0, 0.2)
(0.0, 0.2)
The last step is to test the trained model. The score the prediction is set at 0.5, so that any score > 0.5 is dirty image, while any score < 0.5 is a clean iamge. From the results, We can see the model is doing pretty well.
[ ]:
# Test with test set
plt.rcParams["figure.figsize"] = (20,5)
# Classify Clean images
model.eval()
with torch.no_grad():
for batch_number, (noisy_samples, clean_samples) in enumerate(zip(validate_dirty_loader, validate_clean_loader)):
noisy_data = noisy_samples[0].float().to(device)
noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float().to(device)
clean_data = clean_samples[0].float().to(device)
clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float().to(device)
# predict
predicted_noisy_labels = model(noisy_data)
predicted_clean_labels = model(clean_data)
# Display predictions
total_images = 5
fig=plt.figure()
fig.suptitle("Classify Test Set Clean Images")
for i in range(total_images):
plt.subplot(1,total_images,i+1)
plt.imshow(clean_samples[0][i].reshape(image_size, image_size, 3))
if predicted_clean_labels[i][0] > 0.5:
predicted_label = "Noisy"
else:
predicted_label = "Clean"
plt.title("Predicted: "+predicted_label)
break
[ ]:
# Classify Dirty images
plt.rcParams["figure.figsize"] = (20,5)
model.eval()
with torch.no_grad():
for batch_number, (noisy_samples, clean_samples) in enumerate(zip(validate_dirty_loader, validate_clean_loader)):
noisy_data = noisy_samples[0].float().to(device)
noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float().to(device)
clean_data = clean_samples[0].float().to(device)
clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float().to(device)
# predict
predicted_noisy_labels = model(noisy_data)
predicted_clean_labels = model(clean_data)
# Display predictions
total_images = 5
fig=plt.figure()
fig.suptitle("Classify Test Set Dirty Images")
for i in range(total_images):
plt.subplot(1,total_images,i+1)
plt.imshow(noisy_samples[0][i].reshape(image_size, image_size, 3))
if predicted_noisy_labels[i][0] > 0.5:
predicted_label = "Dirty"
else:
predicted_label = "Clean"
plt.title("Predicted: "+predicted_label)
break