forked from Yu-Group/clinical-rule-vetting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
221 lines (194 loc) · 8.71 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import random
from abc import abstractmethod
from os.path import join as oj
from typing import Dict, Tuple
import numpy as np
import pandas as pd
from joblib import Memory
import rulevetting
from vflow import init_args, Vset, build_vset
class DatasetTemplate:
"""Classes that use this template should be called "Dataset"
All functions take **kwargs, so you can specify any judgement calls you aren't sure about with a kwarg flag.
Please refrain from shuffling / reordering the data in any of these functions, to ensure a consistent test set.
"""
@abstractmethod
def clean_data(self, data_path: str = rulevetting.DATA_PATH, **kwargs) -> pd.DataFrame:
"""
Convert the raw data files into a pandas dataframe.
Dataframe keys should be reasonable (lowercase, underscore-separated).
Data types should be reasonable.
Params
------
data_path: str, optional
Path to all data files
kwargs: dict
Dictionary of hyperparameters specifying judgement calls
Returns
-------
cleaned_data: pd.DataFrame
"""
return NotImplemented
@abstractmethod
def preprocess_data(self, cleaned_data: pd.DataFrame, **kwargs) -> pd.DataFrame:
"""Preprocess the data.
Impute missing values.
Scale/transform values.
Should put the prediction target in a column named "outcome"
Parameters
----------
cleaned_data: pd.DataFrame
kwargs: dict
Dictionary of hyperparameters specifying judgement calls
Returns
-------
preprocessed_data: pd.DataFrame
"""
return NotImplemented
@abstractmethod
def extract_features(self, preprocessed_data: pd.DataFrame, **kwargs) -> pd.DataFrame:
"""Extract features from preprocessed data
All features should be binary
Parameters
----------
preprocessed_data: pd.DataFrame
kwargs: dict
Dictionary of hyperparameters specifying judgement calls
Returns
-------
extracted_features: pd.DataFrame
"""
return NotImplemented
def split_data(self, preprocessed_data: pd.DataFrame) -> pd.DataFrame:
"""Split into 3 sets: training, tuning, testing.
Do not modify (to ensure consistent test set).
Keep in mind any natural splits (e.g. hospitals).
Ensure that there are positive points in all splits.
Parameters
----------
preprocessed_data
kwargs: dict
Dictionary of hyperparameters specifying judgement calls
Returns
-------
df_train
df_tune
df_test
"""
return tuple(np.split(
preprocessed_data.sample(frac=1, random_state=42),
[int(.6 * len(preprocessed_data)), # 60% train
int(.8 * len(preprocessed_data))] # 20% tune, 20% test
))
@abstractmethod
def get_outcome_name(self) -> str:
"""Should return the name of the outcome we are predicting
"""
return NotImplemented
@abstractmethod
def get_dataset_id(self) -> str:
"""Should return the name of the dataset id (str)
"""
return NotImplemented
@abstractmethod
def get_meta_keys(self) -> list:
"""Return list of keys which should not be used in fitting but are still useful for analysis
"""
return NotImplemented
def get_judgement_calls_dictionary(self) -> Dict[str, Dict[str, list]]:
"""Return dictionary of keyword arguments for each function in the dataset class.
Each key should be a string with the name of the arg.
Each value should be a list of values, with the default value coming first.
Example
-------
return {
'clean_data': {},
'preprocess_data': {
'imputation_strategy': ['mean', 'median'], # first value is default
},
'extract_features': {},
}
"""
return NotImplemented
def get_data(self, save_csvs: bool = False,
data_path: str = rulevetting.DATA_PATH,
load_csvs: bool = False,
run_perturbations: bool = False) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Runs all the processing and returns the data.
This method does not need to be overriden.
Params
------
save_csvs: bool, optional
Whether to save csv files of the processed data
data_path: str, optional
Path to all data
load_csvs: bool, optional
Whether to skip all processing and load data directly from csvs
run_perturbations: bool, optional
Whether to run / save data pipeline for all combinations of judgement calls
Returns
-------
df_train
df_tune
df_test
"""
PROCESSED_PATH = oj(data_path, self.get_dataset_id(), 'processed')
if load_csvs:
return tuple([pd.read_csv(oj(PROCESSED_PATH, s), index_col=0)
for s in ['train.csv', 'tune.csv', 'test.csv']])
np.random.seed(0)
random.seed(0)
CACHE_PATH = oj(data_path, 'joblib_cache')
cache = Memory(CACHE_PATH, verbose=0).cache
kwargs = self.get_judgement_calls_dictionary()
default_kwargs = {}
for key in kwargs.keys():
func_kwargs = kwargs[key]
default_kwargs[key] = {k: func_kwargs[k][0] # first arg in each list is default
for k in func_kwargs.keys()}
print('kwargs', default_kwargs)
if not run_perturbations:
cleaned_data = self.clean_data(data_path=data_path, **default_kwargs['clean_data'])
preprocessed_data = self.preprocess_data(cleaned_data, **default_kwargs['preprocess_data'])
extracted_features = self.extract_features(preprocessed_data, **default_kwargs['extract_features'])
df_train, df_tune, df_test = self.split_data(extracted_features)
elif run_perturbations:
data_path_arg = init_args([data_path], names=['data_path'])[0]
clean_set = build_vset('clean_data', self.clean_data, param_dict=kwargs['clean_data'], cache_dir=CACHE_PATH,
output_matching=True)
cleaned_data = clean_set(data_path_arg)
preprocess_set = build_vset('preprocess_data', self.preprocess_data, param_dict=kwargs['preprocess_data'],
cache_dir=CACHE_PATH, output_matching=True)
preprocessed_data = preprocess_set(cleaned_data)
extract_set = build_vset('extract_features', self.extract_features, param_dict=kwargs['extract_features'],
cache_dir=CACHE_PATH, output_matching=True)
extracted_features = extract_set(preprocessed_data)
split_data = Vset('split_data', modules=[self.split_data])
dfs = split_data(extracted_features)
if save_csvs:
os.makedirs(PROCESSED_PATH, exist_ok=True)
if not run_perturbations:
for df, fname in zip([df_train, df_tune, df_test],
['train.csv', 'tune.csv', 'test.csv']):
meta_keys = rulevetting.api.util.get_feat_names_from_base_feats(df.keys(), self.get_meta_keys())
df.loc[:, meta_keys].to_csv(oj(PROCESSED_PATH, f'meta_{fname}'))
df.drop(columns=meta_keys).to_csv(oj(PROCESSED_PATH, fname))
if run_perturbations:
for k in dfs.keys():
if isinstance(k, tuple):
os.makedirs(oj(PROCESSED_PATH, 'perturbed_data'), exist_ok=True)
perturbation_name = str(k).replace(', ', '_').replace('(', '').replace(')', '')
perturbed_path = oj(PROCESSED_PATH, 'perturbed_data', perturbation_name)
os.makedirs(perturbed_path, exist_ok=True)
for i, fname in enumerate(['train.csv', 'tune.csv', 'test.csv']):
df = dfs[k][i]
meta_keys = rulevetting.api.util.get_feat_names_from_base_feats(df.keys(),
self.get_meta_keys())
df.loc[:, meta_keys].to_csv(oj(perturbed_path, f'meta_{fname}'))
df.drop(columns=meta_keys).to_csv(oj(perturbed_path, fname))
return dfs[list(dfs.keys())[0]]
if run_perturbations:
#return dfs[list(dfs.keys())[0]]
return dfs
return df_train, df_tune, df_test