Skip to content

Commit

Permalink
now support using provided connectivity matrix as initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
anyuzx committed Nov 29, 2021
1 parent 3b87234 commit 178661a
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 58 deletions.
28 changes: 18 additions & 10 deletions HippsDimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,20 +324,23 @@ def checkEMD(ddmap):


class Optimize:
def __init__(self, ddmap_target):
def __init__(self, ddmap_target, connectivity_matrix=None):
# ddmap_target is the targeted matrix we would like to match
# note that ddmap_taret is the mean SQUARED distance matrix, not mean distance matrix
self.ddmap_target = ddmap_target

# get the size of system
self.n = ddmap_target.shape[0]

# initialize the connectivity matrix
# here the connectivity matrix is initialized as a simple rouse chain whose spring constant is determined such\
# that its radius of gyration is close to the target
rg2 = .5 * np.nanmean(self.ddmap_target)
k = self.n / (4. * rg2)
self.A = construct_connectivity_matrix_rouse(self.n, k)
if connectivity_matrix is None:
# initialize the connectivity matrix
# here the connectivity matrix is initialized as a simple rouse chain whose spring constant is determined such\
# that its radius of gyration is close to the target
rg2 = .5 * np.nanmean(self.ddmap_target)
k = self.n / (4. * rg2)
self.A = construct_connectivity_matrix_rouse(self.n, k)
else:
self.A = connectivity_matrix

# initialize the loss
self.loss = None
Expand All @@ -360,7 +363,6 @@ def __update_parameter(self, t, learning_rate, lamd=0.0, reg='l2', method='IS',
compare_ratio = ddmap_t / self.ddmap_target
# compute the prefactor for iterative scaling
fhash = np.nansum(ddmap_t) / 2.
#fhash = 1.0

if method == 'IS':
# compute the gradient
Expand Down Expand Up @@ -444,6 +446,7 @@ def run(self, epoch, general_method='optimization', **kwargs):
@click.command()
@click.argument('input', nargs=1)
@click.argument('output-prefix', nargs=1)
@click.option('-k', '--connectivity-matrix', type=str, required=False, help='Use provided connectivity matrix as initialization. Useful when restart from previous run')
@click.option('-e', '--ensemble', type=int, default=1000, show_default=True, help='specify the number of conformations generated')
@click.option('-a', '--alpha', type=float, default=4.0, show_default=True, help='specify the value of cmap-to-dmap conversion exponent')
@click.option('-s', '--selection', type=str, help='specify which chromosome or region to run the model on if the input file is Hi-C data in cooler format. Accept any valid options for [fetch] method in cooler.Cooler.matrix() selector')
Expand All @@ -463,7 +466,8 @@ def run(self, epoch, general_method='optimization', **kwargs):
@click.option('--balance', is_flag=True, default=False, show_default=True, help='Turn on the matrix balance for contact map. Only effective when input_type == cmap and input_format == cooler')
@click.option('--not-normalize', is_flag=True, default=False, show_default=True, help='Turn off auto normalization of contact map. Only effective when the input is contact map')
@click.option('--enforce-nonnegative-connectivity-matrix', is_flag=True, default=False, show_default=True, help='Enforcing that the "spring constants" in the connectivity matrix can only be nonnegative')
def main(input, output_prefix, ensemble, alpha, selection, method, lamd, reg, iteration, learning_rate, input_type, input_format, log, no_xyzs, ignore_missing_data, balance, not_normalize, enforce_nonnegative_connectivity_matrix):
def main(input, output_prefix, connectivity_matrix, ensemble, alpha, selection, method, lamd, reg, iteration, learning_rate, input_type, \
input_format, log, no_xyzs, ignore_missing_data, balance, not_normalize, enforce_nonnegative_connectivity_matrix):
"""
Script to run HIPPS/DIMES to generate ensemble of genome structures from either contact map or mean distance map\n
INPUT: Specify the path to the input file\n
Expand Down Expand Up @@ -510,6 +514,10 @@ def main(input, output_prefix, ensemble, alpha, selection, method, lamd, reg, it
else:
dmap_target = cmap2dmap(cmap, alpha, not_normalize)
dmap_target = ((3. * np.pi) / 8.) * np.power(dmap_target, 2.)

if connectivity_matrix is not None:
connectivity_matrix = np.loadtxt(connectivity_matrix)
console.print("Load the provided connectivity matrix and will use it as initialization.")
console.print("Initialization completed")


Expand Down Expand Up @@ -544,7 +552,7 @@ def main(input, output_prefix, ensemble, alpha, selection, method, lamd, reg, it
)
console.print(table)

model = Optimize(dmap_target)
model = Optimize(dmap_target, connectivity_matrix=connectivity_matrix)
keyword_arguments = {'learning_rate': learning_rate, 'lamd': lamd, 'reg': reg, 'method': method,
'enforce_nonnegative_connectivity_matrix': enforce_nonnegative_connectivity_matrix}

Expand Down
Loading

0 comments on commit 178661a

Please sign in to comment.