Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PAD] tokens override text report generation when batch size > 1 #11

Open
AceMcAwesome77 opened this issue Aug 14, 2024 · 5 comments
Open

Comments

@AceMcAwesome77
Copy link

AceMcAwesome77 commented Aug 14, 2024

Hi, thanks again for this helpful repo. I am implementing this model training code but running into a strange problem. When I set my batch_size to 1 for both TF and SCST, I don't have any problems during training or validation and all the text reports look good. The train step of TF training also works well when batch size >1, but when batch size >1 during validation step, I get text reports that look like this (I am leaving the special tokens in on purpose here):

[BOS][PAD][PAD][PAD][PAD][PAD], catheters and devices: left chest wall pacemaker. lungs: left lung base atelectasis or infiltrate. pleural spaces: left pleural effusion. heart/mediastinum: cardiomegaly. bones/joints: unremarkable.[SEP]left pleural effusion. left lung base atelectasis or infiltrate.[EOS]

These results have the first part of the text report cut off - the beginning should read "Tubes, catheters, and devices". This occurs for roughly half of my samples when batch_size = 2. This also occurs during the training step of SCST (but not TF) if batch size >1, which makes me think it is related to the self.encoder_decoder.generate() function. It looks like it occurs when the prompt lengths of the previous report are different between the 2 samples in the batch, so it adds [PAD] tokens to the shorter one to make them the same length. However these [PAD] tokens do appear to be appropriately masked by the decoder_attention_mask. I don't understand why the [PAD] tokens would override other correct caption words, rather than just have the generated text start after the end of the [PAD] tokens.

Is this a common problem for this model? I see in the paper that the model was trained with a mbatch size of 32 so higher batch sizes must be possible.

As a side question - is there any expected difference in function of the model if the [BOS] token were to come after all those [PAD] tokens in the prompt, rather than before?

Thanks!

@AceMcAwesome77
Copy link
Author

After investigating further, I think this is occurring because of the difference in where PAD tokens are placed in relation to the BOS token between training and validation. During training_step, when add_bos_token_id is False, the BOS token is not appended in the tokenize_prompt function, and is later appended in the tokenize_report_teacher_forcing function. This means that the padding done by the previous_sections = tokenizer() function in tokenize_prompt when batch size >1 places PAD tokens before BOS during training_step. During validation_step, add_bos_token_id is True so is appends BOS prior to all the PAD tokens that come from the padding=longest parameter in previous_sections = tokenizer().

In theory I wouldn't think this would affect anything if the attention mask is blinding all the PAD tokens. However, when I manually changed the code to append the BOS token after the PAD tokens in tokenize_prompt even during validation_step, the issue I described above (where the beginning of generated text was cut off when using batch size >1) was resolved. However this did introduce a new issue where the beginning of the generated report is sensible but the end of the report devolves into repetitions and nonsense, which I am looking into now.

@anicolson
Copy link
Member

Hi @AceMcAwesome77,

Could you look at the inputs to the decoder during training versus testing here?:

Specifically, the decoder_input_ids, the decoder_attention_mask, the decoder_token_type_ids, and the decoder_position_ids? (decoder_token_type_ids and the decoder_position_ids should be in kwargs_decoder).

Checking the location of the special tokens in decoder_input_ids will be important. The attention mask dealt with the padding fine.

I'll be honest, I developed and tested the model with a mini-batch size greater than one.

Can I ask about your training schema? Including your dataset size, etc?

@AceMcAwesome77
Copy link
Author

AceMcAwesome77 commented Aug 27, 2024

I had modified that self.decoder() function to take custom decoder embeddings so I could add age and gender as covariates to the embeddings. Therefore my inputs to the self.decoder() function were a little different than the source code:

    decoder_outputs = self.decoder(
          input_ids=None,
          attention_mask=new_attention_mask,
          encoder_hidden_states=encoder_hidden_states,
          encoder_attention_mask=encoder_outputs.attention_mask,
          inputs_embeds=new_decoder_embeddings,
          output_attentions=output_attentions,
          output_hidden_states=output_hidden_states,
          use_cache=use_cache,
          past_key_values=past_key_values,
          return_dict=return_dict,
          **kwargs_decoder,
      )

