Skip to content

Commit

Permalink
fix: Now works with all solvers in Pulp
Browse files Browse the repository at this point in the history
  • Loading branch information
hoomanzabeti committed Apr 28, 2021
1 parent 137ee50 commit e2fea72
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 50 deletions.
2 changes: 2 additions & 0 deletions group_testing/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ decode:
'is_it_noiseless': True
'lp_relaxation': False
'solver_name': 'CPLEX_PY'
#['GLPK_CMD', 'CPLEX_PY', 'GUROBI_CMD', 'PULP_CBC_CMD', 'COIN_CMD', 'PULP_CHOCO_CMD']
'solver_options':
'timeLimit': 1800
'logPath': 'auto'
'lambda_selection': False
'number_of_folds': 3
'cv_param':
Expand Down
119 changes: 82 additions & 37 deletions group_testing/group_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import argparse
from distutils.util import strtobool
from multiprocessing import Pool
from multiprocessing import cpu_count
import numpy as np
Expand All @@ -21,6 +22,7 @@
from sklearn.metrics import roc_curve, auc, balanced_accuracy_score
from sklearn.metrics import accuracy_score
import sys
import pulp as pl

from group_testing import __version__
from group_testing.generate_groups import gen_measurement_matrix
Expand All @@ -35,14 +37,14 @@ def multi_process_group_testing(design_param, decoder_param):
try:
single_run_start = time.time()
# generate the measurement matrix from the given options
if 'group_size' in design_param.keys() and str(design_param['group_size']).lower()=='auto':
if 'group_size' in design_param.keys() and str(design_param['group_size']).lower() == 'auto':
assert 'N' in design_param.keys(), "To generate the group size automatically parameter 'N' is needed to be" \
"defined in the config file"
assert 's' in design_param.keys(), "To generate the group size automatically parameter 's' is needed to be" \
"defined in the config file"
assert design_param['s'] <= design_param['N'], " 's'> 'N': number of infected individuals can not be " \
"greater than number of individuals."
design_param['group_size'] = utils.auto_group_size(design_param['N'],design_param['s'])
design_param['group_size'] = utils.auto_group_size(design_param['N'], design_param['s'])
print("group size is {}".format(design_param['group_size']))
if design_param['generate_groups'] == 'alternative_module':
generate_groups_alt_module = __import__(design_param['groups_alternative_module'][0], globals(), locals(),
Expand All @@ -53,7 +55,6 @@ def multi_process_group_testing(design_param, decoder_param):
A = generate_groups_alt_function(**passing_param)
elif design_param['generate_groups'] == 'input':
A = np.genfromtxt(design_param['groups_input'], delimiter=',')
# TODO: Check if m and N are defined too
assert np.array_equal(A, A.astype(bool)), "The input design matrix A is not binary!"
design_param['m'], design_param['N'] = A.shape
design_param['group_size'] = int(max(A.sum(axis=1)))
Expand Down Expand Up @@ -87,34 +88,35 @@ def multi_process_group_testing(design_param, decoder_param):
assert np.array_equal(b, b.astype(bool)), "test results input file is not binary!"
elif design_param['generate_test_results'] == 'alternative_module':
test_results_alt_module = __import__(design_param['test_results_alternative_module'][0],
globals(), locals(), [], 0)
globals(), locals(), [], 0)
test_results_alt_function = getattr(test_results_alt_module,
design_param['test_results_alternative_module'][1])
design_param['test_results_alternative_module'][1])
passing_param, temp_remaining_param = utils.param_distributor(design_param, test_results_alt_function)
remaining_param.update(temp_remaining_param)
b = test_results_alt_function(**passing_param)
elif design_param['generate_test_results'] == 'generate':
passing_param, temp_remaining_param = utils.param_distributor(design_param, gen_test_vector)
remaining_param.update(temp_remaining_param)
b = gen_test_vector(A, u,**passing_param)
b = gen_test_vector(A, u, **passing_param)
for main_param in ['N', 'm', 's', 'group_size', 'seed']:
if main_param not in design_param:
if main_param not in remaining_param:
design_param[main_param]= 'N\A'
design_param[main_param] = 'N\A'
else:
design_param[main_param]=remaining_param[main_param]
design_param[main_param] = remaining_param[main_param]
if 'save_to_file' in design_param.keys() and design_param['save_to_file']:
design_path = utils.inner_path_generator(design_param['result_path'], 'Design')
design_matrix_path = utils.inner_path_generator(design_path, 'Design_Matrix')
pd.DataFrame(A).to_csv(utils.report_file_path(design_matrix_path, 'design_matrix', 'csv', design_param),
header=None, index=None)
if design_param['generate_individual_status']:
individual_status_path = utils.inner_path_generator(design_path,'Individual_Status')
pd.DataFrame(u).to_csv(utils.report_file_path(individual_status_path, 'individual_status','csv', design_param),
header=None, index=None)
individual_status_path = utils.inner_path_generator(design_path, 'Individual_Status')
pd.DataFrame(u).to_csv(
utils.report_file_path(individual_status_path, 'individual_status', 'csv', design_param),
header=None, index=None)
if design_param['generate_test_results']:
test_results_path = utils.inner_path_generator(design_path,'Test_Results')
pd.DataFrame(b).to_csv(utils.report_file_path(test_results_path, 'test_results','csv', design_param),
test_results_path = utils.inner_path_generator(design_path, 'Test_Results')
pd.DataFrame(b).to_csv(utils.report_file_path(test_results_path, 'test_results', 'csv', design_param),
header=None, index=None)
except:
e = sys.exc_info()
Expand All @@ -124,10 +126,21 @@ def multi_process_group_testing(design_param, decoder_param):
if decoder_param['decoding']:
try:
if decoder_param['decoder'] == 'generate':
# TODO: this is only for cplex! Change it to more general form!
log_path = utils.inner_path_generator(design_param['result_path'], 'Logs')
decoder_param['solver_options']['logPath'] = utils.report_file_path(log_path, 'log','txt', design_param)
passing_param,_ = utils.param_distributor(decoder_param,GroupTestingDecoder)
if utils.dict_key_checker(decoder_param, 'solver_options') \
and utils.dict_key_checker(decoder_param['solver_options'], 'logPath') \
or decoder_param['lambda_selection']:
log_path = utils.inner_path_generator(design_param['result_path'], 'Logs')
if utils.dict_key_checker(decoder_param, 'solver_options') \
and utils.dict_key_checker(decoder_param['solver_options'], 'logPath'):
if str(decoder_param['solver_options']['logPath']).lower() == 'auto':
decoder_param['solver_options']['logPath'] = utils.report_file_path(log_path, 'log', 'txt',
design_param)
else:
decoder_param['solver_options']['logPath'] = os.path.join(log_path,
decoder_param['solver_options'][
'logPath'])

passing_param, _ = utils.param_distributor(decoder_param, GroupTestingDecoder)
c = GroupTestingDecoder(**passing_param)
single_fit_start = time.time()
if decoder_param['lambda_selection']:
Expand All @@ -143,11 +156,11 @@ def multi_process_group_testing(design_param, decoder_param):
c = grid.best_estimator_
pd.DataFrame.from_dict(grid.cv_results_).to_csv(
utils.report_file_path(log_path,
'cv_results', 'csv', design_param)
'cv_results', 'csv', design_param)
)
pd.DataFrame(grid.best_params_, index=[0]).to_csv(
utils.report_file_path(log_path,
'best_param', 'csv', design_param)
'best_param', 'csv', design_param)
)
else:
c.fit(A, b)
Expand All @@ -158,25 +171,27 @@ def multi_process_group_testing(design_param, decoder_param):
# TODO: CV for alternative module. Is it needed?
single_fit_start = time.time()
decoder_alt_module = __import__(decoder_param['decoder_alternative_module'][0],
globals(), locals(), [], 0)
globals(), locals(), [], 0)
decoder_alt_function = getattr(decoder_alt_module,
decoder_param['decoder_alternative_module'][1])
decoder_param['decoder_alternative_module'][1])
passing_param, _ = utils.param_distributor(decoder_param, decoder_alt_function)
c = decoder_alt_function(**passing_param)
c.fit(A, b)
single_fit_end = time.time()
if 'save_to_file' in design_param.keys() and design_param['save_to_file']:
solution_path = utils.inner_path_generator(design_param['result_path'], 'Solutions')
pd.DataFrame(c.solution()).to_csv(utils.report_file_path(solution_path, 'solution','csv', design_param),
header=None, index=None)
pd.DataFrame(c.solution()).to_csv(
utils.report_file_path(solution_path, 'solution', 'csv', design_param),
header=None, index=None)
# evaluate the accuracy of the solution
if decoder_param['evaluation']:
try:
ev_result = decoder_evaluation(u, c.solution(), decoder_param['eval_metric'])
ev_result['solver_time'] = round(single_fit_end - single_fit_start, 2)
# TODO: this is only for cplex status! Change it to more general form!
ev_result['Status'] = c.prob_.cplex_status
print('Evaluation is DONE!')
ev_result['fitting_time'] = round(single_fit_end - single_fit_start, 2)
if 'cplex_status' in c.prob_.__dict__.keys():
ev_result['Status'] = c.prob_.cplex_status
else:
ev_result['Status'] = pl.LpStatus[c.prob_.status]
except Exception as e:
print(e)
ev_result = {'tn': None, 'fp': None, 'fn': None, 'tp': None}
Expand All @@ -191,35 +206,64 @@ def multi_process_group_testing(design_param, decoder_param):
print("Decoding was not performed!")



