probnmn.modules.seq2seq_base¶
-
class
probnmn.modules.seq2seq_base.
Seq2SeqBase
(vocabulary: allennlp.data.vocabulary.Vocabulary, source_namespace: str, target_namespace: str, input_size: int = 256, hidden_size: int = 256, num_layers: int = 2, dropout: float = 0.0, max_decoding_steps: int = 30)[source]¶ Bases:
allennlp.models.encoder_decoders.simple_seq2seq.SimpleSeq2Seq
A wrapper over AllenNLP’s SimpleSeq2Seq class. This serves as a base class for the
ProgramGenerator
andQuestionReconstructor
. The key differences from super class are:This class doesn’t use beam search, it performs categorical sampling or greedy decoding as explicitly passed on
forward()
call.This class records four metrics: perplexity, sequence_accuracy, word error rate and BLEU score.
Has sensible defaults for super class (dot-product attention, embedding etc.).
- Parameters
- vocabulary: allennlp.data.vocabulary.Vocabulary
AllenNLP’s vocabulary. This vocabulary has three namespaces - “questions”, “programs” and “answers”, which contain respective token to integer mappings.
- source_namespace: str, required
Namespace for source tokens, “programs” for
QuestionReconstructor
and “questions” forProgramGenerator
.- target_namespace: str, required
Namespace for target tokens, “programs” for ProgramGenerator and “questions” for QuestionReconstructor.
- input_sizeint, optional (default = 256)
The dimension of the inputs to the LSTM.
- hidden_sizeint, optional (default = 256)
The dimension of the outputs of the LSTM.
- num_layers: int, optional (default = 2)
Number of recurrent layers of the LSTM.
-
forward
(self, source_tokens:torch.LongTensor, target_tokens:Union[torch.LongTensor, NoneType]=None, decoding_strategy:str='sampling') → Dict[str, torch.Tensor][source]¶ Override AllenNLP’s forward, changing decoder logic. Perform either categorical sampling or greedy decoding as per specified.
- Parameters
- source_tokens: torch.LongTensor
Tokenized source sequences padded to maximum length. These are not padded with @start@ and @end@ sentence boundaries. Shape: (batch_size, max_source_length)
- target_tokens: torch.LongTensor, optional (default = None)
Tokenized target sequences padded to maximum length. These are not padded with @start@ and @end@ sentence boundaries. Shape: (batch_size, max_target_length)
- decoding_strategy: str, optional (default = “sampling”)
How to perform decoding? One of “sampling” or “greedy”.
- Returns
- Dict[str, torch.Tensor]
-
_forward_loop
(self, state:Dict[str, torch.FloatTensor], target_tokens:Dict[str, torch.LongTensor]=None, decoding_strategy:str='sampling') → Dict[str, torch.Tensor][source]¶ Make forward pass during training or do greedy search during prediction.
Notes
We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results.
-
_trim_predictions
(self, predictions:torch.LongTensor)[source]¶ Trim output predictions at first “@end@” and pad the rest of sequence. This includes “@end@” as last token in trimmed sequence.
-
static
_get_loss
(logits:torch.LongTensor, targets:torch.LongTensor, target_mask:torch.LongTensor)[source]¶ Override AllenNLP Seq2Seq model’s provided
_get_loss
method, which returns sequence cross entropy averaged over batch by default. Instead, provide sequence cross entropy of each sequence in a batch separately.From AllenNLP documentation:
Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of
targets
is expected to be greater than that oflogits
because the decoder does not need to compute the output corresponding to the last timestep oftargets
. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens:The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6
We actually need to compare:
the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P>
-
get_metrics
(self, reset:bool=True) → Dict[str, float][source]¶ Return recorded metrics - perplexity, sequence accuracy, word error rate, BLEU.
- Parameters
- reset: bool, optional (default = True)
Whether to reset the accumulated metrics after retrieving them.
- Returns
- Dict[str, float]
A dictionary with metrics:
{ "perplexity", "sequence_accuracy", "word_error_rate", "BLEU" }