I had prepended two items (age, gender) to create the new_decoder_embeddings and also prepended two 1-values to create the new_attention_mask. I thought that should leave the mask in place correctly, but it could have created problems compared to the original code.

However my most recent edit seems to have fixed the issue entirely so I think we can close this issue. I was able to resolve the issue by changing some of the logic in the tokenize_prompt function to move validation step BOS token placement from before the padding to after padding like so:

        previous_sections = [
            f'[PMT]{ind} {pf}[PMT-SEP]{pi}' \
            for ind, pf, pi in zip(indication, previous_findings, previous_impression)
        ]
                
        # Tokenize the combined sections
        previous_sections = tokenizer(
            previous_sections,
            padding='longest',
            truncation=True,
            max_length=max_len,
            return_tensors='pt',
            return_token_type_ids=False,
            add_special_tokens=False,
        ).to(self.device)
        
        if add_bos_token_id:
            # Get the batch size (number of sequences) and current sequence length
            batch_size, seq_len = previous_sections.input_ids.shape
            
            if seq_len == max_len:
                # Replace the last token with BOS token
                previous_sections.input_ids[:, -1] = 1
                
            else:    
                # Append the token_id '1' for BOS to the end of each sequence in previous_sections.input_ids
                append_token = torch.full((batch_size, 1), 1, dtype=previous_sections.input_ids.dtype, device=self.device)
                previous_sections.input_ids = torch.cat([previous_sections.input_ids, append_token], dim=1)
                
                # Append '1' to the end of each sequence in previous_sections.attention_mask
                append_attention = torch.full((batch_size, 1), 1, dtype=previous_sections.attention_mask.dtype, device=self.device)
                previous_sections.attention_mask = torch.cat([previous_sections.attention_mask, append_attention], dim=1)

and then by commenting out this section:

        # Ensure BOS token identifier is at the end of the input_ids:
        if previous_sections.input_ids.shape[1] == max_len:
            previous_sections.input_ids[:, -1] = torch.where(
                previous_sections.attention_mask[:, -1] == 1,
                tokenizer.bos_token_id,
                previous_sections.input_ids[:, -1],
            ) 

In the example above, I also had added clinical indication as 'ind' to the prompt. I think this is far enough off from the original code base that the bug is probably related to my edits. But, if it would generally be better to have the order of the prompt padding and BOS token synchronized between training_step and validation_step, this is one way to do it.

I'm training this on 7 out of the 8 A100's on a Nvidia DGX. The dataset is about 300,000 chest xrays right now.

@anicolson
Copy link
Member

Wow, good stuff.

I agree, It would be much cleaner if the BOS token was just added solely in tokenize_prompt.

The indication section should give you a nice boost in performance.

That's a nice amount of compute. I used 4xP100s for CXRMate with DDP. Maybe you can look into using accumulated_mbatch_size :https://github.com/csiro-mlai/dl_hpc_starter_pack/blob/efaae3e703492fabffcc69edfd8976f72d0fd1bd/src/dlhpcstarter/trainer.py#L360C8-L360C31

You could set in your config:

devices: 4
mbatch_size: 4 # Or less.
accumulated_mbatch_size: 32 # Your effective mini-batch size.

And this would help you avoid using DeepSpeed.

See the following for more details: https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html#accumulate-gradients

It is calculated here for you from your num_nodes, devices, and mini-batch size:

https://github.com/csiro-mlai/dl_hpc_starter_pack/blob/efaae3e703492fabffcc69edfd8976f72d0fd1bd/src/dlhpcstarter/trainer.py#L362

@AceMcAwesome77
Copy link
Author

Thanks for the tip about gradient accumulation, that could be useful particularly since I've had problems with NaN loss in the middle of epochs at low batch size. I don't want to avoid using deepspeed entirely though - I've found that even on a single GPU, deepspeed enables a reduction in required memory during training; when experimenting with a batch_size of 1, I was able to push my maximum image resolution from 768x768 all the way to 1120x1120 without OOM errors using stage_2_offload. I'm training on 768x768 images with a batch size of 4 right now, but I'd like to push the image resolution higher eventually and the gradient accumulation could be helpful there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants