probnmn.modules.seq2seq_base

class probnmn.modules.seq2seq_base.Seq2SeqBase(vocabulary: allennlp.data.vocabulary.Vocabulary, source_namespace: str, target_namespace: str, input_size: int = 256, hidden_size: int = 256, num_layers: int = 2, dropout: float = 0.0, max_decoding_steps: int = 30)[source]

Bases: allennlp.models.encoder_decoders.simple_seq2seq.SimpleSeq2Seq

A wrapper over AllenNLP’s SimpleSeq2Seq class. This serves as a base class for the ProgramGenerator and QuestionReconstructor. The key differences from super class are:

  1. This class doesn’t use beam search, it performs categorical sampling or greedy decoding as explicitly passed on forward() call.

  2. This class records four metrics: perplexity, sequence_accuracy, word error rate and BLEU score.

  3. Has sensible defaults for super class (dot-product attention, embedding etc.).

Parameters
vocabulary: allennlp.data.vocabulary.Vocabulary

AllenNLP’s vocabulary. This vocabulary has three namespaces - “questions”, “programs” and “answers”, which contain respective token to integer mappings.

source_namespace: str, required

Namespace for source tokens, “programs” for QuestionReconstructor and “questions” for ProgramGenerator.

target_namespace: str, required

Namespace for target tokens, “programs” for ProgramGenerator and “questions” for QuestionReconstructor.

input_sizeint, optional (default = 256)

The dimension of the inputs to the LSTM.

hidden_sizeint, optional (default = 256)

The dimension of the outputs of the LSTM.

num_layers: int, optional (default = 2)

Number of recurrent layers of the LSTM.

forward(self, source_tokens:torch.LongTensor, target_tokens:Union[torch.LongTensor, NoneType]=None, decoding_strategy:str='sampling') → Dict[str, torch.Tensor][source]

Override AllenNLP’s forward, changing decoder logic. Perform either categorical sampling or greedy decoding as per specified.

Parameters
source_tokens: torch.LongTensor

Tokenized source sequences padded to maximum length. These are not padded with @start@ and @end@ sentence boundaries. Shape: (batch_size, max_source_length)

target_tokens: torch.LongTensor, optional (default = None)

Tokenized target sequences padded to maximum length. These are not padded with @start@ and @end@ sentence boundaries. Shape: (batch_size, max_target_length)

decoding_strategy: str, optional (default = “sampling”)

How to perform decoding? One of “sampling” or “greedy”.

Returns
Dict[str, torch.Tensor]
_forward_loop(self, state:Dict[str, torch.FloatTensor], target_tokens:Dict[str, torch.LongTensor]=None, decoding_strategy:str='sampling') → Dict[str, torch.Tensor][source]

Make forward pass during training or do greedy search during prediction.

Notes

We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results.

_trim_predictions(self, predictions:torch.LongTensor)[source]

Trim output predictions at first “@end@” and pad the rest of sequence. This includes “@end@” as last token in trimmed sequence.

static _get_loss(logits:torch.LongTensor, targets:torch.LongTensor, target_mask:torch.LongTensor)[source]

Override AllenNLP Seq2Seq model’s provided _get_loss method, which returns sequence cross entropy averaged over batch by default. Instead, provide sequence cross entropy of each sequence in a batch separately.

From AllenNLP documentation:

Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of targets is expected to be greater than that of logits because the decoder does not need to compute the output corresponding to the last timestep of targets. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens:

The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
and the mask would be                     1   1   1   1   1   0   0
and let the logits be                     l1  l2  l3  l4  l5  l6

We actually need to compare:

the sequence           w1  w2  w3  <E> <P> <P>
with masks             1   1   1   1   0   0
against                l1  l2  l3  l4  l5  l6
(where the input was)  <S> w1  w2  w3  <E> <P>
get_metrics(self, reset:bool=True) → Dict[str, float][source]

Return recorded metrics - perplexity, sequence accuracy, word error rate, BLEU.

Parameters
reset: bool, optional (default = True)

Whether to reset the accumulated metrics after retrieving them.

Returns
Dict[str, float]

A dictionary with metrics:

{
    "perplexity",
    "sequence_accuracy",
    "word_error_rate",
    "BLEU"
}