diff --git a/examples/inference/pippy/bert.py b/examples/inference/pippy/bert.py index bed3562337b..474409f5d0f 100644 --- a/examples/inference/pippy/bert.py +++ b/examples/inference/pippy/bert.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(2, 512), # bs x seq_len + size=(1, 512), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, @@ -49,6 +49,16 @@ # available on all GPUs # model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True) +# Create new inputs of the expected size (n_processes) +input = torch.randint( + low=0, + high=model.config.vocab_size, + size=(2, 512), # bs x seq_len + device="cpu", + dtype=torch.int64, + requires_grad=False, +) + # Move the inputs to the first device input = input.to("cuda:0") diff --git a/examples/inference/pippy/gpt2.py b/examples/inference/pippy/gpt2.py index 994327f3c0d..e08e0149822 100644 --- a/examples/inference/pippy/gpt2.py +++ b/examples/inference/pippy/gpt2.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(2, 1024), # bs x seq_len + size=(1, 1024), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, diff --git a/examples/inference/pippy/t5.py b/examples/inference/pippy/t5.py index 2f9218aef14..8ce438acd6f 100644 --- a/examples/inference/pippy/t5.py +++ b/examples/inference/pippy/t5.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(2, 1024), # bs x seq_len + size=(1, 1024), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, diff --git a/src/accelerate/inference.py b/src/accelerate/inference.py index 8ea47f4d8cf..276dfb4fab1 100644 --- a/src/accelerate/inference.py +++ b/src/accelerate/inference.py @@ -142,9 +142,9 @@ def prepare_pippy( no_split_module_classes (`List[str]`): A list of class names for layers we don't want to be split. example_args (tuple of model inputs): - The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible. + The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use this method if possible. example_kwargs (dict of model inputs) - The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure + The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a *highly* limiting structure that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition is true for all cases. num_chunks (`int`, defaults to the number of available GPUs):