-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GTPrompt model training problem #16
Comments
Hi @yihp, If Else, it will automatically load the Line 103 in 820607a
and Line 110 in 820607a
Hence, remove |
Hi @anicolson , Thank you very much for your reply. Are the network structures of the Looking forward to your reply ! |
Hi @yihp, Yes, the MultiCXR model is used to warm-start GTPrompt. |
Hi @anicolson , But I don't know if there is a problem with the So can I specify the |
OK, I will go to the lab to try it later. Thank you very much for your reply ! |
Hi @anicolson , I specify
Are the network structures of the GTPrompt model and the MultiCXR model the same? Why don't they match? Looking forward to your reply ! |
Hi @yihp, Ah, I am sorry, I forgot about LoRA. So GTPrompt is MultiCXR + LoRA. So the model is warm started and then LoRA is added. You can see this here:
and here:
So all this has to happen within the class due to the differences. This is a bit annoying, but you have to save the .ckpt as a Hugging Face model checkpoint: https://github.com/aehrc/cxrmate/blob/main/modules/transformers/multi_tf_model_to_hub.ipynb And instead of setting Sorry for the confusion. |
Hi, I see you removed your comment, are you still interested in this?
From: yihp ***@***.***>
Date: Thursday, 19 September 2024 at 7:08 pm
To: aehrc/cxrmate ***@***.***>
Cc: Nicolson, Aaron (H&B, Herston) ***@***.***>, Mention ***@***.***>
Subject: Re: [aehrc/cxrmate] GTPrompt model training problem (Issue #16)
Hi @anicolson<https://github.com/anicolson> ,
I have another question about how to save the aehrc/cxrmate-tf Hugging Face model checkpoint?
Is aehrc/cxrmate-tf the LongitudinalPromptMultiCXREncoderDecoderModel model class?Am I converting it in the following way:
# Encoder & decoder config:
config_decoder = transformers.BertConfig(
vocab_size=151659,
num_hidden_layers=6,
type_vocab_size=2,
) # BERT as it includes token_type_ids.
encoder_ckpt_name = 'microsoft/cvt-21-384-22k'
config_encoder = CvtWithProjectionHeadConfig.from_pretrained(
'/public-data/yhp/cxrmate/microsoft/cvt-21-384-22k',
# os.path.join(ckpt_zoo_dir, encoder_ckpt_name),
local_files_only=True,
projection_size=config_decoder.hidden_size,
)
config = transformers.VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
# Encoder-to-decoder instance:
LongitudinalPromptMultiCXREncoderDecoderModel.register_for_auto_class("AutoModel")
encoder_decoder = LongitudinalPromptMultiCXREncoderDecoderModel(config=config)
for key in list(state_dict.keys()):
if 'encoder_projection' in key:
state_dict[key.replace('encoder_projection', 'encoder.projection_head.projection')] = state_dict.pop(key)
elif 'last_hidden_state_layer_norm' in key:
state_dict[key.replace('last_hidden_state_layer_norm', 'encoder.projection_head.layer_norm')] = state_dict.pop(key)
elif 'encoder.encoder' in key:
state_dict[key.replace('encoder.encoder', 'encoder.cvt.encoder')] = state_dict.pop(key)
elif 'encoder_decoder.' in key:
state_dict[key.replace('encoder_decoder.', '')] = state_dict.pop(key)
else:
warnings.warn(f'Key not found: {key}')
encoder_decoder.load_state_dict(state_dict)
encoder_decoder.save_pretrained(save_path)
I converted it like this and trained scst model use config public-longitudinal_gt_prompt_cxr-bert.yaml, but the model output was garbled
Looking forward to your reply !
—
Reply to this email directly, view it on GitHub<#16 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AGHGZ7RWIFAYSTYRCSGTID3ZXKH67AVCNFSM6AAAAABOEUCMD2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNRQGQZTQMRZG4>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Hi @anicolson , Thank you very much for your reply. I have two question:
Looking forward to your reply ! |
With 1), I've added the remaining notebooks to save to hf checkpoints here: https://github.com/aehrc/cxrmate/tree/main/modules/transformers. It can be a bit of a nightmare with getting the key names right in the state_dict, so you might have to play around with that. With 2), this has been added back into the configs:
|
Hi @anicolson , Thank you very much for your reply ! ! ! Firstly, If I want to use bert_score as a reward, do you have any related experiments? cxrmate/tools/rewards/cxrbert.py Line 15 in b106927
Secondly, during the training of different models(
Looking forward to your reply ! |
Hi @yihp, See https://github.com/aehrc/cxrmate-ed/blob/main/tools/rewards/bertscore.py Note that the cxrmate-ed repo will be heavily refactored in a couple of weeks. So I was using 4xP100 GPUs to train the model. For single_tf and multi_tf:
For longitudinal_gt_prompt_tf:
For longitudinal_gt_prompt_cxr-bert:
|
Hi! Thanks for your contribution. It is an excellent piece of work!
My task language is Chinese. I have trained the
MultiCXR
model on my own vocabulary, I have the following problems when training the GTPrompt model:I cannot load the
multi_ckpt_name: aehrc/cxrmate-multi-tf
you trained, because the word embedding dimension size is different, and thecxrmate-multi-tf-cn
I trained myself did not save the model file in the pytorch_model.bin format, so I don’t know how to load it.How should I load the trained
MultiCXR
model in.ckpt
format.The text was updated successfully, but these errors were encountered: