From 23316ed54eaa806be09d5037f8d0e37a93338309 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Mon, 26 Aug 2024 23:38:25 -0700 Subject: [PATCH] Revert "Update torchpippy (#2938)" This reverts commit 939ce400cbc0478fd281b9fcc9a2a85691ba9f7d. --- examples/inference/pippy/bert.py | 12 +--------- examples/inference/pippy/gpt2.py | 12 +--------- examples/inference/pippy/llama.py | 4 +--- examples/inference/pippy/t5.py | 9 ------- src/accelerate/inference.py | 39 +++++++++++++++++-------------- src/accelerate/utils/imports.py | 6 ++++- 6 files changed, 29 insertions(+), 53 deletions(-) diff --git a/examples/inference/pippy/bert.py b/examples/inference/pippy/bert.py index 474409f5d0f..bed3562337b 100644 --- a/examples/inference/pippy/bert.py +++ b/examples/inference/pippy/bert.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(1, 512), # bs x seq_len + size=(2, 512), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, @@ -49,16 +49,6 @@ # available on all GPUs # model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True) -# Create new inputs of the expected size (n_processes) -input = torch.randint( - low=0, - high=model.config.vocab_size, - size=(2, 512), # bs x seq_len - device="cpu", - dtype=torch.int64, - requires_grad=False, -) - # Move the inputs to the first device input = input.to("cuda:0") diff --git a/examples/inference/pippy/gpt2.py b/examples/inference/pippy/gpt2.py index d1f232b51de..994327f3c0d 100644 --- a/examples/inference/pippy/gpt2.py +++ b/examples/inference/pippy/gpt2.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(1, 1024), # bs x seq_len + size=(2, 1024), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, @@ -48,16 +48,6 @@ # available on all GPUs # model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True) -# Create new inputs of the expected size (n_processes) -input = torch.randint( - low=0, - high=model.config.vocab_size, - size=(2, 1024), # bs x seq_len - device="cpu", - dtype=torch.int64, - requires_grad=False, -) - # Move the inputs to the first device input = input.to("cuda:0") diff --git a/examples/inference/pippy/llama.py b/examples/inference/pippy/llama.py index 631da07bfcf..a1b2e12bb8f 100644 --- a/examples/inference/pippy/llama.py +++ b/examples/inference/pippy/llama.py @@ -27,7 +27,7 @@ # Input configs # Create example inputs for the model tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") -prompts = ("I would like to", "I really like to") # bs = 2, sending 2 per process +prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3 tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer(prompts, return_tensors="pt", padding=True) @@ -43,8 +43,6 @@ # currently we don't support `model.generate` # output = model.generate(**inputs, max_new_tokens=1) -prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3 -inputs = tokenizer(prompts, return_tensors="pt", padding=True) inputs = inputs.to(0) with torch.no_grad(): output = model(**inputs) diff --git a/examples/inference/pippy/t5.py b/examples/inference/pippy/t5.py index b134eb5372c..2f9218aef14 100644 --- a/examples/inference/pippy/t5.py +++ b/examples/inference/pippy/t5.py @@ -14,21 +14,12 @@ import time import torch -from packaging import version from transformers import AutoModelForSeq2SeqLM from accelerate import PartialState, prepare_pippy -from accelerate import __version__ as accelerate_version from accelerate.utils import set_seed -if version.parse(accelerate_version) > version.parse("0.33.0"): - raise RuntimeError( - "Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. " - "Please use a lower accelerate version and `torchpippy`, which this example uses." - ) - - # Set the random seed to have reproducable outputs set_seed(42) diff --git a/src/accelerate/inference.py b/src/accelerate/inference.py index 7ee0bdd6f63..4f9b07081b2 100644 --- a/src/accelerate/inference.py +++ b/src/accelerate/inference.py @@ -79,21 +79,22 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks): `AcceleratorState.num_processes` """ # Note: We import here to reduce import time from general modules, and isolate outside dependencies - from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline + from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points + from pippy.PipelineStage import PipelineStage # We need to annotate the split points in the model for PiPPy state = PartialState() - split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points} - pipe = pipeline( - model, - mb_args=args, - mb_kwargs=kwargs, - split_spec=split_spec, - ) - stage = pipe.build_stage(state.local_process_index, device=state.device) - schedule = ScheduleGPipe(stage, num_chunks) + annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points}) + found_batch_size = find_pippy_batch_size(args, kwargs) + if found_batch_size != num_chunks: + if args is not None: + args = pad_input_tensors(args, found_batch_size, num_chunks) + if kwargs is not None: + kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) + pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs) + stage = PipelineStage(pipe, state.local_process_index, device=state.device) - return schedule + return stage def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): @@ -142,12 +143,11 @@ def prepare_pippy( no_split_module_classes (`List[str]`): A list of class names for layers we don't want to be split. example_args (tuple of model inputs): - The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use - this method if possible. + The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible. example_kwargs (dict of model inputs) - The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a - *highly* limiting structure that requires the same keys be present at *all* inference calls. Not - recommended unless the prior condition is true for all cases. + The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure + that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition + is true for all cases. num_chunks (`int`, defaults to the number of available GPUs): The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but this can be tuned and played with. In general one should have num_chunks >= num_gpus. @@ -155,7 +155,10 @@ def prepare_pippy( If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. """ if not is_pippy_available(): - raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.") + raise ImportError( + "`pippy` was not found to be installed on your system. Please " + "install using `pip install torchpippy` or ensure you have at least version 0.2.0" + ) state = PartialState() example_args = send_to_device(example_args, "cpu") example_kwargs = send_to_device(example_kwargs, "cpu") @@ -174,7 +177,7 @@ def prepare_pippy( model.hf_split_points = split_points def forward(*args, **kwargs): - return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs) + return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs) # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` # Note: creates an infinite recursion loop with `generate` diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 15f802e5926..3badfefd684 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -178,7 +178,11 @@ def is_deepspeed_available(): def is_pippy_available(): - return is_torch_version(">=", "2.4.0") + package_exists = _is_package_available("pippy", "torchpippy") + if package_exists: + pippy_version = version.parse(importlib.metadata.version("torchpippy")) + return compare_versions(pippy_version, ">", "0.1.1") + return False def is_bf16_available(ignore_tpu=False):