# main method for testing
def main(sysargs=sys.argv[1:]):
start_time = time.time()

# argparse
parser = argparse.ArgumentParser(prog='GroupTesting', description='Description')
required_args= parser.add_argument_group('required arguments')
required_args = parser.add_argument_group('required arguments')
parser.add_argument(
'--version', action='version',
'--version', action='version',
version="%(prog)s version {version}".format(version=__version__)
)
required_args.add_argument(
'--config', dest='config', metavar='FILE',
'--config', dest='config', metavar='FILE',
help='Path to the config.yml file', required=True,
)
parser.add_argument(
'--output-dir', dest='output_path', metavar='DIR',
'--output-dir', dest='output_path', metavar='DIR',
help='Path to the output directory',
)
args = parser.parse_args()
parser.add_argument(
'--parallel', dest='parallel', metavar='BOOL',
help='whether to use multiprocessing. The default value is True.', default=True, type=bool
)
# parser.add_argument(
# '--seed', dest='seed', metavar='RAND_SEED',
# help='Random seed', type=int
# )
# parser.add_argument(
# '-N', dest='N', metavar='NUM_OF_INDIVIDUALS',
# help='Population size', type=int
# )
# parser.add_argument(
# '-m', dest='m', metavar='NUM_OF_TEST',
# help='Number of tests', type=int
# )
# parser.add_argument(
# '-g', '--group-size', dest='group_size',
# help='Group size', type=int
# )
# parser.add_argument(
# '-d', '--divisibility', dest='max_tests_per_individual', metavar='Divisibility',
# help='Divisibility or maximum number of times a person’s sample can be included in a test. ', type=int
# )
# parser.add_argument(
# '-s', dest='s', metavar='NUM_OF_INFECTED',
# help='number of infecteds', type=int
# )

