Getting Started With PyTorch Lightning
PyTorch Lightning Guide
**Below is a documentation from PyTorch Lightning that i’ve shorten. I tend to use these methods the most in my projects. **
The general pattern is that each loop (training, validation, test loop) has 3 methods:
The methods in the LightningModule
are called in this order:
If you define a validation loop then
And if you define a test loop:
In every epoch, the loop methods are called in this frequency:
called every batch -
called every epoch
- def __init__(self):
Define Model Architecture
- def forward(self,x)
Forward pass our data
- def training_step(self, batch_idx): (REQUIRED)
- batch: The Output of your DataLoader. A tensor, tuple or list.
- batch_idx (int): Integer displaying index of this batch
- optimizer_idx (int): When using multiple optimizer, this argument is used.
- hiddens (Tensor): Passed in if truncated_bptt_steps > 0
Returns: Dict with loss key and optional log or progress bar keys.
- loss: tensor scalar (required)
- progress_bar: Dict for progress bar display (Tensor)
- log: Dict for metrics to add to logger ```python output = { ‘loss’: loss, # required ‘progress_bar’: {‘training_loss’:loss}, # Optional, must be tensor ‘log’: logger_logs: {‘losses’: logger_loss} } return output
- def training_step_end(batch_parts_outputs): (OPTIONAL)
Use this when training with dp or ddp2 because
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Parameters: batch_parts_outputs: What you return in training_step for each batch part.
Return: Dict with loss key and optional log or progress bar keys. - loss -> tensor scalar REQUIRED
- progress_bar -> Dict for progress bar display. Must have only tensors - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) - training_epoch_end(outputs)
Called at the end of the training epoch with the outputs of all training steps.
- outputs: List of outputs you defined in training_step() or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Returns: Dict or OrderedDict. May contain the following optional keys:
- log (metrics to be added to the logger; only tensors)
- any metric used in a callback (e.g. early stopping).
The outputs here are strictly for logging or progress bar. If you don’t need to display anything, don’t return anything.
- def configure_optimizers(self): (REQUIRED)
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Single Optimizer List of Tuple - List of Optimizer ```python def configure_optimizers(self): # most cases opt = Adam(self.parameters(), lr=1e-3) return opt
def configure_optimizers(self): generator_opt = Adam(self.model_gen.parameters(), lr=0.01) disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) return generator_opt, disriminator_opt
- def validation_step(batch, batch_idx, dataloader_idx): (OPTIONAL)
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- batch: The output of your DataLoader.
- batch_idx (int): The index of this batch
- dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple val datasets used)
Return: Dict or OrderedDict - passed to
. If you definedvalidation_step_end()
it will go to that first. - def validation_step_end(batch_parts_outputs): (OPTIONAL)
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
Parameters: batch_parts_outputs : What you return in
for each batch part.Return: Dict or OrderedDict - passed to the
method. - def validation_epoch_end(outputs: (OPTIONAL)
Called at the end of the validation epoch with the outputs of all validation steps.
- outputs: List of outputs you defined in
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Retruns: Dict or OrderedDict. May have the following optional keys:
- progress_bar (dict for progress bar display; only tensors)
- log (dict of metrics to add to logger; only tensors). ``` python def validation_epoch_end(self, outputs): val_acc_mean = 0 for output in outputs: val_acc_mean += output[‘val_acc’] val_acc_mean /= len(outputs) tqdm_dict = {‘val_acc’: val_acc_mean.item()} # show val_acc in progress bar but only log val_loss results = { ‘progress_bar’: tqdm_dict, ‘log’: {‘val_acc’: val_acc_mean.item()} } return results
- outputs: List of outputs you defined in
- def test_step(batch, batch_idx, dataloader_idx): (OPTIONAL)
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- batch: The output of your DataLoader.
- batch_idx (int): The index of this batch
- dataloader_idx (int): he index of the dataloader that produced this batch (only if multiple test datasets used).
Return: Dict or OrderedDict - Dict or OrderedDict - passed to the
method. If you definedtest_step_end()
it will go to that first. - def test_step_end(batch_parts_outputs): (OPTIONAL)
Use this when testing with dp or ddp2 because
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Parameters: batch_parts_outputs: What you return in
for each batch part. Return: Dict or OrderedDict: Dict or OrderedDict - passed to thetest_epoch_end()
. - def test_epoch_end(outputs: (OPTIONAL)
Called at the end of a test epoch with the output of all test steps.
- outputs: List of outputs you defined in
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader
Retruns: Dict or OrderedDict. May have the following optional keys:
- progress_bar (dict for progress bar display; only tensors)
- log (dict of metrics to add to logger; only tensors). ``` python def test_epoch_end(self, outputs): test_acc_mean = 0 for output in outputs: test_acc_mean += output[‘test_acc’] test_acc_mean /= len(outputs) tqdm_dict = {‘test_acc’: test_acc_mean.item()} # show val_acc in progress bar but only log val_loss results = { ‘progress_bar’: tqdm_dict, ‘log’: {‘test_acc’: test_acc_mean.item()} } return results
- outputs: List of outputs you defined in
- def prepare_data(self): (OPTIONAL)
Use this to download and prepare data. In distributed (GPU, TPU), this will only be called once. This is called before requesting the dataloaders:
def prepare_data(self): download_imagenet() clean_imagenet() cache_imagenet()
- def train_dataloader(self): (Required)
Implement a PyTorch DataLoader for training.
- Single PyTorch DataLoader
- def val_dataloader(self): (Optional)
Implement a PyTorch DataLoader for validation.
- Single PyTorch DataLoader
- def test_dataloader(self): (Optional)
Implement a PyTorch DataLoader for Testing.
- Single PyTorch DataLoader
For more info visit:
from pytorch_lightning import Trainer
model = LitMNIST()
trainer = Trainer(gpus=1)
max_epochs: Stop training once this number of epochs is reached. By Default max_epochs: 1000
min_epochs: Force training for at least these many epochs
max_steps: Stop training after this number of steps. Disabled by default (None).
min_steps: Force training for at least these number of steps. Disabled by default (None). logger: Logger (or iterable collection of loggers) for experiment tracking.
checkpoint_callback: Callback for checkpointing.
early_stop_callback (:class:pytorch_lightning.callbacks.EarlyStopping
callbacks: Add a list of callbacks.
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed
gradient_clip_val: 0 means don’t clip.
process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training.
gpus: Which GPUs to train on.
If enabled and gpus
is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in “exclusive mode”, such
that only one process at a time can access them.
num_tpu_cores: How many TPU cores to train on (1 or 8).
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value 0
disables progress bar.
Ignored when a custom callback is passed to :paramref:~Trainer.callbacks
check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
train_percent_check: How much of training dataset to check.
val_percent_check: How much of validation dataset to check.
test_percent_check: How much of test dataset to check.
val_check_interval: How often within one training epoch to check the validation set
log_save_interval: Writes logs to disk this often
row_log_interval: How often to add logging rows (does not write to disk)
distributed_backend: The distributed backend to use.
precision: Full precision (32), half precision (16).
weights_summary: Prints a summary of the weights when training begins.
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in default_root_dir
amp_level: The optimization level to use (O1, O2, etc…).
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
Early Stopping
Stop training when a monitored quantity has stopped improving.
- monitor (str): quantity to be monitored. Default: ‘val_loss’.
- min_delta (float) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Default: 0.
- patience (int) – number of epochs with no improvement after which training will be stopped. Default: 0.
- verbose (bool) – verbosity mode. Default: False.
- mode (str) – one of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. Default: ‘auto’.
- strict (bool) – whether to crash the training if monitor is not found in the metrics. Default: True.
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping('val_loss')
trainer = Trainer(early_stop_callback=early_stopping)
Model Checkpointing
Automatically save model checkpoints during training.
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(filepath=None, monitor='val_loss', verbose=False, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix='')
- filepath (Optional[str]): Path to save the model file. Can contain named formatting options to be auto-filled.
#saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt checkpoint_callback = ModelCheckpoint( filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
- monitor (str): quantity to monitor.
- verbose (bool): verbosity mode. Default: False.
- save_top_k (int): if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved.
- mode (str): one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
- save_weights_only (bool): if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (
- period (int): Interval (number of epochs) between checkpoints.
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Restoring Training State
If you don’t just want to load weights, but instead restore the full training, do the following:
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
# automatically restores model, epoch, step, LR schedulers, apex, etc...