[Project Page] [Paper] [Models] [Poster] [Slides] [Blog (In Chinese)]
This repo provides official code and checkpoints for iVideoGPT, a generic and efficient world model architecture that has been pre-trained on millions of human and robotic manipulation trajectories.
- 🚩 2024.11.01: NeurIPS 2024 camera-ready version is released on arXiv.
- 🚩 2024.09.26: iVideoGPT has been accepted by NeurIPS 2024, congrats!
- 🚩 2024.08.31: Training code is released (Work in progress 🚧 and please stay tuned!)
- 🚩 2024.05.31: Project website with video samples is released.
- 🚩 2024.05.30: Model pre-trained on Open X-Embodiment and inference code are released.
- 🚩 2024.05.27: Our paper is released on arXiv.
conda create -n ivideogpt python==3.9
conda activate ivideogpt
pip install -r requirements.txt
To evaluate the FVD metric, download the pretrained I3D model into pretrained_models/i3d/i3d_torchscript.pt
.
At the moment we provide the following models:
Model | Resolution | Action-conditioned | Goal-conditioned | Tokenizer Size | Transformer Size |
---|---|---|---|---|---|
ivideogpt-oxe-64-act-free | 64x64 | No | No | 114M | 138M |
ivideogpt-oxe-64-act-free-medium | 64x64 | No | No | 114M | 436M |
ivideogpt-oxe-64-goal-cond | 64x64 | No | Yes | 114M | 138M |
ivideogpt-oxe-256-act-free | 256x256 | No | No | 310M | 138M |
If no network connection to Hugging Face, you can manually download from Tsinghua Cloud.
Notes:
- Due to the heterogeneity of action spaces, we currently do not have an action-conditioned prediction model on OXE.
- Pre-trained models at 256x256 resolution may not perform best due to insufficient training, but can serve as a good starting point for downstream fine-tuning.
Open X-Embodiment: Download datasets from Open X-Embodiment and extract single episodes as .npz
files:
python datasets/oxe_data_converter.py --dataset_name {dataset name, e.g. bridge} --input_path {path to downloaded OXE} --output_path {path to stored npz}
To replicate our pre-training on OXE, you need to extract all datasets listed under OXE_SELECT
in ivideogpt/data/dataset_mixes.py
.
See instructions at datasets
on preprocessing more datasets.
For action-free video prediction on Open X-Embodiment, run:
python inference/predict.py --pretrained_model_name_or_path "thuml/ivideogpt-oxe-64-act-free" --input_path inference/samples/fractal_sample.npz --dataset_name fractal20220817_data
See more examples at inference
.
To pre-train iVideoGPT, adjust the arguments in the command below as needed and run:
bash ./scripts/pretrain/ivideogpt-oxe-64-act-free.sh
See more scripts for pre-trained models at scripts/pretrain
.
After preparing the BAIR dataset, run the following:
accelerate launch train_tokenizer.py \
--exp_name bair_tokenizer_ft --output_dir log_vqgan --seed 0 --mixed_precision bf16 \
--model_type ctx_vqgan \
--train_batch_size 16 --gradient_accumulation_steps 1 --disc_start 1000005 \
--oxe_data_mixes_type bair --resolution 64 --dataloader_num_workers 16 \
--rand_select --video_stepsize 1 --segment_horizon 16 --segment_length 8 --context_length 1 \
--pretrained_model_name_or_path pretrained_models/ivideogpt-oxe-64-act-free/tokenizer
For action-conditioned video prediction, run the following:
accelerate launch train_gpt.py \
--exp_name bair_llama_ft --output_dir log_trm --seed 0 --mixed_precision bf16 \
--vqgan_type ctx_vqgan \
--pretrained_model_name_or_path {log directory of finetuned tokenizer}/unwrapped_model \
--config_name configs/llama/config.json --load_internal_llm --action_conditioned --action_dim 4 \
--pretrained_transformer_path pretrained_models/ivideogpt-oxe-64-act-free/transformer \
--per_device_train_batch_size 16 --gradient_accumulation_steps 1 \
--learning_rate 1e-4 --lr_scheduler_type cosine \
--oxe_data_mixes_type bair --resolution 64 --dataloader_num_workers 16 \
--video_stepsize 1 --segment_length 16 --context_length 1 \
--use_eval_dataset --use_fvd --use_frame_metrics \
--weight_decay 0.01 --llama_attn_drop 0.1 --embed_no_wd
For action-free video prediction, remove --load_internal_llm --action_conditioned
.
Install the Metaworld version we used:
pip install git+https://github.com/Farama-Foundation/Metaworld.git@83ac03ca3207c0060112bfc101393ca794ebf1bd
Modify paths in mbrl/cfgs/mbpo_config.yaml
to your own paths (currently only support absolute paths).
python mbrl/train_metaworld_mbpo.py task=plate_slide num_train_frames=100002 demo=true
If you find this project useful, please cite our paper as:
@inproceedings{wu2024ivideogpt,
title={iVideoGPT: Interactive VideoGPTs are Scalable World Models},
author={Jialong Wu and Shaofeng Yin and Ningya Feng and Xu He and Dong Li and Jianye Hao and Mingsheng Long},
booktitle={Advances in Neural Information Processing Systems},
year={2024},
}
If you have any question, please contact [email protected].
Our codebase is based on huggingface/diffusers and facebookresearch/drqv2.