Simplify your PyTorch model using PyTorch Lightning
Code for machine learning can get repetitive. In this blog post, you will learn to combat code repetition by using PyTorch Lightning.
Some code of machine learning can get highly repetitive. Think for example of the times you had to write a train loop, logging the evaluation results. Also loading and storing the model is a task that is often performed. With PyTorch lightning, you can get rid of the repetitive code by wrapping your code in a PyTorch Lightning module. Yet, all of the parts are fully customizable. You can define your own loading and storing procedures and you own evaluation code and train code.
import torch from torch.nn import functional as F from torch import nn from transformers import AdamW import pandas as pd import numpy as np from sklearn.metrics import classification_report import pytorch_lightning as pl from torch.utils.data import Dataset from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModel from sklearn.metrics import f1_score class LitLegalBertClassifier(pl.LightningModule): def __init__(self, encoder, model): super().__init__() self.encoder = encoder self.model = model def forward(self, x): # in lightning, forward defines the prediction/inference actions embedding = self.encoder(x) return embedding def training_step(self, batch, batch_idx): # training_step defines the train loop. It is independent of forward x, attention_mask = self.encoder(batch) y_true = batch['label'] y_pred = self.model(x, attention_mask=attention_mask) loss = F.cross_entropy(y_pred, y_true) self.log('train_loss', loss) return loss def validation_step(self, batch, batch_idx): y_pred_proba = base_model(*encoder(batch)) y_pred = y_pred_proba.argmax(dim=1).numpy() y_true = batch['label'].numpy() idxs_to_labels = lambda idxs: np.array([ds_test.idx_to_label[idx] for idx in idxs]) scores = dict() for label in ds_test.labels: scores[label] = f1_score(idxs_to_labels(y_true) == label, idxs_to_labels(y_pred) == label) self.log('f1_' + label, scores[label]) return scores def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=1e-5) return optimizer # Load the base model base_model = LegalBertClassifierv2(legalbert, n_classes=3) # Wrap it in a PyTorch Lightning module model = LitLegalBertClassifier(encoder, base_model) # Here is my train dataloader train_dataloader = DataLoader(ds_train, batch_size=4, shuffle=True) # And the test dataloader test_dataloader = DataLoader(ds_test, batch_size=len(ds_test), shuffle=False) # Setup a trainer trainer = pl.Trainer(max_epochs=None, val_check_interval=1.0) # And fit the model using the train dataloader and test dataloader trainer.fit(model, train_dataloader, test_dataloader)
So, this approach will costs you some time in the beginning, but it will save you a lot of time when refactoring your code. I am currently using this in my research projects. What are your favorite tools and libraries for simplifying ML code?