PyTorch Lightning
Data Pipeline
LightningDataModule
- Contains data loaders for training, validation, and test sets
- As an example, see the PASCAL VOC data module
- The optional
train_transforms
,val_transforms
, andtest_transforms
arguments are passed to theLightningDataModule
super class, allowing you to decouple the data and its transforms
DataLoader
- An iterator for sampling batches from a dataset
- Loads the data using multiple workers
- Takes care of batching and shuffling the samples
Dataset
- Defines how to read and process the data samples
- Often downloads the data package over the network first
- torchvision contains a collection of standard datasets
- As an example, see the PASCAL VOC dataset
- The constructor downloads and extracts the tar archive
- The
__getitem__
method parses the XML annotations, applies any transforms, and returns a data point
Transform
- Defines how a data point is transformed
- Can be used to normalize the samples, or augment the dataset with more samples
- torchvision contains a collection of standard transforms
- For example,
Resize
,RandomCrop
Model
LightningModule
- Incorporates the model, optimizers, and training and evaluation steps
- As an example, see the Faster R-CNN model
- The
forward
method defines the forward pass of the network configure_optimizers
constructs and returns an optimizertraining_step
runs a forward pass and computes the training lossvalidation_step
runs a detection and computes the validation score
Training
Trainer
- The
fit
method trains a model on given data - Training is controlled using constructor arguments such as
max_epochs
,precision
,gradient_clip_val
,gpus
,accelerator
- By default, model checkpoints are written to
lightning_logs/version_N/checkpoints/
, where N is the experiment version - To resume training from a model checkpoint, pass a file name or URL in the
resume_from_checkpoint
argument
Parsing Command Line Arguments
PyTorch Lightning provides a mechanism for easily mapping command line arguments to constructor arguments. For example, a Trainer can be constructed in the following way:
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
trainer = Trainer.from_argparse_args(args)
The constructor arguments are added to an argparse parser using the add_argparse_args
method, and the command line argumets are used to construct a Trainer using the from_argparse_args
method. You can add methods that call the add_argparse_args and from_argparse_args functions in any of your classes to add the change functionality. They parse the constructor docstring and type hints, so these must be present.