Skip to content

🧩 SimMIM

This section covers the implementation of SimMIM (Simple Masked Image Modeling) in the ViT-SSL framework.
It follows the original SimMIM paper.


Overview

SimMIM performs masked patch prediction, where parts of an input image are hidden (masked) and the model learns to reconstruct those regions.
The approach is conceptually similar to BERT-style pretraining but adapted for vision using patch-level masking and pixel-level regression.


Architecture: SimMIMViT

Defined in model.py, this model consists of:

  • A custom patch embedding via Unfold + Linear
  • Learnable mask token for masked patches
  • Positional embeddings
  • Multiple EncoderBlocks
  • A simple MLP head for predicting RGB pixel values of masked patches

The masked tokens are inserted directly into the sequence before encoding. No [CLS] token is used during pretraining.


Masking: simple_masking

Defined in masking.py, this function:

  • Selects a random subset of patches to mask using boolean masks
  • Returns:
    • The original patches (unchanged)
    • A binary mask indicating which patches are masked
    • The target pixels for loss computation

Forward Pass

patches = Unfold(image)
patches, bool_mask, targets = simple_masking(patches)

encoder_input = torch.where(mask, mask_token, projected_patches)
encoder_input += pos_emb
encoded = transformer(encoder_input)

output = simmim_head(encoded[masked_positions])

The model only predicts masked patches, this makes training efficient and focused.

Loss

The loss is a pixel-wise regression loss (e.g., MSE or L1) between the predicted and ground truth pixel values of the masked patches:

loss = criterion(predicted_pixels, target_pixels)

Training: SimMIMTrainer

Implemented in trainer.py, this trainer:

  • Loads input images and applies patch masking
  • Flattens predictions and targets for loss computation
  • Supports warmup schedulers and logs training/validation metrics

Validation

Validation follows the same logic as training, but without gradient updates. It:

  • Reconstructs patches
  • Measures loss
  • Logs predictions for analysis or visual debugging
Component File Role
SimMIMViT model.py Vision Transformer with patch masking and pixel prediction
simple_masking masking.py Random masking of input patches and target generation
SimMIMTrainer trainer.py Training/validation loop with patch reconstruction loss