probnmn.trainers.module_training_trainer¶
-
class
probnmn.trainers.module_training_trainer.
ModuleTrainingTrainer
(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]¶ Bases:
probnmn.trainers._trainer._Trainer
Performs training for
module_training
phase, using batches of training examples fromModuleTrainingDataset
.- Parameters
- config: Config
A
Config
object with all the relevant configuration parameters.- serialization_dir: str
Path to a directory for tensorboard logging and serializing checkpoints.
- gpu_ids: List[int], optional (default=[0])
List of GPU IDs to use or evaluation,
[-1]
- use CPU.- cpu_workers: int, optional (default = 0)
Number of CPU workers to use for fetching batch examples in dataloader.
Examples
>>> config = Config("config.yaml") # PHASE must be "module_training" >>> trainer = ModuleTrainingTrainer(config, serialization_dir="/tmp") >>> evaluator = ModuleTrainingEvaluator(config, trainer.models) >>> for iteration in range(100): >>> trainer.step() >>> # validation every 100 steps >>> if iteration % 100 == 0: >>> val_metrics = evaluator.evaluate() >>> trainer.after_validation(val_metrics, iteration)
-
_do_iteration
(self, batch:Dict[str, Any]) → Dict[str, Any][source]¶ Forward and backward passes on models, given a batch sampled from dataloader.
- Parameters
- batch: Dict[str, Any]
A batch of training examples sampled from dataloader. See
step()
and_cycle()
on how this batch is sampled.
- Returns
- Dict[str, Any]
An output dictionary typically returned by the models. This would be passed to
_after_iteration()
for tensorboard logging.
-
after_validation
(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]¶ Set
"metric"
key inval_metrics
, this governs learning rate scheduling and keeping track of best checkpoint (insuper
method). This metric will be answer accuracy.Super method will perform learning rate scheduling, serialize checkpoint, and log all the validation metrics to tensorboard.
- Parameters
- val_metrics: Dict[str, Any]
Validation metrics of
NeuralModuleNetwork
. Returned byevaluate
method ofModuleTrainingEvaluator
.- iteration: int, optional (default = None)
Iteration number. If
None
, use the internalself._iteration
counter.