In this tutorial, we will implement a UNet to solve Kaggle's 2018 Data Science Bowl Competition. The challenge asks participants to find the location of nuclei from images of cells. The source of this tutorial and instructions to reproduce this analysis can be found at the thomasjpfan/ml-journal repo.
Exploring the Data¶
We can now define the datasets training and validiation datasets:
samples_dirs = list(d for d in Path('data/cells/').iterdir() if d.is_dir())
train_dirs, valid_dirs = train_test_split(
samples_dirs, test_size=0.2, random_state=42)
train_cell_ds = CellsDataset(train_dirs)
valid_cell_ds = CellsDataset(valid_dirs)
Overall the cell images come in different sizes, and fall in three different categories:
Most of the data is of Type 2. Training a single model to be able to find the nuclei for all types may not be the best option, but we will give it a try! For reference here are the corresponding masks for the above three types:
In order to train a neutral net, each image we feed in must be the same size. For our dataset, we break our images up into 256x256 patches. The UNet architecture typically has a hard time dealing with objects on the edge of an image. In order to deal with this issue, we pad our images by 16 using reflection. The image augmentation is handled by PatchedDataset
. Its implementation can be found in dataset.py
.
train_ds = PatchedDataset(
train_cell_ds, patch_size=(256, 256), padding=16, random_flips=True)
val_ds = PatchedDataset(
valid_cell_ds, patch_size=(256, 256), padding=16, random_flips=False)
Defining the Module¶
Now we define the UNet module with the pretrained VGG16_bn
as a feature encoder. The details of this module can be found in model.py
:
module = UNet(pretrained=True)
Freezer¶
The features generated by VGG16_bn
are prefixed with conv
. These weights will be frozen, which restricts training to only our decoder layers.
from skorch.callbacks import Freezer
freezer = Freezer('conv*')
Learning Rate Scheduler¶
We use a Cyclic Learning Rate scheduler to train our neutral network.
from skorch.callbacks import LRScheduler
from skorch.callbacks.lr_scheduler import CyclicLR
cyclicLR = LRScheduler(policy=CyclicLR,
base_lr=0.002,
max_lr=0.2,
step_size_up=540,
step_size_down=540)
Why is step_size_up 540?
Since we are using a batch size of 32, each epoch will have about 54 (len(train_ds)//32
) training iterations. We are also setting max_epochs
to 20, which gives a total of 1080 (max_epochs*54
) training iterations. We construct our Cyclic Learning Rate policy to peak at the 10th epoch by setting step_size_up
to 540. This can be shown with a plot of the learning rate:
_, ax = plt.subplots(figsize=(10, 5))
ax.set_title('Cyclic Learning Rate Scheduler')
ax.set_xlabel('Training iteration')
ax.set_ylabel('Learning Rate')
ax.plot(cyclicLR.simulate(1080, 0.002));
Checkpoint¶
A checkpoint is used to save the model weights with the best loss:
from skorch.callbacks import Checkpoint
checkpoint = Checkpoint(dirname='unet')
Custom Loss Module¶
Since we have padded our images and mask, the loss function will need to ignore the padding when calculating the binary log loss. We define a BCEWithLogitsLossPadding
to filter out the padding:
class BCEWithLogitsLossPadding(nn.Module):
def __init__(self, padding=16):
super().__init__()
self.padding = padding
def forward(self, input, target):
input = input.squeeze_(
dim=1)[:, self.padding:-self.padding, self.padding:-self.padding]
target = target.squeeze_(
dim=1)[:, self.padding:-self.padding, self.padding:-self.padding]
return binary_cross_entropy_with_logits(input, target)
Training Skorch NeutralNet¶
Now we can define the skorch
NeutralNet to train out UNet!
from skorch.net import NeuralNet
from skorch.helper import predefined_split
net = NeuralNet(
module,
criterion=BCEWithLogitsLossPadding,
criterion__padding=16,
batch_size=32,
max_epochs=20,
optimizer__momentum=0.9,
iterator_train__shuffle=True,
iterator_train__num_workers=4,
iterator_valid__shuffle=False,
iterator_valid__num_workers=4,
train_split=predefined_split(val_ds),
callbacks=[('freezer', freezer),
('cycleLR', cyclicLR),
('checkpoint', checkpoint)],
device='cuda'
)
Let's highlight some parametesr in our NeutralNet
:
criterion__padding=16
- Passes the padding to ourBCEWithLogitsLossPadding
initializer.train_split=predefined_split(val_ds)
- Sets theval_ds
to be the validation set during training.callbacks=[(..., Checkpoint(f_params='best_params.pt'))]
- Saves the best parameters tobest_params.pt
.
Next we train our UNet with the training dataset:
net.fit(train_ds);
Before we evaluate our model, we load the checkpoint with the best weights into the net
object:
net.load_params(checkpoint=checkpoint)
Evaluating our model¶
Now that we trained our model, lets see how we did with the three types presented at the beginning of this tutorial. Since our UNet module, is designed to output logits, we must convert these values to probabilities:
val_masks = net.predict(val_ds).squeeze(1)
val_prob_masks = 1/(1 + np.exp(-val_masks))
We plot the predicted mask with its corresponding true mask and original image:
Our UNet is able to predict the location of the nuclei for all three types of cell images!
Whats next?¶
In this tutorial, we used skorch
to train a UNet to predict the location of nuclei in an image. There are still areas that can be improved with our solution:
- Since there are three types of images in our dataset, we can improve our results by having three different UNet models for each of the three types.
- We can use traditional image processing to fill in the holes that our UNet produced.
- Our loss function can include a loss analogous to the compeititons metric of intersection over union.