Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile: generate should use call instead of forward #34906

Open
2 of 4 tasks
SilverSoldier opened this issue Nov 25, 2024 · 4 comments · May be fixed by #34907
Open
2 of 4 tasks

torch.compile: generate should use call instead of forward #34906

SilverSoldier opened this issue Nov 25, 2024 · 4 comments · May be fixed by #34907
Labels

Comments

@SilverSoldier
Copy link

SilverSoldier commented Nov 25, 2024

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.25.2
  • Safetensors version: 0.4.5
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.6.0.dev20241008+cu124 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: NO

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/opt-125m"
length = 100

prompt_text = 'In a small, bustling cafe nestled in the heart of a vibrant city, a serendipitous event unfolded, leaving a lasting impression on all who witnessed it. As the patrons sat sipping their coffees and engaging in animated conversations, a talented street musician entered the cafe, carrying a weathered guitar and radiating an aura of creativity.'

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.compile()
input_ids = tokenizer(prompt_text, add_special_tokens=False, return_tensors='pt').input_ids
output = model.generate(input_ids, max_new_tokens=length)

Expected behavior

Expected behaviour is that we use the compiled forward function.

When compiling using the model.compile() API, the call method uses an internal variable with the compiled forward instead of the uncompiled forward.

(I raised a related issue in pytorch, this is the Option 2 there)

So generate, should use the call method instead of the forward to use the compiled version of forward (for this particular case of model.compile).
However, recent changes have changed this call to model.forward() instead of model() for the non-first token :

def _sample():
  ...
  def model_forward(model, *args, **kwargs):
      return model.forward(*args, **kwargs)
  ...
      if i == 0:
          outputs = self(**model_inputs, return_dict=True)
          i += 1
      else:
          outputs = model_forward(self, return_dict=True, **model_inputs)

model_forward should be changed to call model() instead of model.forward()

@SilverSoldier SilverSoldier linked a pull request Nov 25, 2024 that will close this issue
5 tasks
@ydshieh
Copy link
Collaborator

ydshieh commented Nov 25, 2024

Hi, so if I understand correctly, the goal of this change is to make the future model(...) (potentially used outside generate) will use the already compiled version (if that is done in generate here). It's good, but currently code isn't really a bug, right?

@SilverSoldier
Copy link
Author

Previous code called model() which works with model.compile(). Current code (from this recent commit) changed this to model.forward().
So current code does not work as expected with model.compile(), it uses compiled version for the first iteration and eager mode for other iterations. This is sort of a bug.

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Nov 25, 2024

Hi @SilverSoldier, thanks for opening the issue! Indeed I agree that we should use __call__, for consistency between all methods. We wanted to refine this part anyway, this was the first shot! I'll take care of it very soon. Curious to see what the pytorch team has to say with regards to the consistency of the different way to compile a model, but I'm not sure this can be solved in general.

Also, we introduced the recent changes because it is inefficient to compile the forward in all cases (we only want to compile the iterative decoding part, not the prefill), so your code would me more efficient if you don't call compile yourself, and only use cache_implementation=static in generate, and let us only compile the iterative decoding part 🤗

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 25, 2024

I see, it's another API of compile! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants