Simplify your PyTorch model using PyTorch Lightning

Simplify your PyTorch model using PyTorch Lightning

Code for machine learning can get repetitive. In this blog post, you will learn to combat code repetition by using PyTorch Lightning.

Some code of machine learning can get highly repetitive. Think for example of the times you had to write a train loop, logging the evaluation results. Also loading and storing the model is a task that is often performed. With PyTorch lightning, you can get rid of the repetitive code by wrapping your code in a PyTorch Lightning module. Yet, all of the parts are fully customizable. You can define your own loading and storing procedures and you own evaluation code and train code.

import torch
from torch.nn import functional as F
from torch import nn
from transformers import AdamW
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report
import pytorch_lightning as pl
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import f1_score

class LitLegalBertClassifier(pl.LightningModule):

    def __init__(self, encoder, model):
        super().__init__()
        self.encoder = encoder
        self.model = model

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, attention_mask = self.encoder(batch)
        y_true = batch['label']
        y_pred = self.model(x, attention_mask=attention_mask)
        loss = F.cross_entropy(y_pred, y_true)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_pred_proba = base_model(*encoder(batch))
        y_pred = y_pred_proba.argmax(dim=1).numpy()
        y_true = batch['label'].numpy()
        idxs_to_labels = lambda idxs: np.array([ds_test.idx_to_label[idx] for idx in idxs])
        scores = dict()
        for label in ds_test.labels:
            scores[label] = f1_score(idxs_to_labels(y_true) == label, idxs_to_labels(y_pred) == label)
            self.log('f1_' + label, scores[label])
        return scores

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=1e-5)
        return optimizer
    
# Load the base model
base_model = LegalBertClassifierv2(legalbert, n_classes=3)
# Wrap it in a PyTorch Lightning module
model = LitLegalBertClassifier(encoder, base_model)
# Here is my train dataloader
train_dataloader = DataLoader(ds_train, batch_size=4, shuffle=True)
# And the test dataloader
test_dataloader = DataLoader(ds_test, batch_size=len(ds_test), shuffle=False)
# Setup a trainer
trainer = pl.Trainer(max_epochs=None, val_check_interval=1.0)
# And fit the model using the train dataloader and test dataloader
trainer.fit(model, train_dataloader, test_dataloader)

So, this approach will costs you some time in the beginning, but it will save you a lot of time when refactoring your code. I am currently using this in my research projects. What are your favorite tools and libraries for simplifying ML code?