Skip to content

Commit

Permalink
streamline usage
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 17, 2024
1 parent c890334 commit e981528
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
9 changes: 4 additions & 5 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,18 @@ def predict(self, inputs: torch.Tensor) -> Any:
for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

yield generate(
yield from generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
)
)

def encode_response(self, output_stream):
for outputs in output_stream:
yield [json.dumps({"output": self.tokenizer.decode(output)}) for output in outputs]
def encode_response(self, output):
yield {"output": self.tokenizer.decode(next(output))}


def run_server(
Expand Down
7 changes: 1 addition & 6 deletions tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,5 @@ def test_simple(tmp_path):
)
with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello world"})
response_list = response.json()
parsed_response = []

for item in response_list:
parsed_dict = json.loads(item)
parsed_response.append(parsed_dict)
assert parsed_response[0]["output"][:19] == "Hello world statues"
assert response.json()["output"][:19] == "Hello world statues"

0 comments on commit e981528

Please sign in to comment.