-
Notifications
You must be signed in to change notification settings - Fork 191
/
query_sentiment_detection_mapper.py
85 lines (70 loc) · 3.25 KB
/
query_sentiment_detection_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Dict, Optional
from data_juicer.utils.common_utils import nested_set
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, Mapper
OP_NAME = 'query_sentiment_detection_mapper'
@OPERATORS.register_module(OP_NAME)
class QuerySentimentDetectionMapper(Mapper):
"""
Mapper to predict user's sentiment label ('negative', 'neutral' and
'positive') in query. Input from query_key.
Output label and corresponding score for the query, which is
store in 'query_sentiment_label' and
'query_sentiment_label_score' in Data-Juicer meta field.
"""
_accelerator = 'cuda'
_batched_op = True
def __init__(
self,
hf_model:
str = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis', # noqa: E501 E131
zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en',
model_params: Dict = {},
zh_to_en_model_params: Dict = {},
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model ID to predict sentiment label.
:param zh_to_en_hf_model: Translation model from Chinese to English.
If not None, translate the query from Chinese to English.
:param model_params: model param for hf_model.
:param zh_to_en_model_params: model param for zh_to_hf_model.
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model,
return_pipe=True,
pipe_task='text-classification',
**model_params)
if zh_to_en_hf_model is not None:
self.zh_to_en_model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=zh_to_en_hf_model,
return_pipe=True,
pipe_task='translation',
**zh_to_en_model_params)
else:
self.zh_to_en_model_key = None
def process_batched(self, samples, rank=None):
queries = samples[self.query_key]
if self.zh_to_en_model_key is not None:
translater, _ = get_model(self.zh_to_en_model_key, rank,
self.use_cuda())
results = translater(queries)
queries = [item['translation_text'] for item in results]
classifier, _ = get_model(self.model_key, rank, self.use_cuda())
results = classifier(queries)
labels = [r['label'] for r in results]
scores = [r['score'] for r in results]
if Fields.meta not in samples:
samples[Fields.meta] = [{} for val in labels]
for i in range(len(samples[Fields.meta])):
samples[Fields.meta][i] = nested_set(
samples[Fields.meta][i], MetaKeys.query_sentiment_label,
labels[i])
samples[Fields.meta][i] = nested_set(
samples[Fields.meta][i], MetaKeys.query_sentiment_score,
scores[i])
return samples