Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GTPrompt model training problem #16

Open
yihp opened this issue Sep 13, 2024 · 13 comments
Open

GTPrompt model training problem #16

yihp opened this issue Sep 13, 2024 · 13 comments

Comments

@yihp
Copy link

yihp commented Sep 13, 2024

Hi! Thanks for your contribution. It is an excellent piece of work!

My task language is Chinese. I have trained the MultiCXR model on my own vocabulary, I have the following problems when training the GTPrompt model:

I cannot load the multi_ckpt_name: aehrc/cxrmate-multi-tf you trained, because the word embedding dimension size is different, and the cxrmate-multi-tf-cn I trained myself did not save the model file in the pytorch_model.bin format, so I don’t know how to load it.

How should I load the trained MultiCXR model in.ckpt format.

# Load multi checkpoint:
if encoder_decoder_ckpt_name:
    encoder_decoder = AutoModel.from_pretrained(encoder_decoder_ckpt_name, trust_remote_code=True)
    self.load_state_dict(encoder_decoder.state_dict())
else:
    warnings.warn('The encoder-to-decoder model was not warm-started before applying low-rank approximation.')
@anicolson
Copy link
Member

Hi @yihp,

If test_ckpt_name is in your config, it will use the hugging face from_pretrained method during testing.

Else, it will automatically load the .ckpt as the lightning module from your exp_dir:

ckpt_path = get_test_ckpt_path(

and

model = TaskModel.load_from_checkpoint(checkpoint_path=ckpt_path, **vars(args))

Hence, remove test_ckpt_name from your config to test the .ckpt file.

@yihp
Copy link
Author

yihp commented Sep 16, 2024

Hi @anicolson ,

Thank you very much for your reply.

Are the network structures of the GTPrompt model and the MultiCXR model the same? So during training, can I load the model checkpoint ckpt of the MultiCXR model when training the GTPrompt model ?

Looking forward to your reply !

@anicolson
Copy link
Member

Hi @yihp,

Yes, the MultiCXR model is used to warm-start GTPrompt.

@yihp
Copy link
Author

yihp commented Sep 16, 2024

Hi @anicolson ,

But I don't know if there is a problem with the pytorch_model.bin I saved when training the MultiCXR model, which causes garbled output during the verification process.

So can I specify the last.ckpt of the MultiCXR model?

@anicolson
Copy link
Member

Hi @yihp,

Specify warm_start_ckpt_path in your config:

model = TaskModel.load_from_checkpoint(checkpoint_path=args.warm_start_ckpt_path, **vars(args))

@yihp
Copy link
Author

yihp commented Sep 16, 2024

OK, I will go to the lab to try it later. Thank you very much for your reply !

@yihp
Copy link
Author

yihp commented Sep 16, 2024

Hi @anicolson ,

I specify warm_start_ckpt_path for training:
dlhpcstarter -t cxrmate -c config/train/longitudinal_gt_prompt_tf_qwen.yaml --stages_module tools.stages --train --trial 5 --warm-start-ckpt-path experiments/cxrmate/multi_tf/trial_0/epoch=3-step=7840-val_report_nlg_bleu_4=0.017195.ckpt
But the following error occurred:

Traceback (most recent call last):
  File "/home/maiyue/anaconda3/envs/cxrmate/bin/dlhpcstarter", line 8, in <module>
    sys.exit(main())
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/dlhpcstarter/__main__.py", line 126, in main
    submit(args, cmd_line_args, stages_fnc)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/dlhpcstarter/__main__.py", line 21, in submit
    stages_fnc(args)
  File "/public-data/yhp/cxrmate/tools/stages.py", line 49, in stages
    model = TaskModel.load_from_checkpoint(checkpoint_path=args.warm_start_ckpt_path, **vars(args))
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/core/module.py", line 1586, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 91, in _load_from_checkpoint
    model = _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 187, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
  File "/home/maiyue/anaconda3/envs/cxrmate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GTPrompt:
        Missing key(s) in state_dict: "encoder_decoder.decoder.base_model.model.bert.embeddings.word_embeddings.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.position_embeddings.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.token_type_embeddings.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.LayerNorm.weight", "encoder_decoder.decoder.base_model.model.bert.embeddings.LayerNorm.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.base_layer.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.base_layer.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.lora_A.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.query.lora_B.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.base_layer.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.base_layer.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.lora_A.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.key.lora_B.default.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.value.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.self.value.bias", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.output.dense.weight", "encoder_decoder.decoder.base_model.model.bert.encoder.layer.0.attention.output.dense.bias",

Are the network structures of the GTPrompt model and the MultiCXR model the same? Why don't they match?

Looking forward to your reply !

@anicolson
Copy link
Member

anicolson commented Sep 16, 2024

Hi @yihp,

Ah, I am sorry, I forgot about LoRA. So GTPrompt is MultiCXR + LoRA. So the model is warm started and then LoRA is added. You can see this here:

self.encoder_decoder = LongitudinalPromptMultiCXREncoderDecoderModel(
.

and here:

So all this has to happen within the class due to the differences.

This is a bit annoying, but you have to save the .ckpt as a Hugging Face model checkpoint: https://github.com/aehrc/cxrmate/blob/main/modules/transformers/multi_tf_model_to_hub.ipynb

And instead of setting warm_start_ckpt_path, set multi_ckpt_name in your config. And multi_ckpt_name should be the save_path from that notebook.

Sorry for the confusion.

@anicolson
Copy link
Member

anicolson commented Sep 24, 2024 via email

@yihp
Copy link
Author

yihp commented Sep 25, 2024

Hi @anicolson ,

Thank you very much for your reply.

I have two question:

  1. the first about how to save the aehrc/cxrmate-tf Hugging Face model checkpoint?
  2. Secondly,The paper states that for SCST, validation was performed every 1/10 of an epoch. How should it be set? Every_n_epochs: 0.1 did not work.

Looking forward to your reply !

@anicolson
Copy link
Member

anicolson commented Sep 25, 2024

With 1), I've added the remaining notebooks to save to hf checkpoints here: https://github.com/aehrc/cxrmate/tree/main/modules/transformers. It can be a bit of a nightmare with getting the key names right in the state_dict, so you might have to play around with that.

With 2), this has been added back into the configs:

