probnmn.trainers.question_coding_trainer

class probnmn.trainers.question_coding_trainer.QuestionCodingTrainer(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]

Bases: probnmn.trainers._trainer._Trainer

Performs training for question_coding phase, using batches of training examples from QuestionCodingDataset.

Parameters
config: Config

A Config object with all the relevant configuration parameters.

serialization_dir: str

Path to a directory for tensorboard logging and serializing checkpoints.

gpu_ids: List[int], optional (default=[0])

List of GPU IDs to use or evaluation, [-1] - use CPU.

cpu_workers: int, optional (default = 0)

Number of CPU workers to use for fetching batch examples in dataloader.

Examples

>>> config = Config("config.yaml")  # PHASE must be "question_coding"
>>> trainer = QuestionCodingTrainer(config, serialization_dir="/tmp")
>>> evaluator = QuestionCodingEvaluator(config, trainer.models)
>>> for iteration in range(100):
>>>     trainer.step()
>>>     # validation every 100 steps
>>>     if iteration % 100 == 0:
>>>         val_metrics = evaluator.evaluate()
>>>         trainer.after_validation(val_metrics, iteration)
_do_iteration(self, batch:Dict[str, Any]) → Dict[str, Any][source]

Forward and backward passes on models, given a batch sampled from dataloader.

Parameters
batch: Dict[str, Any]

A batch of training examples sampled from dataloader. See step() and _cycle() on how this batch is sampled.

Returns
Dict[str, Any]

An output dictionary typically returned by the models. This would be passed to _after_iteration() for tensorboard logging.

after_validation(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]

Steps to do after an external _Evaluator performs evaluation. This is not called by step(), call it from outside at appropriate time. Default behavior is to perform learning rate scheduling, serializaing checkpoint and to log validation metrics to tensorboard.

Since this implementation assumes a key "metric" in val_metrics, it is convenient to set this key while overriding this method, when there are multiple models and multiple metrics and there is one metric which decides best checkpoint.

Parameters
val_metrics: Dict[str, Any]

Validation metrics for all the models. Returned by evaluate method of _Evaluator (or its extended class).

iteration: int, optional (default = None)

Iteration number. If None, use the internal self._iteration counter.