Using Augraphy in Pytorch deep learning pipeline augmentation process.#

In this example, a classifier will be trained in Pytorch deep learning framework to classify between clean and dirty document images.

Augraphy will be used to augment the clean images and generate the dirty set of images.

Transfer learning is applied on a Resnet model to further fine tune it in classifying the images into “dirty” and “clean”.

The image dataset is extracted from a Kaggle document denoising competition: Denoising ShabbyPages.

The notebook on this example can be downloaded at this link.

  1. The first step is to download the image dataset. We are using gdown to download the data here.

[ ]:
# download and unzip document image data

!gdown --id 1uJPavzL7K3FFr9MEfZbdX3SNa1bGEdPu
!unzip shabby_small.zip
  1. Next, we install the latest version of Augraphy from their repository.

[ ]:
# Install  Augraphy, the main image augmentation library

!pip install git+https://github.com/sparkfish/augraphy
  1. Then, we import some basic and Pytorch related libraries.

[ ]:
# import libraries

import sys
import glob
import cv2
import numpy as np

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets, models
from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.optim import lr_scheduler

from torch.autograd import Variable

from matplotlib import pyplot as plt
from time import time

  1. The next step is to create an Augrahpy augmentation pipeline. We can use from augraphy import * to import all the necessary functions and modules.

An Augraphy pipeline consists of 3 phases (ink, paper and post), and we use AugraphyPipeline(ink_phase, paper_phase, post_phase) to initialize the pipeline instance.

Each ink_phase, paper_phase and post_phase is a list contains all of the augmentation instances. In this example, we will be using a simple augmentation pipeline and they will be using the augmentation default parameters.

[ ]:
# create an Augraphy augmentation pipeline

from augraphy import *

ink_phase = [Dithering(p=0.5),
             InkBleed(p=0.5),
             OneOf([LowInkRandomLines(p=1), LowInkPeriodicLines(p=1)]),
            ]

paper_phase = [ColorPaper(p=0.5)]

post_phase = [Markup(p=0.25),
              DirtyRollers(p=0.25),
              Scribbles(p=0.25),
              BindingsAndFasteners(p=0.25),
              BadPhotoCopy(p=0.25),
              DirtyDrum(p=0.25),
              ]

augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)

  1. Different set of image data requires a different transformation object. For clean images, To_BGR and To_Tensor are used to convert input image from PIL image into numpy image (in BGR channels). By using datasets.ImageFolder, images are in PIL format, so we need to convert it to numpy image (in BGR channels) here.

For dirty images, other than To_BGR and To_Tensor, we have an additional augmentation_pipeline to augment the image and generate dirty effect.

[ ]:
# create augmentation object

class To_Tensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, image):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        #if len(image.shape)>2:
        #    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        if len(image.shape)<3:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

        image = image.reshape(3, image.shape[0], image.shape[1])

        return torch.from_numpy(image.astype("float64")/255)

class To_BGR(object):
    """Convert from PIL RGB Image to numpy array BGR image."""
    def __call__(self, image):
        # convert from PIL RGB to BGR
        image_numpy = np.array(image)
        if len(image_numpy.shape)<3:
            return cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR)
        else:
            return cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR)

# augmentations for clean and dirty images
dirty_transforms = transforms.Compose([To_BGR(),
                                       augmentation_pipeline,
                                       To_Tensor()])
clean_transforms = transforms.Compose([To_BGR(), To_Tensor()])
  1. Before the training, we need to define some training parameters such as batch_size, epochs, shuffle and image_size. For the batch_size value, the actual value will be 2 times of this value because we load a single batch of clean and dirty images in a single iteration, and they will be merged before we parse them into the training step.

[ ]:
# training parameters

# actual batch size will be x2 of this, due to clean+dirty images
batch_size = 16
epochs = 60
shuffle=1
image_size = 400
  1. datasets.ImageFolder is used to load training images from their folder and create the dataset. We have a separated transform object dirty_transforms for dirty images and clean_transforms for clean images.

torch.utils.data.DataLoader is used to create the dataloader of the training from the created dataset.

[ ]:
# Create datasets

train_dirty_path = "/content/shabby_small/train/train_dirty/"
train_clean_path = "/content/shabby_small/train/train_clean/"

validate_dirty_path = "/content/shabby_small/validate/validate_dirty/"
validate_clean_path = "/content/shabby_small/validate/validate_clean/"

test_dirty_path = "/content/shabby_small/test/test_dirty/"
test_clean_path = "/content/shabby_small/test/test_clean/"

