-
Notifications
You must be signed in to change notification settings - Fork 3
/
m1_fit.py
130 lines (108 loc) · 4.62 KB
/
m1_fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import os
import pickle
import datetime
import numpy as np
import pandas as pd
from utils.parallel import get_pool
from utils.model import model
from utils.agent import *
# find the current path
path = os.path.dirname(os.path.abspath(__file__))
## pass the hyperparams
parser = argparse.ArgumentParser(description='Test for argparse')
parser.add_argument('--n_fit', '-f', help='fit times', type = int, default=1)
parser.add_argument('--data_set', '-d', help='which_data', type = str, default='gain_data')
parser.add_argument('--method', '-m', help='fitting methods', type = str, default='map')
parser.add_argument('--group', '-g', help='fit to ind or fit to the whole group', type=str, default='ind')
parser.add_argument('--agent_name', '-n', help='choose agent', default='mix_pol_3w')
parser.add_argument('--n_cores', '-c', help='number of CPU cores used for parallel computing',
type=int, default=1)
parser.add_argument('--seed', '-s', help='random seed', type=int, default=420)
args = parser.parse_args()
args.agent = eval(args.agent_name)
# create the folders if not existed
if not os.path.exists(f'{path}/fits'):
os.mkdir(f'{path}/fits')
if not os.path.exists(f'{path}/fits/{args.agent_name}'):
os.mkdir(f'{path}/fits/{args.agent_name}')
def fit_parallel(pool, data, subj, verbose, args):
'''A worker in the parallel computing pool
'''
## fix random seed
seed = args.seed
n_params = args.agent.n_params
## Start fitting
# fit cross validate
m_data = np.sum([data[key].shape[0]
for key in data.keys()])
results = [pool.apply_async(subj.fit,
args=(data, args.method, seed+2*i, verbose)
) for i in range(args.n_fit)]
opt_nll = np.inf
for p in results:
params, loss = p.get()
if loss < opt_nll:
opt_nll, opt_params = loss, params
aic = n_params*2 + 2*opt_nll
bic = n_params*m_data + 2*opt_nll
fit_mat = np.hstack([opt_params, opt_nll, aic, bic]).reshape([1, -1])
## Save the params + nll + aic + bic
col = args.agent.p_name + ['nll', 'aic', 'bic']
print(f' nll: {fit_mat[0, -3]:.4f}')
fit_res = pd.DataFrame(fit_mat, columns=col)
return fit_res
def fit(pool, data, args):
'''Find the optimal free parameter for each model
'''
## Define the RL model
subj = model(args.agent)
## Start
start_time = datetime.datetime.now()
## Fit params to each individual
if args.group == 'ind':
done_subj = 0
all_subj = len(data.keys())
for sub_idx in data.keys():
print(f'Fitting subject {sub_idx}, progress: {(done_subj*100)/all_subj:.2f}%')
fit_res = fit_parallel(pool, data[sub_idx], subj, False, args)
pname = f'{path}/fits/{args.agent_name}/params-{args.data_set}-{args.method}-{sub_idx}.csv'
fit_res.to_csv(pname)
done_subj += 1
## Fit params to the population level
elif args.group == 'avg':
fit_res = fit_parallel(data, pool, subj, True, args)
pname = f'{path}/fits/{args.agent_name}/params-{args.data_set}-{args.method}-avg.csv'
fit_res.to_csv(pname)
## END!!!
end_time = datetime.datetime.now()
print('\nparallel computing spend {:.2f} seconds'.format(
(end_time - start_time).total_seconds()))
def summary(data, args):
## Prepare storage
n_sub = len(data.keys())
n_params = args.agent.n_params
res_mat = np.zeros([n_sub, n_params+3]) + np.nan
res_smry = np.zeros([2, n_params+3]) + np.nan
folder = f'{path}/fits/{args.agent_name}'
## Loop to collect data
for i, sub_idx in enumerate(data.keys()):
fname = f'{folder}/params-{args.data_set}-{args.method}-{sub_idx}.csv'
log = pd.read_csv(fname, index_col=0)
res_mat[i, :] = log.iloc[0, :].values
if i == 0: col = log.columns
## Compute and save the mean and sem
res_smry[0, :] = np.mean(res_mat, axis=0)
res_smry[1, :] = np.std(res_mat, axis=0) / np.sqrt(n_sub)
fname = f'{path}/fits/params-{args.data_set}-{args.method}-{args.agent_name}-ind.csv'
pd.DataFrame(res_smry, columns=col).round(4).to_csv(fname)
if __name__ == '__main__':
## STEP 0: GET PARALLEL POOL
pool = get_pool(args)
## STEP 1: LOAD DATA
fname = f'{path}/data/{args.data_set}.pkl'
with open(fname, 'rb') as handle: data=pickle.load(handle)
## STEP 2: FIT
fit(pool, data, args)
# summary the mean and std for parameters
if args.group=='ind': summary(data, args)