Skip to content

Pytorch Lightning implementation of Vision Transformer with support for loading checkpoints saved in official Flax implementation.

Notifications You must be signed in to change notification settings

ps4vs/Vision-Transformer-Pytorch-Lightning

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer in PyTorch Lightning

This is a third party implementation of the Vision Transformer paper in PyTorch Lightning with focus on transparency in training/fine-tuning the model.
Heavily based on Google's official implementation in Flax

GitHub Logo

Features to be implemented:

  • [:heavy_check_mark:] Architecture as PyTorch modules.

TODO: Sparse and Linear Transformers utilities

  • [:heavy_check_mark:] Support for loading checkpoints (i.e, pretrained weights) saved as .npz by Flax model into an identical PyTorch model in terms of architecture and naming conventions.

Have to look at load_pretrained function in checkpoints.py thoroughly for conversion into torch.nn.Module.state_dict() format.

  • [:heavy_check_mark:] General model architecture as a pl.LightningModule object with transparent code, with readable code for tokenisation, training steps etc.
  • Implementation of 4 variations of ViT (b16, b32, l16, l32) in PyTorch based on configs.py in the official repo.

Have to remove the hardcoded variables and write them in terms of self.hparams

  • Implementation of training step and configure optimizers in the LightningModule to truly support fine-tuning on custom dataset.
  • Implementation of a reusable torchvision.Dataset class to output tokenised images (with positional encodings) for usage in ViT.

Have to look at prefetch, get_data and get_dataset_info in train.py for this.

  • Support for Multi-GPU training/fine-tuning using pl.LightningModule's features.

About

Pytorch Lightning implementation of Vision Transformer with support for loading checkpoints saved in official Flax implementation.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 100.0%