# create datasets
train_dirty_data = datasets.ImageFolder(train_dirty_path,transform=dirty_transforms)
train_clean_data = datasets.ImageFolder(train_clean_path,transform=clean_transforms)
validate_dirty_data = datasets.ImageFolder(validate_dirty_path,transform=dirty_transforms)
validate_clean_data = datasets.ImageFolder(validate_clean_path,transform=clean_transforms)
test_dirty_data = datasets.ImageFolder(test_dirty_path,transform=dirty_transforms)
test_clean_data = datasets.ImageFolder(test_clean_path,transform=clean_transforms)

# create dataloader
train_dirty_loader = torch.utils.data.DataLoader(train_dirty_data, shuffle = shuffle, batch_size=batch_size)
train_clean_loader = torch.utils.data.DataLoader(train_clean_data, shuffle = shuffle, batch_size=batch_size)
validate_dirty_loader = torch.utils.data.DataLoader(validate_dirty_data, shuffle = shuffle, batch_size=batch_size)
validate_clean_loader = torch.utils.data.DataLoader(validate_clean_data, shuffle = shuffle, batch_size=batch_size)
test_dirty_loader = torch.utils.data.DataLoader(test_dirty_data, shuffle = shuffle, batch_size=batch_size)
test_clean_loader = torch.utils.data.DataLoader(test_clean_data, shuffle = shuffle, batch_size=batch_size)

  1. Here, we check and make sure the images are loaded and augmented correctly by visual inspection.

[ ]:
# Display some of the clean and dirty images

plt.rcParams["figure.figsize"] = (20,12)
# clean
for data in train_clean_loader:
    plt.figure()
    for i in range(15):
        plt.subplot(3,5,i+1)
        plt.imshow(data[0][i].reshape(image_size, image_size, 3))
        plt.title("clean")
    break

# dirty
for data in train_dirty_loader:
    plt.figure()
    for i in range(15):
        plt.subplot(3,5,i+1)
        plt.imshow(data[0][i].reshape(image_size, image_size, 3))
        plt.title("dirty")
    break

../_images/examples_pytorch_integration_classification_example_16_0.png
../_images/examples_pytorch_integration_classification_example_16_1.png
  1. We are using Resnet in the transfer learning process. From Pytorch framework, we can import it by using models.resnet50(pretrained=True). Then, we need to freeze the existing model parameters and add a fully connected layer for training purpose. The loss function is BCEWithLogitsLoss because we need a sigmoid layer for binary classification problem.

[ ]:
# Define training and model parameters

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define model
model =models.resnet50(pretrained=True)
#freeze all params
for params in model.parameters():
  params.requires_grad_ = False
# add final layer based on current final layer inputs to single classification output
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)

# define optimizer and loss function
criterion = BCEWithLogitsLoss()  # with sigmoid
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 362MB/s]
  1. The model training process includes an early stopping feature and the validation data evaluations. In both of the training and validation iterations, we merge each batch dirty and clean images before parsing them to the model so that each batch has a same number of clean and dirty images.

[ ]:
# training iterations

train_losses = []
validate_losses = []

patient_count = 0
patient_threshold = 10

best_validate_loss = 100

for epoch in range(epochs):
    train_loss = 0
    start_time = time()

    # reset to training mode
    model.train()

    for batch_number, (noisy_samples, clean_samples) in enumerate(zip(train_dirty_loader, train_clean_loader)):


        noisy_data = noisy_samples[0].float()
        noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float()
        clean_data = clean_samples[0].float()
        clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float()

        # combine clean and dirty image
        all_data = torch.cat([noisy_data, clean_data], dim=0).to(device)
        all_labels = torch.cat([noisy_label, clean_label]).reshape(all_data.shape[0], 1).to(device)

        optimizer.zero_grad()

        # forward pass
        predicted_labels = model(all_data)

        # compute loss
        loss = criterion(predicted_labels, all_labels)
        train_loss += loss.cpu().data.item()

        # backward pass
        loss.backward()

        # update parameters
        optimizer.step()

    # append each epoch training loss
    average_train_loss = train_loss/(len(train_dirty_loader) + len(train_clean_loader) )
    train_losses.append(average_train_loss)

    # evaluate model:
    model.eval()
    with torch.no_grad():

        validate_loss = 0

        for batch_number, (noisy_samples, clean_samples) in enumerate(zip(validate_dirty_loader, validate_clean_loader)):

            noisy_data = noisy_samples[0].float()
            noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float()
            clean_data = clean_samples[0].float()
            clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float()

            # combine clean and dirty image
            all_data = torch.cat([noisy_data, clean_data], dim=0).to(device)
            all_labels = torch.cat([noisy_label, clean_label]).reshape(all_data.shape[0], 1).to(device)

            # forward pass
            predicted_labels = model(all_data)

            # compute loss
            loss = criterion(predicted_labels, all_labels)
            validate_loss += loss.cpu().data.item()

        # append each epoch validate loss
        average_validate_loss = validate_loss/(len(validate_dirty_loader) + len(validate_clean_loader))
        validate_losses.append(average_validate_loss)

    # EARLY STOPPING
    if average_validate_loss<=best_validate_loss:
        best_validate_loss = average_validate_loss
        patient_count = 0
    else:
        patient_count += 1
        if patient_count > patient_threshold:
            break

    print("Epoch = "+str(epoch)+", train_loss = "+str(average_train_loss)+", validate_loss = "+str(average_validate_loss)+", training duration = "+str(round(time()-start_time))+" seconds")

