Getting Started With PyTorch Lightning
PyTorch Lightning Guide
**Below is a documentation from PyTorch Lightning that i’ve shorten. I tend to use these methods the most in my projects. **
The general pattern is that each loop (training, validation, test loop) has 3 methods:
-
___step
-
___step_end
-
___epoch_end
Lifecycle
The methods in the LightningModule
are called in this order:
If you define a validation loop then
And if you define a test loop:
In every epoch, the loop methods are called in this frequency:
-
validation_step()
called every batch -
validation_epoch_end()
called every epoch
LightningModule
- def __init__(self):
Define Model Architecture
- def forward(self,x)
Forward pass our data
- def training_step(self, batch_idx): (REQUIRED)
Parameters:
- batch: The Output of your DataLoader. A tensor, tuple or list.
- batch_idx (int): Integer displaying index of this batch
- optimizer_idx (int): When using multiple optimizer, this argument is used.
- hiddens (Tensor): Passed in if truncated_bptt_steps > 0
Returns: Dict with loss key and optional log or progress bar keys.
- loss: tensor scalar (required)
- progress_bar: Dict for progress bar display (Tensor)
- log: Dict for metrics to add to logger ```python output = { ‘loss’: loss, # required ‘progress_bar’: {‘training_loss’:loss}, # Optional, must be tensor ‘log’: logger_logs: {‘losses’: logger_loss} } return output
- def training_step_end(batch_parts_outputs): (OPTIONAL)
Use this when training with dp or ddp2 because
training_step()
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Parameters: batch_parts_outputs: What you return in training_step for each batch part.
Return: Dict with loss key and optional log or progress bar keys. - loss -> tensor scalar REQUIRED
- progress_bar -> Dict for progress bar display. Must have only tensors - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) - training_epoch_end(outputs)
Called at the end of the training epoch with the outputs of all training steps.
Parameters:
- outputs: List of outputs you defined in training_step() or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Returns: Dict or OrderedDict. May contain the following optional keys:
- log (metrics to be added to the logger; only tensors)
- any metric used in a callback (e.g. early stopping).
The outputs here are strictly for logging or progress bar. If you don’t need to display anything, don’t return anything.
- def configure_optimizers(self): (REQUIRED)
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
Returns:
- Single Optimizer List of Tuple - List of Optimizer ```python def configure_optimizers(self): # most cases opt = Adam(self.parameters(), lr=1e-3) return opt
def configure_optimizers(self): generator_opt = Adam(self.model_gen.parameters(), lr=0.01) disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) return generator_opt, disriminator_opt
- def validation_step(batch, batch_idx, dataloader_idx): (OPTIONAL)
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
Parameters:
- batch: The output of your DataLoader.
- batch_idx (int): The index of this batch
- dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple val datasets used)
Return: Dict or OrderedDict - passed to
validation_epoch_end()
. If you definedvalidation_step_end()
it will go to that first. - def validation_step_end(batch_parts_outputs): (OPTIONAL)
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
Parameters: batch_parts_outputs : What you return in
validation_step()
for each batch part.Return: Dict or OrderedDict - passed to the
validation_epoch_end()
method. - def validation_epoch_end(outputs: (OPTIONAL)
Called at the end of the validation epoch with the outputs of all validation steps.
Parameters:
- outputs: List of outputs you defined in
validation_step()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Retruns: Dict or OrderedDict. May have the following optional keys:
- progress_bar (dict for progress bar display; only tensors)
- log (dict of metrics to add to logger; only tensors). ``` python def validation_epoch_end(self, outputs): val_acc_mean = 0 for output in outputs: val_acc_mean += output[‘val_acc’] val_acc_mean /= len(outputs) tqdm_dict = {‘val_acc’: val_acc_mean.item()} # show val_acc in progress bar but only log val_loss results = { ‘progress_bar’: tqdm_dict, ‘log’: {‘val_acc’: val_acc_mean.item()} } return results
- outputs: List of outputs you defined in
- def test_step(batch, batch_idx, dataloader_idx): (OPTIONAL)
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
Parameters:
- batch: The output of your DataLoader.
- batch_idx (int): The index of this batch
- dataloader_idx (int): he index of the dataloader that produced this batch (only if multiple test datasets used).
Return: Dict or OrderedDict - Dict or OrderedDict - passed to the
test_epoch_end()
method. If you definedtest_step_end()
it will go to that first. - def test_step_end(batch_parts_outputs): (OPTIONAL)
Use this when testing with dp or ddp2 because
test_step()
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Parameters: batch_parts_outputs: What you return in
test_step()
for each batch part. Return: Dict or OrderedDict: Dict or OrderedDict - passed to thetest_epoch_end()
. - def test_epoch_end(outputs: (OPTIONAL)
Called at the end of a test epoch with the output of all test steps.
Parameters:
- outputs: List of outputs you defined in
test_step_end()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader
Retruns: Dict or OrderedDict. May have the following optional keys:
- progress_bar (dict for progress bar display; only tensors)
- log (dict of metrics to add to logger; only tensors). ``` python def test_epoch_end(self, outputs): test_acc_mean = 0 for output in outputs: test_acc_mean += output[‘test_acc’] test_acc_mean /= len(outputs) tqdm_dict = {‘test_acc’: test_acc_mean.item()} # show val_acc in progress bar but only log val_loss results = { ‘progress_bar’: tqdm_dict, ‘log’: {‘test_acc’: test_acc_mean.item()} } return results
- outputs: List of outputs you defined in
- def prepare_data(self): (OPTIONAL)
Use this to download and prepare data. In distributed (GPU, TPU), this will only be called once. This is called before requesting the dataloaders:
def prepare_data(self): download_imagenet() clean_imagenet() cache_imagenet()
- def train_dataloader(self): (Required)
Implement a PyTorch DataLoader for training.
Returns:
- Single PyTorch DataLoader
- def val_dataloader(self): (Optional)
Implement a PyTorch DataLoader for validation.
Returns:
- Single PyTorch DataLoader
- def test_dataloader(self): (Optional)
Implement a PyTorch DataLoader for Testing.
Returns:
- Single PyTorch DataLoader
Trainer
For more info visit: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-class
from pytorch_lightning import Trainer
model = LitMNIST()
trainer = Trainer(gpus=1)
trainer.fit(model)
Args:
max_epochs: Stop training once this number of epochs is reached. By Default max_epochs: 1000
min_epochs: Force training for at least these many epochs
max_steps: Stop training after this number of steps. Disabled by default (None).
min_steps: Force training for at least these number of steps. Disabled by default (None). logger: Logger (or iterable collection of loggers) for experiment tracking.
checkpoint_callback: Callback for checkpointing.
early_stop_callback (:class:pytorch_lightning.callbacks.EarlyStopping
):
callbacks: Add a list of callbacks.
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed
gradient_clip_val: 0 means don’t clip.
process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training.
gpus: Which GPUs to train on.
auto_select_gpus:
If enabled and gpus
is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in “exclusive mode”, such
that only one process at a time can access them.
num_tpu_cores: How many TPU cores to train on (1 or 8).
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value 0
disables progress bar.
Ignored when a custom callback is passed to :paramref:~Trainer.callbacks
.
check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
train_percent_check: How much of training dataset to check.
val_percent_check: How much of validation dataset to check.
test_percent_check: How much of test dataset to check.
val_check_interval: How often within one training epoch to check the validation set
log_save_interval: Writes logs to disk this often
row_log_interval: How often to add logging rows (does not write to disk)
distributed_backend: The distributed backend to use.
precision: Full precision (32), half precision (16).
weights_summary: Prints a summary of the weights when training begins.
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in default_root_dir
.
amp_level: The optimization level to use (O1, O2, etc…).
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
Callbacks
Early Stopping
Stop training when a monitored quantity has stopped improving.
Parameters:
- monitor (str): quantity to be monitored. Default: ‘val_loss’.
- min_delta (float) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Default: 0.
- patience (int) – number of epochs with no improvement after which training will be stopped. Default: 0.
- verbose (bool) – verbosity mode. Default: False.
- mode (str) – one of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. Default: ‘auto’.
- strict (bool) – whether to crash the training if monitor is not found in the metrics. Default: True.
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping('val_loss')
trainer = Trainer(early_stop_callback=early_stopping)
Model Checkpointing
Automatically save model checkpoints during training.
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(filepath=None, monitor='val_loss', verbose=False, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix='')
Parameters:
- filepath (Optional[str]): Path to save the model file. Can contain named formatting options to be auto-filled.
#saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt checkpoint_callback = ModelCheckpoint( filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
- monitor (str): quantity to monitor.
- verbose (bool): verbosity mode. Default: False.
- save_top_k (int): if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved.
- mode (str): one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
- save_weights_only (bool): if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).
- period (int): Interval (number of epochs) between checkpoints.
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}')
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Restoring Training State
If you don’t just want to load weights, but instead restore the full training, do the following:
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)