.

@yihp
Copy link
Author

yihp commented Sep 25, 2024

Hi @anicolson ,

Thank you very much for your reply ! ! !
I have two question:

Firstly, If I want to use bert_score as a reward, do you have any related experiments?
Do I just need to change ckpt_name = 'microsoft/BiomedVLP-CXR-BERT-specialized' to ckpt_name = 'microsoft/bert-base-chinese',do I use the output of the last layer of cls as the word embedding vector to calculate the cosine similarity?

ckpt_name = 'microsoft/BiomedVLP-CXR-BERT-specialized'

Secondly, during the training of different models(single_tf, multi_tf, longitudinal_gt_prompt_tf, longitudinal_gt_prompt_cxr-bert.yaml), how did you set the following training parameters:

devices: 
max_epochs: 
mbatch_size: 
accumulated_mbatch_size: 

Looking forward to your reply !

@anicolson
Copy link
Member

Hi @yihp,

See https://github.com/aehrc/cxrmate-ed/blob/main/tools/rewards/bertscore.py

And https://github.com/aehrc/cxrmate-ed/blob/17bb8f1131f58c151ccb7b46667ed5a98e79e660/modules/lightning_modules/cxrmate_ed/scst_rewards.py#L9

Note that the cxrmate-ed repo will be heavily refactored in a couple of weeks.

So I was using 4xP100 GPUs to train the model.

For single_tf and multi_tf:

devices: 4
max_epochs: 32
mbatch_size: 8
accumulated_mbatch_size: 32

For longitudinal_gt_prompt_tf:

devices: 4
max_epochs: 32
mbatch_size: 2
accumulated_mbatch_size: 32

For longitudinal_gt_prompt_cxr-bert:

devices: 4
max_epochs: 32
mbatch_size: 1   # See paper for explanation of this.
accumulated_mbatch_size: 32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants