-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_submission.py
110 lines (91 loc) · 5.19 KB
/
eval_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
import argparse
import json
import polars as pl
import os.path
import config
log = logging.getLogger(os.path.basename(__file__))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_split_alias', default='train-test')
parser.add_argument('--file_submit', default='submission-v1.0.0-0505c388-20230119171211.csv')
args = parser.parse_args()
log.info(f'Running {os.path.basename(__file__)} with parameters: \n' + json.dumps(vars(args), indent=2))
log.info('This evaluates a submission on test data.')
# python -m model.eval_submission --file_submit v1.4.0-20230130152337-3f932a03.csv
labels = pl.read_parquet(f'{config.DIR_DATA}/{args.data_split_alias}-parquet/test_labels/*.parquet')
submission = pl.read_csv(f'{config.DIR_DATA}/{args.data_split_alias}-submit/{args.file_submit}')
dir_out = f'{config.DIR_DATA}/{args.data_split_alias}-eval-submissions'
os.makedirs(dir_out, exist_ok=True)
labels = labels \
.rename({'type': 'type_int'}). \
join(pl.DataFrame({'type_int': [0, 1, 2], 'type': ['clicks', 'carts', 'orders']})
.with_column(pl.col('type_int').cast(pl.Int8)), on='type_int') \
.drop('type_int') \
.with_column(pl.lit(1).alias('target'))
submission = submission \
.with_column(pl.col('session_type').cast(str).str.split('_').alias('session_type_split')) \
.with_column(pl.col('session_type_split').arr.get(0).alias('session').cast(pl.Int32)) \
.with_column(pl.col('session_type_split').arr.get(1).alias('type')) \
.with_column(pl.col('labels').cast(str).str.split(' ')) \
.explode('labels') \
.with_column(pl.col('labels').cast(pl.Int32).alias('aid')) \
.drop(['labels', 'session_type', 'session_type_split']) \
.with_column(pl.lit(1).alias('submit'))
joined = labels.join(submission, on=['session', 'type', 'aid'], how='outer').fill_null(0)
joined = joined \
.groupby(['session', 'type']) \
.agg([pl.sum('target').clip_max(20).alias('true'),
(pl.col('target') * pl.col('submit')).sum().alias('hit')]) \
.groupby('type') \
.agg([pl.sum('hit'), pl.sum('true')]) \
.with_column((pl.col('hit') / pl.col('true')).alias('recall@20'))
recall_agg = joined \
.join(pl.DataFrame({'type': ['clicks', 'carts', 'orders'], 'weight': [0.1, 0.3, 0.6]}), on='type') \
.with_column(pl.col('recall@20') * pl.col('weight')) \
.sum()\
.with_column(pl.lit('total').alias('type'))
res = pl.concat([joined[['type', 'recall@20']], recall_agg[['type', 'recall@20']]]) \
.join(pl.DataFrame({'type': ['clicks', 'carts', 'orders', 'total'], 'order': [1, 2, 3, 4]}), on='type') \
.sort('order') \
.drop('order')
log.debug('Recall@20 per type & weighted total: \n' + str(res))
with open(f'{dir_out}/{args.file_submit.replace(".csv", ".json")}', 'w') as f:
json.dump(dict(zip(res['type'], res['recall@20'])), f, indent=2, sort_keys=True)
res.to_csv(f'{dir_out}/{args.file_submit}', index=False)
# v1.0.0-7fa08333-20230119143255.csv
# ┌────────┬────────┬─────────┬───────────┐
# │ type ┆ hit ┆ true ┆ recall@20 │
# ╞════════╪════════╪═════════╪═══════════╡
# │ carts ┆ 230643 ┆ 566105 ┆ 0.407421 │
# │ clicks ┆ 855952 ┆ 1737968 ┆ 0.492502 │
# │ orders ┆ 202361 ┆ 310905 ┆ 0.650877 │
# │ total ┆ ┆ ┆ 0.562002 │
# └────────┴────────┴─────────┴───────────┘
# v1.2.0-20230129142628-4a0d1182.csv
# ┌────────┬────────┬─────────┬───────────┐
# │ type ┆ hit ┆ true ┆ recall@20 │
# ╞════════╪════════╪═════════╪═══════════╡
# │ carts ┆ 229674 ┆ 566105 ┆ 0.405709 │
# │ orders ┆ 202256 ┆ 310905 ┆ 0.65054 │
# │ clicks ┆ 856203 ┆ 1737968 ┆ 0.492646 │
# │ total ┆ ┆ ┆ 0.561301 │
# └────────┴────────┴─────────┴───────────┘
# v1.4.0-20230130152337-3f932a03.csv
# ┌────────┬───────────┐
# │ type ┆ recall@20 │
# ╞════════╪═══════════╡
# │ orders ┆ 0.65299 │
# │ clicks ┆ 0.493545 │
# │ carts ┆ 0.407908 │
# │ total ┆ 0.563521 │
# └────────┴───────────┘
# v1.5.0-20230131110348-bc9b575e.csv
# ┌────────┬───────────┐
# │ type ┆ recall@20 │
# ╞════════╪═══════════╡
# │ clicks ┆ 0.498642 │
# │ carts ┆ 0.411609 │
# │ orders ┆ 0.654711 │
# │ total ┆ 0.566174 │
# └────────┴───────────┘