args = parser.parse_args()
#print("------------------->",args.N)
# Read config file
design_param, decoder_param = utils.config_reader(args.config)
# output files path
current_path, result_path = utils.result_path_generator(args)
for d in design_param:
d['result_path'] = result_path
if not any([d['lambda_selection'] if 'lambda_selection' in d.keys() else False for d in decoder_param]):
if args.parallel \
and not any([d['lambda_selection'] if 'lambda_selection' in d.keys() else False for d in decoder_param]):
with Pool(processes=cpu_count()) as pool:
results = pool.starmap(multi_process_group_testing, itertools.product(design_param, decoder_param))
pool.close()
Expand All @@ -231,8 +275,9 @@ def main(sysargs=sys.argv[1:]):
pd.DataFrame(design_param).to_csv(os.path.join(result_path, 'opts.csv'))
if all(v is not None for v in results):
column_names = ['N', 'm', 's', 'group_size', 'seed', 'max_tests_per_individual', 'tn', 'fp', 'fn', 'tp',
decoder_param[0]['eval_metric'], 'Status', 'solver_time', 'time']
decoder_param[0]['eval_metric'], 'Status', 'fitting_time', 'time']
print(pd.DataFrame(results).reindex(columns=column_names))
pd.DataFrame(results).reindex(columns=column_names).to_csv(os.path.join(result_path, 'ConfusionMatrix.csv'))

