Skip to content

Commit

Permalink
Start adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Aug 23, 2024
1 parent 37d1473 commit 1f74cd3
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
12 changes: 11 additions & 1 deletion examples/inference/pippy/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
size=(1, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
Expand All @@ -49,6 +49,16 @@
# available on all GPUs
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)

# Create new inputs of the expected size (n_processes)
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)

# Move the inputs to the first device
input = input.to("cuda:0")

Expand Down
2 changes: 1 addition & 1 deletion examples/inference/pippy/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
size=(1, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
Expand Down
2 changes: 1 addition & 1 deletion examples/inference/pippy/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
size=(1, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def prepare_pippy(
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of model inputs):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use this method if possible.
example_kwargs (dict of model inputs)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`, defaults to the number of available GPUs):
Expand Down

0 comments on commit 1f74cd3

Please sign in to comment.