virtex.optim.lookahead


Lookahead Optimizer: k steps forward, 1 step back.

This implementation is adapted with minimal modifications from the authors’ implementation.

If you take it from here, please cite them:

@inproceedings{zhang2019lookahead,
    title={Lookahead Optimizer: k steps forward, 1 step back},
    author={Zhang, Michael R and Lucas, James and Hinton, Geoffrey and Ba, Jimmy},
    journal={NeurIPS},
    year={2019}
}
class virtex.optim.lookahead.Lookahead(optimizer: torch.optim.optimizer.Optimizer, k: int = 5, alpha: float = 0.8)[source]

Bases: torch.optim.optimizer.Optimizer

Implements Lookahead optimizer.

Parameters
  • optimizer – Wrapper inner optimizer. The weights it manages will be the “fast” weights.

  • k – Number of lookahead steps before updating “slow” weights.

  • alpha – Linear interpolation factor, 1.0 recovers inner optimizer.

zero_grad()[source]

Clear all grad buffers at the start of new forward pass.

state_dict()[source]

Returns the state of the optimizer as a dict.

It contains two entries:

  • state - a dict holding current optimization state. Its content

    differs between optimizer classes.

  • param_groups - a list containing all parameter groups where each

    parameter group is a dict

load_state_dict(state_dict: Dict[str, Any])[source]

Loads the optimizer state.

Parameters

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

step(closure: Optional[Callable] = None)[source]

Perform a single Lookahead optimization step.

Parameters

closure – A callable that re-evaluates the model and returns loss.

load_slow_weights()[source]

Load slow weights from Lookahead optimizer. Useful for performing evaluation on the slow weights (which typically generalize better).

This method backs up fast weights to load them after evaluation. No need to call this method if evaluation happens just after a lookahead step.

restore_fast_weights()[source]

Restore fast weights for optimization. Call this after evaluation if load_slow_weights() was called.