PyTorch Lightning

Data Pipeline

LightningDataModule

  • Contains data loaders for training, validation, and test sets
  • As an example, see the PASCAL VOC data module
  • The optional train_transforms, val_transforms, and test_transforms arguments are passed to the LightningDataModule super class, allowing you to decouple the data and its transforms

DataLoader

  • An iterator for sampling batches from a dataset
  • Loads the data using multiple workers
  • Takes care of batching and shuffling the samples

Dataset

  • Defines how to read and process the data samples
  • Often downloads the data package over the network first
  • torchvision contains a collection of standard datasets
  • As an example, see the PASCAL VOC dataset
  • The constructor downloads and extracts the tar archive
  • The __getitem__ method parses the XML annotations, applies any transforms, and returns a data point

Transform

  • Defines how a data point is transformed
  • Can be used to normalize the samples, or augment the dataset with more samples
  • torchvision contains a collection of standard transforms
  • For example, Resize, RandomCrop

Model

LightningModule

Training

Trainer

  • The fit method trains a model on given data
  • Training is controlled using constructor arguments such as max_epochs, precision, gradient_clip_val, gpus, accelerator
  • By default, model checkpoints are written to lightning_logs/version_N/checkpoints/, where N is the experiment version
  • To resume training from a model checkpoint, pass a file name or URL in the resume_from_checkpoint argument

Parsing Command Line Arguments

PyTorch Lightning provides a mechanism for easily mapping command line arguments to constructor arguments. For example, a Trainer can be constructed in the following way:

parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
trainer = Trainer.from_argparse_args(args)

The constructor arguments are added to an argparse parser using the add_argparse_args method, and the command line argumets are used to construct a Trainer using the from_argparse_args method. You can add methods that call the add_argparse_args and from_argparse_args functions in any of your classes to add the change functionality. They parse the constructor docstring and type hints, so these must be present.