-
Notifications
You must be signed in to change notification settings - Fork 29
/
retrieval_metrics_test.py
88 lines (66 loc) · 3.33 KB
/
retrieval_metrics_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pytest
from continuous_eval.llm_factory import LLMFactory
from continuous_eval.metrics.retrieval import (
ExactSentenceMatch,
LLMBasedContextCoverage,
LLMBasedContextPrecision,
PrecisionRecallF1,
RankedRetrievalMetrics,
RougeChunkMatch,
RougeSentenceMatch,
TokenCount,
)
from tests.helpers import example_datum
from tests.helpers.utils import all_close, in_zero_one
def test_precision_recall_exact_chunk_match():
data = [example_datum.CAPITAL_OF_FRANCE, example_datum.ROMEO_AND_JULIET]
expected_results = [
{"context_precision": 0.0, "context_recall": 0.0, "context_f1": 0.0},
{"context_precision": 1.0, "context_recall": 1.0, "context_f1": 1.0},
]
metric = PrecisionRecallF1(RougeChunkMatch(threshold=0.7))
assert all(all_close(metric(**datum), expected) for datum, expected in zip(data, expected_results)) # type: ignore
def test_precision_recall_exact_sentence_match():
data = [example_datum.ROMEO_AND_JULIET]
expected_results = [{"context_precision": 1.0, "context_recall": 1.0, "context_f1": 1.0}]
metric = PrecisionRecallF1(RougeSentenceMatch(threshold=0.8))
assert all(all_close(metric(**datum), expected) for datum, expected in zip(data, expected_results)) # type: ignore
def test_precision_recall_rouge_sentence_match():
data = [example_datum.CAPITAL_OF_FRANCE, example_datum.IMPLICATIONS_GLOBAL_WARMING]
expected_results = [
{"context_precision": 0.0, "context_recall": 0.0, "context_f1": 0.0},
{
"context_precision": 0.09090909090909091,
"context_recall": 0.5,
"context_f1": 0.15384615384615385,
},
]
metric = PrecisionRecallF1(RougeSentenceMatch())
assert all(all_close(metric(**datum), expected) for datum, expected in zip(data, expected_results)) # type: ignore
def test_ranked_retrieval_exact_chunk_match():
data = [example_datum.CAPITAL_OF_FRANCE, example_datum.ROMEO_AND_JULIET]
expected_results = [
{"average_precision": 0, "reciprocal_rank": 0, "ndcg": 0.0},
{"average_precision": 1.0, "reciprocal_rank": 1.0, "ndcg": 1.0},
]
metric = RankedRetrievalMetrics(RougeChunkMatch())
assert all(all_close(metric(**datum), expected) for datum, expected in zip(data, expected_results)) # type: ignore
def test_ranked_retrieval_exact_sentence_match():
with pytest.raises(AssertionError):
RankedRetrievalMetrics(ExactSentenceMatch())
def test_llm_based_context_precision():
data = [example_datum.CAPITAL_OF_FRANCE, example_datum.ROMEO_AND_JULIET]
metric = LLMBasedContextPrecision()
assert all(in_zero_one(metric(**datum)) for datum in data)
def test_llm_based_context_coverage_openai():
data = [example_datum.CAPITAL_OF_FRANCE, example_datum.ROMEO_AND_JULIET]
metric = LLMBasedContextCoverage(model=LLMFactory("gpt-3.5-turbo-1106"))
assert all(in_zero_one(metric(**datum)["LLM_based_context_coverage"]) for datum in data)
def test_token_count():
data = [example_datum.CAPITAL_OF_FRANCE, example_datum.ROMEO_AND_JULIET]
metric = TokenCount("o200k_base")
expected = [17, 16]
assert (result := [metric(**datum)["num_tokens"] for datum in data]) == expected, result
expected = [17, 18]
metric = TokenCount("approx")
assert (result := [metric(**datum)["num_tokens"] for datum in data]) == expected, result