Epoch = 0, train_loss = 0.16265885250234022, validate_loss = 0.09095224934188943, training duration = 292 seconds
Epoch = 1, train_loss = 0.07528629535582007, validate_loss = 0.06522136818813651, training duration = 265 seconds
Epoch = 2, train_loss = 0.05403669600988307, validate_loss = 0.060843666133127715, training duration = 241 seconds
Epoch = 3, train_loss = 0.048908831642531764, validate_loss = 0.06135064368381312, training duration = 287 seconds
Epoch = 4, train_loss = 0.04750860836811182, validate_loss = 0.050219189081537094, training duration = 250 seconds
Epoch = 5, train_loss = 0.04251422401426769, validate_loss = 0.04269793512005555, training duration = 273 seconds
Epoch = 6, train_loss = 0.03895130033445794, validate_loss = 0.04299438656552842, training duration = 255 seconds
Epoch = 7, train_loss = 0.03376865359704669, validate_loss = 0.040767080346612555, training duration = 274 seconds
Epoch = 8, train_loss = 0.037427358352011296, validate_loss = 0.03735993959401783, training duration = 295 seconds
Epoch = 9, train_loss = 0.027450435317871048, validate_loss = 0.0340286838380914, training duration = 267 seconds
Epoch = 10, train_loss = 0.030007415280745523, validate_loss = 0.04129516693616384, training duration = 278 seconds
Epoch = 11, train_loss = 0.026311413248682896, validate_loss = 0.04192912134979116, training duration = 286 seconds
Epoch = 12, train_loss = 0.026286867640276507, validate_loss = 0.043772940352363024, training duration = 275 seconds
Epoch = 13, train_loss = 0.02571372339138534, validate_loss = 0.03408439231938437, training duration = 263 seconds
Epoch = 14, train_loss = 0.025083887924599212, validate_loss = 0.031926931997172926, training duration = 272 seconds
Epoch = 15, train_loss = 0.027622007015274792, validate_loss = 0.043157988178886865, training duration = 250 seconds
Epoch = 16, train_loss = 0.02393035617338993, validate_loss = 0.04161530176765824, training duration = 281 seconds
Epoch = 17, train_loss = 0.022220263144046796, validate_loss = 0.038027058355510235, training duration = 248 seconds
Epoch = 18, train_loss = 0.028428362529134242, validate_loss = 0.028791493098986775, training duration = 251 seconds
Epoch = 19, train_loss = 0.022880588873948265, validate_loss = 0.041628423686090266, training duration = 262 seconds
Epoch = 20, train_loss = 0.024806299792012064, validate_loss = 0.03297404432669282, training duration = 272 seconds
Epoch = 21, train_loss = 0.023054723957235495, validate_loss = 0.038020130222369185, training duration = 282 seconds
Epoch = 22, train_loss = 0.024718658873675074, validate_loss = 0.03629387138215335, training duration = 239 seconds
Epoch = 23, train_loss = 0.019438672135016176, validate_loss = 0.03745270794943759, training duration = 278 seconds
Epoch = 24, train_loss = 0.021008325894022496, validate_loss = 0.029737368011602053, training duration = 280 seconds
Epoch = 25, train_loss = 0.022721592334053683, validate_loss = 0.04427085086507233, training duration = 266 seconds
Epoch = 26, train_loss = 0.021167143543319004, validate_loss = 0.0402586883620212, training duration = 270 seconds
Epoch = 27, train_loss = 0.02301945056334683, validate_loss = 0.0231909407792907, training duration = 275 seconds
Epoch = 28, train_loss = 0.021348601735265153, validate_loss = 0.04714697344522727, training duration = 267 seconds
Epoch = 29, train_loss = 0.017056679160038873, validate_loss = 0.0243394438018042, training duration = 286 seconds
Epoch = 30, train_loss = 0.02305366937355025, validate_loss = 0.026760961308977323, training duration = 278 seconds
Epoch = 31, train_loss = 0.01824065751293901, validate_loss = 0.02530191934324409, training duration = 279 seconds
Epoch = 32, train_loss = 0.01787017357376654, validate_loss = 0.03044675217059098, training duration = 279 seconds
Epoch = 33, train_loss = 0.016991370946748137, validate_loss = 0.03152604773027921, training duration = 261 seconds
Epoch = 34, train_loss = 0.018007332896359446, validate_loss = 0.03396627848575774, training duration = 269 seconds
Epoch = 35, train_loss = 0.014609841947875372, validate_loss = 0.027170873455409156, training duration = 259 seconds
Epoch = 36, train_loss = 0.019973909323353592, validate_loss = 0.03745627410611824, training duration = 255 seconds
Epoch = 37, train_loss = 0.017210176370742663, validate_loss = 0.026869589315825386, training duration = 276 seconds
  1. Once the training is done, loss curve is plotted.

