From 8652b9ccec7b0735e76076856325931506b438bd Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Thu, 21 Nov 2024 08:24:25 +0800 Subject: [PATCH] fix: migrate types --- .../test_writer/evals/test_classification_enums.py | 9 ++++++--- .../evals/test_classification_literals.py | 14 +++++++++++--- tests/llm/test_writer/evals/test_extract_users.py | 6 +++--- .../test_writer/evals/test_sentiment_analysis.py | 5 +++-- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/llm/test_writer/evals/test_classification_enums.py b/tests/llm/test_writer/evals/test_classification_enums.py index cf555d27c..6426130c3 100644 --- a/tests/llm/test_writer/evals/test_classification_enums.py +++ b/tests/llm/test_writer/evals/test_classification_enums.py @@ -1,6 +1,5 @@ import enum from itertools import product -from typing import List, Tuple from writerai import Writer import pytest @@ -38,7 +37,9 @@ class SinglePrediction(BaseModel): @pytest.mark.parametrize("model, data, mode", product(models, data, modes)) -def test_writer_classification(model: str, data: List[Tuple], mode: instructor.Mode): +def test_writer_classification( + model: str, data: list[tuple[str, Labels]], mode: instructor.Mode +): client = instructor.from_writer(client=Writer(), mode=mode) input, expected = data @@ -82,7 +83,9 @@ class MultiClassPrediction(BaseModel): @pytest.mark.parametrize("model, data, mode", product(models, data, modes)) -def test_writer_multi_classify(model: str, data: List[Tuple], mode: instructor.Mode): +def test_writer_multi_classify( + model: str, data: list[tuple[str, list[MultiLabels]]], mode: instructor.Mode +): client = instructor.from_writer(client=Writer(), mode=mode) if (mode, model) in { diff --git a/tests/llm/test_writer/evals/test_classification_literals.py b/tests/llm/test_writer/evals/test_classification_literals.py index fe615d19c..9ff0d8fdf 100644 --- a/tests/llm/test_writer/evals/test_classification_literals.py +++ b/tests/llm/test_writer/evals/test_classification_literals.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Literal, List, Tuple +from typing import Literal from writerai import AsyncWriter import pytest @@ -26,7 +26,11 @@ class SinglePrediction(BaseModel): @pytest.mark.parametrize("model, data, mode", product(models, data, modes)) @pytest.mark.asyncio -async def test_classification(model: str, data: List[Tuple], mode: instructor.Mode): +async def test_classification( + model: str, + data: list[tuple[str, Literal["spam", "not_spam"]]], + mode: instructor.Mode, +): client = instructor.from_writer(client=AsyncWriter(), mode=mode) input, expected = data @@ -65,7 +69,11 @@ class MultiClassPrediction(BaseModel): @pytest.mark.parametrize("model, data, mode", product(models, data, modes)) @pytest.mark.asyncio -async def test_writer_multi_classify(model: str, data: List[Tuple], mode: instructor.Mode): +async def test_writer_multi_classify( + model: str, + data: list[tuple[str, list[Literal["billing", "general_query", "hardware"]]]], + mode: instructor.Mode, +): client = instructor.from_writer(client=AsyncWriter(), mode=mode) input, expected = data diff --git a/tests/llm/test_writer/evals/test_extract_users.py b/tests/llm/test_writer/evals/test_extract_users.py index a90ac3c7f..fac768da1 100644 --- a/tests/llm/test_writer/evals/test_extract_users.py +++ b/tests/llm/test_writer/evals/test_extract_users.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import pytest from itertools import product from pydantic import BaseModel @@ -21,7 +19,9 @@ class UserDetails(BaseModel): @pytest.mark.parametrize("model, data, mode", product(models, test_data, modes)) -def test_writer_extract(model: str, data: List[Tuple], mode: instructor.Mode): +def test_writer_extract( + model: str, data: list[tuple[str, str, int]], mode: instructor.Mode +): client = instructor.from_writer(client=Writer(), mode=mode) sample_data, expected_name, expected_age = data diff --git a/tests/llm/test_writer/evals/test_sentiment_analysis.py b/tests/llm/test_writer/evals/test_sentiment_analysis.py index 509a5e24c..9e84ec381 100644 --- a/tests/llm/test_writer/evals/test_sentiment_analysis.py +++ b/tests/llm/test_writer/evals/test_sentiment_analysis.py @@ -1,6 +1,5 @@ import enum from itertools import product -from typing import List, Tuple from pydantic import BaseModel from writerai import Writer @@ -36,7 +35,9 @@ class SentimentAnalysis(BaseModel): @pytest.mark.parametrize("model, data, mode", product(models, test_data, modes)) -def test_writer_sentiment_analysis(model: str, data: List[Tuple], mode: instructor.Mode): +def test_writer_sentiment_analysis( + model: str, data: List[Tuple], mode: instructor.Mode +): client = instructor.from_writer(client=Writer(), mode=mode) sample_data, expected_sentiment = data