virtex.modules.embedding


class virtex.modules.embedding.WordAndPositionalEmbedding(vocab_size: int, hidden_size: int, dropout: float = 0.0, max_caption_length: int = 30, padding_idx: int = 0)[source]

Bases: torch.nn.modules.module.Module

A Module for learned word embeddings and position embeddings for input tokens. Each token is mapped to a fixed dimensional word embedding; and corresponding positional embedding based on its index. These are summed together followed by layer normalization and an optional dropout.

Parameters
  • vocab_size – Size of token vocabulary.

  • hidden_size – Size of token embedding vectors.

  • dropout – Probability for final dropout applied after layer normalization.

  • max_caption_length – Maximum length of input captions; this is used to create a fixed positional embedding lookup table.

  • padding_idx – Token index of [PAD] token, word embedding for these tokens will be a vector of zeroes (and not trainable).

forward(tokens: torch.Tensor) torch.Tensor[source]

Get combined word and positional embeddings for input tokens.

Parameters

tokens – A tensor of shape (batch_size, max_caption_length) containing a batch of caption tokens, values in [0, vocab_size).

Returns

A tensor of shape (batch_size, max_caption_length, hidden_size) containing corresponding token embeddings.