end_time = time.time()
print(end_time - start_time)
print(end_time - start_time)
25 changes: 16 additions & 9 deletions group_testing/group_testing_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def __init__(self, lambda_w=1, lambda_p=1, lambda_n=1, lambda_e=None, defective_
self.solver_name = solver_name
self.solver_options = solver_options
self.prob_ = None
self.ep_cat = 'Binary'
self.en_cat = 'Binary'
self.en_upBound = 1
# self.ep_cat = 'Binary'
# self.en_cat = 'Binary'
if self.lp_relaxation:
self.en_upBound = 1
else:
self.en_upBound = None

def fit(self, A, label):
if self.lambda_e is not None:
Expand Down Expand Up @@ -67,6 +70,7 @@ def fit(self, A, label):
# Variables kind
if self.lp_relaxation:
varCategory = 'Continuous'
#self.solver_options['mip']= True
else:
varCategory = 'Binary'
# Variable w
Expand All @@ -93,11 +97,11 @@ def fit(self, A, label):
en = []
# Variable ep
if len(positive_label) != 0:
ep = LpVariable.dicts(name='ep', indexs=list(positive_label), lowBound=0, upBound=1, cat=self.ep_cat)
ep = pl.LpVariable.dicts(name='ep', indexs=list(positive_label), lowBound=0, upBound=1, cat=varCategory)
# Variable en
if len(negative_label) != 0:
en = LpVariable.dicts(name='en', indexs=list(negative_label), lowBound=0, upBound=self.en_upBound,
cat=self.en_cat)
en = pl.LpVariable.dicts(name='en', indexs=list(negative_label), lowBound=0, upBound=self.en_upBound,
cat=varCategory)
# Defining the objective function
p += pl.lpSum([self.lambda_w * w[i] if isinstance(self.lambda_w, (int, float)) else self.lambda_w[i] * w[i]
for i in range(n)]) + \
Expand All @@ -107,20 +111,23 @@ def fit(self, A, label):
for i in positive_label:
p += pl.lpSum([A[i][j] * w[j] for j in range(n)] + ep[i]) >= 1
for i in negative_label:
if self.en_cat == 'Continuous':
if varCategory == 'Continuous':
p += pl.lpSum([A[i][j] * w[j] for j in range(n)] + -1 * en[i]) == 0
else:
p += pl.lpSum([-1 * A[i][j] * w[j] for j in range(n)] + alpha[i] * en[i]) >= 0
# Prevalence lower-bound
if self.defective_num_lower_bound is not None:
p += pl.lpSum([w[i] for i in range(n)]) >= self.defective_num_lower_bound
solver = pl.get_solver(self.solver_name, **self.solver_options)
if self.solver_options is not None:
solver = pl.get_solver(self.solver_name, **self.solver_options)
else:
solver = pl.get_solver(self.solver_name)
p.solve(solver)
if not self.lp_relaxation:
p.roundSolution()
# ----------------
self.prob_ = p
print("Status:", pl.LpStatus[p.status])
#print("Status:", pl.LpStatus[p.status])
return self

def get_params_w(self, deep=True):
Expand Down
3 changes: 1 addition & 2 deletions group_testing/group_testing_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ def decoder_evaluation(w_true, sln, ev_metric='balanced_accuracy'):
tn, fp, fn, tp = confusion_matrix(w_true, sln).ravel()
eval_metric = getattr(sklearn.metrics,'{}_score'.format(ev_metric))
eval_score = eval_metric(w_true, sln)
ev_result = {'tn': tn, 'fp': fp, 'fn': fn, 'tp':tp, ev_metric:round(eval_score, 3)}
print(ev_result)
ev_result = {'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp, ev_metric: round(eval_score, 3)}
return ev_result


Expand Down
13 changes: 11 additions & 2 deletions group_testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,22 @@ def report_file_path(report_path, report_label, report_extension, params):
return report_path


def dict_key_checker(current_dict, current_key):
if current_key in current_dict.keys():
return True
else:
return False


def result_path_generator(args):
current_path = os.getcwd()
currentDate = datetime.datetime.now()
if args.output_path is None:
dir_name = currentDate.strftime("%b_%d_%Y_%H_%M_%S")
result_path = os.path.join(current_path, "Results/{}".format(dir_name))
else:
dir_name = args.output_path
result_path = os.path.join(current_path, "Results/{}".format(dir_name))
result_path = os.path.join(current_path, dir_name)
if not os.path.isdir(result_path):
try:
os.makedirs(result_path)
Expand All @@ -179,7 +187,8 @@ def result_path_generator(args):
else:
print("Successfully created the directory %s " % result_path)
# Copy config file
copyfile(args.config, os.path.join(result_path, 'config.yml'))
if os.path.isfile(args.config):
copyfile(args.config, os.path.join(result_path, 'config.yml'))
return current_path, result_path


Expand Down

0 comments on commit e2fea72

Please sign in to comment.