diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 79bf2ac657..5cffbb0808 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -144,7 +144,7 @@ def predict(self, inputs: torch.Tensor) -> Any: for block in self.model.transformer.h: block.attn.kv_cache.reset_parameters() - yield generate( + yield from generate( self.model, inputs, max_returned_tokens, @@ -152,11 +152,10 @@ def predict(self, inputs: torch.Tensor) -> Any: top_k=self.top_k, top_p=self.top_p, eos_id=self.tokenizer.eos_id - ) + ) - def encode_response(self, output_stream): - for outputs in output_stream: - yield [json.dumps({"output": self.tokenizer.decode(output)}) for output in outputs] + def encode_response(self, output): + yield {"output": self.tokenizer.decode(next(output))} def run_server( diff --git a/tests/test_serve.py b/tests/test_serve.py index dcc79319ea..0e9b0c6d44 100644 --- a/tests/test_serve.py +++ b/tests/test_serve.py @@ -47,10 +47,5 @@ def test_simple(tmp_path): ) with TestClient(server.app) as client: response = client.post("/predict", json={"prompt": "Hello world"}) - response_list = response.json() - parsed_response = [] - for item in response_list: - parsed_dict = json.loads(item) - parsed_response.append(parsed_dict) - assert parsed_response[0]["output"][:19] == "Hello world statues" + assert response.json()["output"][:19] == "Hello world statues"