Skip to content

Commit

Permalink
extend tokens to 50
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 16, 2024
1 parent 88870d6 commit 7aaa778
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def __init__(self,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
temperature: float = 0.8,
top_k: int = 200,
max_generated_tokens: int = 30) -> None:
max_new_tokens: int = 50) -> None:

super().__init__()
self.checkpoint_dir = checkpoint_dir
self.precision = precision
self.quantize = quantize
self.temperature = temperature
self.top_k = top_k
self.max_generated_tokens = max_generated_tokens
self.max_new_tokens = max_new_tokens

def setup(self, device: str) -> None:
# Setup the model so it can be called in `predict`.
Expand Down Expand Up @@ -80,7 +80,7 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_generated_tokens
max_returned_tokens = prompt_length + self.max_new_tokens

y = generate(
self.model,
Expand All @@ -107,7 +107,7 @@ def run_server(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
temperature: float = 0.8,
top_k: int = 200,
max_generated_tokens: int = 50,
max_new_tokens: int = 50,
devices: int = 1,
accelerator: str = "cuda",
port: int = 8000
Expand All @@ -126,7 +126,7 @@ def run_server(
Values below 1 decrease randomness.
top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel
generated text but can also lead to more incoherent texts.
max_generated_tokens: How many new tokens, in addition to the prompt length, to generate.
max_new_tokens: The number of generation steps to take.
devices: How many devices/GPUs to use.
accelerator: The type of accelerator to use. For example, "cuda" or "cpu".
port: The network port number on which the model is configured to be served.
Expand All @@ -140,7 +140,7 @@ def run_server(
quantize=quantize,
temperature=temperature,
top_k=top_k,
max_generated_tokens=max_generated_tokens,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices)
Expand Down

0 comments on commit 7aaa778

Please sign in to comment.