Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 16, 2024
1 parent 02e25b4 commit 090107c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def test_simple(tmp_path):
yaml.dump(asdict(ours_config), fp)

accelerator = "cpu"
server = LitServer(SimpleLitAPI(checkpoint_dir=tmp_path), accelerator=accelerator, devices=1, timeout=60)
server = LitServer(
SimpleLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1),
accelerator=accelerator, devices=1, timeout=60
)

with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello, World"})
if accelerator == "gpu":
assert response.json()["output"][:25] == "Hello, World gcc exchange", response.json()["output"][:25]
else:
assert response.json()["output"][:25] == "Hello, World Associatedim", response.json()["output"][:25]
response = client.post("/predict", json={"prompt": "Hello world"})
# Model is a small random model, not trained, hence the gibberish.
# We are just testing that the server works.
assert response.json()["output"][:19] == "Hello world statues"

0 comments on commit 090107c

Please sign in to comment.