virtex.modules.visual_backbones


class virtex.modules.visual_backbones.VisualBackbone(visual_feature_size: int)[source]

Bases: torch.nn.modules.module.Module

Base class for all visual backbones. All child classes can simply inherit from Module, however this is kept here for uniform type annotations.

class virtex.modules.visual_backbones.TorchvisionVisualBackbone(name: str = 'resnet50', visual_feature_size: int = 2048, pretrained: bool = False, frozen: bool = False)[source]

Bases: virtex.modules.visual_backbones.VisualBackbone

A visual backbone from Torchvision model zoo. Any model can be specified using corresponding method name from the model zoo.

Parameters
  • name – Name of the model from Torchvision model zoo.

  • visual_feature_size – Size of the channel dimension of output visual features from forward pass.

  • pretrained – Whether to load ImageNet pretrained weights from Torchvision.

  • frozen – Whether to keep all weights frozen during training.

forward(image: torch.Tensor) torch.Tensor[source]

Compute visual features for a batch of input images.

Parameters

image – Batch of input images. A tensor of shape (batch_size, 3, height, width).

Returns

A tensor of shape (batch_size, channels, height, width), for example it will be (batch_size, 2048, 7, 7) for ResNet-50.

detectron2_backbone_state_dict() Dict[str, Any][source]

Return state dict of visual backbone which can be loaded with Detectron2. This is useful for downstream tasks based on Detectron2 (such as object detection and instance segmentation). This method renames certain parameters from Torchvision-style to Detectron2-style.

Returns

{"model", "author", "matching_heuristics"}. These are necessary keys for loading this state dict properly with Detectron2.

Return type

A dict with three keys