probnmn.utils.checkpointing

class probnmn.utils.checkpointing.CheckpointManager(serialization_dir: str = '/tmp', keep_recent: int = 10, **checkpointables: Any)[source]

Bases: object

A CheckpointManager periodically serializes models and other checkpointable objects (which implement state_dict method) as .pth files during training, and optionally keeps track of best performing checkpoint based on an observed metric.

This class closely follows the API of PyTorch optimizers and learning rate schedulers.

Note

For DataParallel and DistributedDataParallel objects, module.state_dict is called instead of state_dict.

Note

The observed metric for keeping best checkpoint is assumed “higher is better”, flip the sign if otherwise.

Parameters
serialization_dir: str

Path to an empty or non-existent directory to save checkpoints.

keep_recent: int, optional (default=10)

Number of recent ‘k’ checkpoints to keep on disk. Older checkpoints will be removed. Set to a very large value for keeping all checkpoints.

checkpointables: Any

Keyword arguments with any checkpointable objects, for example: model, optimizer, learning rate scheduler. Their state dicts can be accessed as the name of keyword.

Examples

>>> model = torch.nn.Linear(10, 2)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> ckpt_manager = CheckpointManager("/tmp/ckpt", model=model, optimizer=optimizer)
>>> num_epochs = 20
>>> for epoch in range(num_epochs):
...     train(model)
...     val_loss = validate(model)
...     ckpt_manager.step(- val_loss, epoch)
step(self, iteration:int, metric:Union[float, NoneType]=None)[source]

Serialize checkpoint and update best checkpoint based on metric.

_state_dict(self)[source]

Return a dict containing state dict of all checkpointables.

remove_earliest_checkpoint(self)[source]

Remove ealiest serialized checkpoint from disk.

load(self, checkpoint_path:str)[source]

Load a serialized checkpoint from a path. This method will try to find each of checkpointables in the file and load its state dict. Since our checkpointables are held as references, this method does not return them.

Parameters
checkpoint_path: str

Path to a checkpoint serialized by step().

Returns
int

Iteration corresponding to the loaded checkpoint. Useful for resuming training. This will be -1 in case of best checkpoint, or if info does not exist.