probnmn.trainers.program_prior_trainer¶
-
class
probnmn.trainers.program_prior_trainer.
ProgramPriorTrainer
(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]¶ Bases:
probnmn.trainers._trainer._Trainer
Performs training for
program_prior
phase, using batches of training examples fromProgramPriorDataset
.- Parameters
- config: Config
A
Config
object with all the relevant configuration parameters.- serialization_dir: str
Path to a directory for tensorboard logging and serializing checkpoints.
- gpu_ids: List[int], optional (default=[0])
List of GPU IDs to use or evaluation,
[-1]
- use CPU.- cpu_workers: int, optional (default = 0)
Number of CPU workers to use for fetching batch examples in dataloader.
Examples
>>> config = Config("config.yaml") # PHASE must be "program_prior" >>> trainer = ProgramPriorTrainer(config, serialization_dir="/tmp") >>> evaluator = ProgramPriorEvaluator(config, trainer.models) >>> for iteration in range(100): >>> trainer.step() >>> # validation every 100 steps >>> if iteration % 100 == 0: >>> val_metrics = evaluator.evaluate() >>> trainer.after_validation(val_metrics, iteration)
-
_do_iteration
(self, batch:Dict[str, Any]) → Dict[str, Any][source]¶ Forward and backward passes on models, given a batch sampled from dataloader.
- Parameters
- batch: Dict[str, Any]
A batch of training examples sampled from dataloader. See
step()
and_cycle()
on how this batch is sampled.
- Returns
- Dict[str, Any]
An output dictionary typically returned by the models. This would be passed to
_after_iteration()
for tensorboard logging.
-
after_validation
(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]¶ Set
"metric"
key inval_metrics
, this governs learning rate scheduling and keeping track of best checkpoint (insuper
method). This metric will be perplexity ofProgramPrior
(lower is better).Super method will perform learning rate scheduling, serialize checkpoint, and log all the validation metrics to tensorboard.
- Parameters
- val_metrics: Dict[str, Any]
Validation metrics of
ProgramPrior
. Returned byevaluate
method ofProgramPriorEvaluator
.- iteration: int, optional (default = None)
Iteration number. If
None
, use the internalself._iteration
counter.