-
Notifications
You must be signed in to change notification settings - Fork 191
/
test_query_topic_detection_mapper.py
59 lines (46 loc) · 1.88 KB
/
test_query_topic_detection_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import unittest
import json
from loguru import logger
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.query_topic_detection_mapper import QueryTopicDetectionMapper
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.common_utils import nested_access
class TestQueryTopicDetectionMapper(DataJuicerTestCaseBase):
hf_model = 'dstefa/roberta-base_topic_classification_nyt_news'
zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en'
def _run_op(self, op, samples, label_key, targets):
dataset = Dataset.from_list(samples)
dataset = dataset.map(op.process, batch_size=2)
for sample, target in zip(dataset, targets):
label = nested_access(sample[Fields.meta], label_key)
self.assertEqual(label, target)
def test_default(self):
samples = [{
'query': '今天火箭和快船的比赛谁赢了。'
},{
'query': '你最近身体怎么样。'
}
]
targets = ['Sports', 'Health and Wellness']
op = QueryTopicDetectionMapper(
hf_model = self.hf_model,
zh_to_en_hf_model = self.zh_to_en_hf_model,
)
self._run_op(op, samples, MetaKeys.query_topic_label, targets)
def test_no_zh_to_en(self):
samples = [{
'query': '这样好吗?'
},{
'query': 'Is this okay?'
}
]
targets = ['Lifestyle and Fashion', 'Health and Wellness']
op = QueryTopicDetectionMapper(
hf_model = self.hf_model,
zh_to_en_hf_model = None,
)
self._run_op(op, samples, MetaKeys.query_topic_label, targets)
if __name__ == '__main__':
unittest.main()