[ ]:
# Display loss curve

x_train = [n for n in range(len(train_losses))]
x_validate = [n for n in range(len(validate_losses))]
plt.rcParams["figure.figsize"] = (10,10)

plt.figure()
plt.grid()
plt.plot(x_train, train_losses,  color="red", label="Train")
plt.plot(x_validate, validate_losses,  color="blue", label="Validate")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.ylim(0, 0.2)
(0.0, 0.2)
../_images/examples_pytorch_integration_classification_example_22_1.png
  1. The last step is to test the trained model. The score the prediction is set at 0.5, so that any score > 0.5 is dirty image, while any score < 0.5 is a clean iamge. From the results, We can see the model is doing pretty well.

[ ]:
# Test with test set

plt.rcParams["figure.figsize"] = (20,5)

# Classify Clean images
model.eval()
with torch.no_grad():

    for batch_number, (noisy_samples, clean_samples) in enumerate(zip(validate_dirty_loader, validate_clean_loader)):

        noisy_data = noisy_samples[0].float().to(device)
        noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float().to(device)
        clean_data = clean_samples[0].float().to(device)
        clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float().to(device)

        # predict
        predicted_noisy_labels = model(noisy_data)
        predicted_clean_labels = model(clean_data)

        # Display predictions
        total_images = 5
        fig=plt.figure()
        fig.suptitle("Classify Test Set Clean Images")
        for i in range(total_images):

            plt.subplot(1,total_images,i+1)
            plt.imshow(clean_samples[0][i].reshape(image_size, image_size, 3))
            if predicted_clean_labels[i][0] > 0.5:
                 predicted_label = "Noisy"
            else:
                 predicted_label = "Clean"
            plt.title("Predicted: "+predicted_label)

        break

../_images/examples_pytorch_integration_classification_example_24_0.png
[ ]:
# Classify Dirty images

plt.rcParams["figure.figsize"] = (20,5)

model.eval()
with torch.no_grad():

    for batch_number, (noisy_samples, clean_samples) in enumerate(zip(validate_dirty_loader, validate_clean_loader)):

        noisy_data = noisy_samples[0].float().to(device)
        noisy_label = torch.tensor([1 for _ in range(len(noisy_samples[0]))]).float().to(device)
        clean_data = clean_samples[0].float().to(device)
        clean_label = torch.tensor([0 for _ in range(len(clean_samples[0]))]).float().to(device)

        # predict
        predicted_noisy_labels = model(noisy_data)
        predicted_clean_labels = model(clean_data)

        # Display predictions
        total_images = 5
        fig=plt.figure()
        fig.suptitle("Classify Test Set Dirty Images")
        for i in range(total_images):
            plt.subplot(1,total_images,i+1)
            plt.imshow(noisy_samples[0][i].reshape(image_size, image_size, 3))
            if predicted_noisy_labels[i][0] > 0.5:
                 predicted_label = "Dirty"
            else:
                 predicted_label = "Clean"
            plt.title("Predicted: "+predicted_label)

        break
../_images/examples_pytorch_integration_classification_example_25_0.png