Skip to content

Commit

Permalink
fix: migrate types
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanleomk committed Nov 21, 2024
1 parent a0b8dc5 commit 8652b9c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
9 changes: 6 additions & 3 deletions tests/llm/test_writer/evals/test_classification_enums.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
from itertools import product
from typing import List, Tuple
from writerai import Writer

import pytest
Expand Down Expand Up @@ -38,7 +37,9 @@ class SinglePrediction(BaseModel):


@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
def test_writer_classification(model: str, data: List[Tuple], mode: instructor.Mode):
def test_writer_classification(
model: str, data: list[tuple[str, Labels]], mode: instructor.Mode
):
client = instructor.from_writer(client=Writer(), mode=mode)

input, expected = data
Expand Down Expand Up @@ -82,7 +83,9 @@ class MultiClassPrediction(BaseModel):


@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
def test_writer_multi_classify(model: str, data: List[Tuple], mode: instructor.Mode):
def test_writer_multi_classify(
model: str, data: list[tuple[str, list[MultiLabels]]], mode: instructor.Mode
):
client = instructor.from_writer(client=Writer(), mode=mode)

if (mode, model) in {
Expand Down
14 changes: 11 additions & 3 deletions tests/llm/test_writer/evals/test_classification_literals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import product
from typing import Literal, List, Tuple
from typing import Literal
from writerai import AsyncWriter

import pytest
Expand All @@ -26,7 +26,11 @@ class SinglePrediction(BaseModel):

@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
@pytest.mark.asyncio
async def test_classification(model: str, data: List[Tuple], mode: instructor.Mode):
async def test_classification(
model: str,
data: list[tuple[str, Literal["spam", "not_spam"]]],
mode: instructor.Mode,
):
client = instructor.from_writer(client=AsyncWriter(), mode=mode)

input, expected = data
Expand Down Expand Up @@ -65,7 +69,11 @@ class MultiClassPrediction(BaseModel):

@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
@pytest.mark.asyncio
async def test_writer_multi_classify(model: str, data: List[Tuple], mode: instructor.Mode):
async def test_writer_multi_classify(
model: str,
data: list[tuple[str, list[Literal["billing", "general_query", "hardware"]]]],
mode: instructor.Mode,
):
client = instructor.from_writer(client=AsyncWriter(), mode=mode)

input, expected = data
Expand Down
6 changes: 3 additions & 3 deletions tests/llm/test_writer/evals/test_extract_users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List, Tuple

import pytest
from itertools import product
from pydantic import BaseModel
Expand All @@ -21,7 +19,9 @@ class UserDetails(BaseModel):


@pytest.mark.parametrize("model, data, mode", product(models, test_data, modes))
def test_writer_extract(model: str, data: List[Tuple], mode: instructor.Mode):
def test_writer_extract(
model: str, data: list[tuple[str, str, int]], mode: instructor.Mode
):
client = instructor.from_writer(client=Writer(), mode=mode)

sample_data, expected_name, expected_age = data
Expand Down
5 changes: 3 additions & 2 deletions tests/llm/test_writer/evals/test_sentiment_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
from itertools import product
from typing import List, Tuple

from pydantic import BaseModel
from writerai import Writer
Expand Down Expand Up @@ -36,7 +35,9 @@ class SentimentAnalysis(BaseModel):


@pytest.mark.parametrize("model, data, mode", product(models, test_data, modes))
def test_writer_sentiment_analysis(model: str, data: List[Tuple], mode: instructor.Mode):
def test_writer_sentiment_analysis(
model: str, data: List[Tuple], mode: instructor.Mode
):
client = instructor.from_writer(client=Writer(), mode=mode)

sample_data, expected_sentiment = data
Expand Down

0 comments on commit 8652b9c

Please sign in to comment.