We propose a Multi-stage Semantic Enhancement and Aggregation framework (Multi-SEA) with novel networks and training schemes. Multi-SEA first designs a fusion module with agent attention and gating mechanism. It enhances uni-modal information and aggregates fine-grained cross-modal information by involving different stages. Multi-SEA then introduces a three-stage scheme to integrate the two structures mentioned above together.
Eventually, Multi-SEA utilizes a negative sample queue and hierarchical scheme to facilitate robust contrastive learning and promote expressive capabilities from implicit information. Experimental results demonstrate that Multi-SEA outperforms the state-of-the-art schemes with a large margin.
We employ the Roberta-base model and Vit-B/16 model as the backbones to preliminarily encode our raw data, which can be found in following links quickly.
Visual Backbone | Text Backbone |
---|---|
vit-b-16 | Roberta-base |
We provide the tensorboard logs and checkpoints fine-tuned on Flickr30k and MSCOCO. The checkpoints contain weights, configs, optimizers, and other training information saved by Pytorch Lightning.
Checkpoint on Flickr30k | Checkpoint on MSCOCO |
---|---|
Multi-SEA_flickr30k | Multi-SEA_mscoco |
-
Python version >= 3.9.0
-
PyTorch version >= 2.0.0
-
Install other libraries via
pip install -r requirements.txt
We follow ViLT and use pyarrow
to serialize the datasets. See this link for details.
We should find the config file in src/subsrc/configs/retrieval_flickr30k.yaml
then replace the following file path:
- data_root : the directory of your dataset
- vit : the directory of image backbone
- tokenizer : the directory of text backbone
finally run the script in /src :
python main.py --config=./subsrc/configs/retrieval_flickr30k.yaml --devices=[0]
We should find the config file in src/subsrc/configs/retrieval_coco.yaml
then replace the following file path:
- data_root : the directory of your dataset
- vit : the directory of image backbone
- tokenizer : the directory of text backbone
finally run the script in /src :
python main.py --config=./subsrc/configs/retrieval_coco.yaml --devices=[0]
We should find the config file in src/subsrc/configs/retrieval_flickr30k.yaml
then replace the following file path:
- data_root : the directory of your dataset
- test_checkpoints_dir : the directory of checkpoint
- vit : the directory of image backbone
- tokenizer : the directory of text backbone
finally run the script in /src :
python main.py --config=./subsrc/configs/retrieval_flickr30k.yaml --devices=[0] --test_only
We should find the config file in src/subsrc/configs/retrieval_coco.yaml
then replace the following file path:
- data_root : the directory of your dataset
- test_checkpoints_dir : the directory of checkpoint
- vit : the directory of image backbone
- tokenizer : the directory of text backbone
finally run the script in /src :
python main.py --config=./subsrc/configs/retrieval_coco.yaml --devices=[0] --test_only
If you use our work, please cite:
The implementation of Mulit-SEA relies on resources from Bert(pytorch), CLIP, llama, and Pytorch Lightning . We thank the original authors for their open-sourcing.