Skip to content

Commit

Permalink
Profile em (#8)
Browse files Browse the repository at this point in the history
* Set up code for profiling the EM function and added a bunch of docstrings to the functions.

* Modified the EM_step() function to use a different method for the third step, all in numpy now.

* Move lemur so that diff makes sense

* cleaned up the profiling flags & added a few comments

* changed the lemur file permissions back

---------

Co-authored-by: Bryce Lorenz Kille <[email protected]>
  • Loading branch information
MGNute and bkille authored Mar 27, 2024
1 parent 97d396a commit 5268854
Showing 1 changed file with 98 additions and 16 deletions.
114 changes: 98 additions & 16 deletions lemur
Original file line number Diff line number Diff line change
Expand Up @@ -442,21 +442,21 @@ class LemurRunEnv():
cigar_gene_mat_tuples)
else:
P_rgs_data["log_P"] = pool.starmap(self.score_cigar_markov,
zip(P_rgs_data["cigar"],
zip(P_rgs_data["cigar"],
repeat(self.transition_mat)))
elif self.aln_score == "edit":
if self.by_gene:
log_P_func = self.score_cigar_fixed(P_rgs_data["cigar"][i], self.gene_edit_cigars[gene])
else:
P_rgs_data["log_P"] = pool.starmap(self.score_cigar_fixed,
zip(P_rgs_data["cigar"],
zip(P_rgs_data["cigar"],
repeat(self.edit_cigar)))
else:
if self.by_gene:
log_P_func = self.score_cigar_fixed(P_rgs_data["cigar"][i], self.gene_fixed_cigars[gene])
else:
P_rgs_data["log_P"] = pool.starmap(self.score_cigar_fixed,
zip(P_rgs_data["cigar"],
zip(P_rgs_data["cigar"],
repeat(self.fixed_cigar)))

del P_rgs_data["cigar"]
Expand Down Expand Up @@ -585,13 +585,22 @@ class LemurRunEnv():

@staticmethod
def logSumExp(ns):
'''
Computes log(e^x1+e^x2+...) for very small x values that might run into a numerical stability problem if
calculated the regular way. Specifically this function subtracts a constant from every number first, then
does the calculation, then adds the constant back at the end.
Args:
ns (np.array): array of arguments to be log-sum-exp'ed
'''
# TODO: maybe remove code duplication with the external version of this function
__max = np.max(ns)
if not np.isfinite(__max):
__max = 0
ds = ns - __max
with np.errstate(divide='ignore'):
sumOfExp = np.exp(ds).sum()
return __max + np.log(sumOfExp)
return __max + np.log(sumOfExp)


def EM_step(self, final=False):
Expand All @@ -603,6 +612,7 @@ class LemurRunEnv():
M-Step: F(t) := sum_{r_in_Reads} [P(t|r)]
'''
# t0 = datetime.datetime.now().timestamp()
if final:
self.F = self.final_F

Expand All @@ -611,31 +621,44 @@ class LemurRunEnv():
self.P_tgr = self.P_rgs_df.reset_index().merge(self.F,
how="inner",
left_on="Target_ID",
right_index=True)
self.P_tgr["P(r|t)*F(t)"] = self.P_tgr.log_P + np.log(self.P_tgr.F)
self.P_tgr_sum = self.P_tgr[["Read_ID", "P(r|t)*F(t)"]].groupby(by="Read_ID", group_keys=False) \
.agg(self.logSumExp)
right_index=True) #Step 0: Merge-F-to_LL
# t1 = datetime.datetime.now().timestamp()

# 2) Compute Likelihood x Prior:
self.P_tgr["P(r|t)*F(t)"] = self.P_tgr.log_P + np.log(self.P_tgr.F) #Step 1: Calc_LL*F
# t2 = datetime.datetime.now().timestamp()

# 3) Compute Lik*Pri Scale factors by Read:
self.P_tgr_sum = EM_get_Prgt_Ft_sums(self.P_tgr, self.args.num_threads)
# t3 = datetime.datetime.now().timestamp()

# 4) Read in per-read scale factors to (read,target)-level dataframe
self.P_tgr = self.P_tgr.merge(self.P_tgr_sum,
how="left",
left_on="Read_ID",
right_index=True,
suffixes=["", "_sum"])
self.P_tgr["P(t|r)"] = self.P_tgr["P(r|t)*F(t)"] - self.P_tgr["P(r|t)*F(t)_sum"]

self.log(set(self.P_tgr["Target_ID"]), logging.DEBUG)
suffixes=["", "_sum"]) #Step 3: Merge-LL*Fsum-to-LL
# t4 = datetime.datetime.now().timestamp()

# 5) E-Step: Calculate P(t|r) = [ P(r|t)*F(t) / sum_{t in taxo} (P(r|t)*F(t)) ]
self.P_tgr["P(t|r)"] = self.P_tgr["P(r|t)*F(t)"] - self.P_tgr["P(r|t)*F(t)_sum"] #Step 4: Calc-Ptgr
# t5 = datetime.datetime.now().timestamp()
self.log(set(self.P_tgr["Target_ID"]), logging.DEBUG)
n_reads = len(self.P_tgr_sum)

# 6) M-Step: Update the estimated values of F(t) = sum_{r}[P(t|r)] #Step 5: Recalc F
self.F = self.P_tgr[["Target_ID", "P(t|r)"]].groupby("Target_ID") \
.agg(lambda x: np.exp(LemurRunEnv.logSumExp(x) - np.log(n_reads)))["P(t|r)"]
self.F.name = "F"
self.F = self.F.loc[self.F!=0]

# Logging: report the sum of the F vector (which should be 1.0)
self.log(self.F.sum(), logging.DEBUG)
# t6 = datetime.datetime.now().timestamp()

if final:
self.final_F = self.F


def compute_loglikelihood(self):
'''Computes the total loglikelihood of the model, which should increase at every iteration or something
is wrong.'''
Expand All @@ -647,7 +670,7 @@ class LemurRunEnv():
criteria has been met, stopping once it has.'''
n_reads = len(set(self.P_rgs_df.reset_index()["Read_ID"]))
self.low_abundance_threshold = 1. / n_reads

if self.args.width_filter:
__P_rgs_df = self.P_rgs_df.reset_index()
tids = list(self.F.index)
Expand Down Expand Up @@ -725,9 +748,9 @@ class LemurRunEnv():
self.collapse_rank()

return

i += 1


@staticmethod
def get_expected_gene_hits(N_genes, N_reads):
Expand Down Expand Up @@ -775,11 +798,70 @@ class LemurRunEnv():
self.log(df_emu_copy.nlargest(30, ["F"]), logging.DEBUG)
self.log(f"File generated: {output_path}\n", logging.DEBUG)

def logSumExp_ReadId(readid, ns):
'''
This is a duplicate of the logSumExp function in the LemurRunEnv object, except that it accepts and
returns the Read ID as an argument. This is so that we can parallelize this function and still associate
the results with the Read ID after being returned.
'''
__max = np.max(ns)
if not np.isfinite(__max):
__max = 0
ds = ns - __max
with np.errstate(divide='ignore'):
sumOfExp = np.exp(ds).sum()
return readid, __max + np.log(sumOfExp)

def EM_get_Prgt_Ft_sums(P_tgr, nthreads):
'''
Computes the log-sum-exp and does the sum aggregation to the read level efficiently.
This function is a strictly numpy replacement for the pandas *.groupby(...).agg(self.logSumExp)
function that was the bottleneck in the EM_step() function. Basically it takes the DF self.P_tgr
and returns a 2-column data frame with columns ["Read_ID","P(r|t)*F(t)"] where each row is the
sum of the second column over all rows with Read_ID equal to the first column. That step was
a real CPU burden in the Pandas implementation and this is much faster.
This function is set up to be parallelized using thread pools, but in practice that may not add
a lot of speedup. But it has to be extracted from the LemurRunEnv object in order to be passed
to the thread pool.
Args:
P_tgr (pandas.df): Pandas data frame containing the two columns "Read ID" and "P(r|t)*F(t)".
nthreads (int): Number of threads to parallelize with.
Returns:
P_tgr_sum (pandas.df)
'''
# Convert each column to individual numpy arrays
Read_ID = P_tgr["Read_ID"].to_numpy().astype(np.str_)
Prgt_Ft = P_tgr["P(r|t)*F(t)"].to_numpy()
# Sort each of them by Read ID
Read_ID_as = Read_ID.argsort()
Read_ID = Read_ID[Read_ID_as]
Prgt_Ft = Prgt_Ft[Read_ID_as]
# Compute indices of unique read IDs
Read_ID_unq, Read_ID_Idx = np.unique(Read_ID, return_index=True)
# Zip up a list of ReadIDs along with the portions of the df to be computed for each read.
map_args = list(zip(Read_ID_unq.tolist(), np.split(Prgt_Ft, Read_ID_Idx[1:])))
thread_pool = Pool(nthreads)
map_res = thread_pool.starmap(logSumExp_ReadId, map_args)
thread_pool.close()
thread_pool.join()
# Reassemble the result into a pandas DF that can go right back into the EM iteration method
P_tgr_sum = pd.DataFrame(map_res, columns=["Read_ID","P(r|t)*F(t)"], dtype='O')
P_tgr_sum['P(r|t)*F(t)'] = P_tgr_sum['P(r|t)*F(t)'].astype(float)
rid=P_tgr_sum["Read_ID"].to_numpy()
if not np.all(rid[:-1]<rid[1:]):
P_tgr_sum.sort_values(["Read_ID"],inplace=True)
P_tgr_sum.set_index(['Read_ID'], inplace=True)
return P_tgr_sum

def main():
run = LemurRunEnv()

if not run.args.sam_input:
run.log(f"Starting run of minimap2 at {datetime.datetime.now()}", logging.INFO)
ts = time.time_ns()
run.run_minimap2()
t0 = time.time_ns()
Expand Down

0 comments on commit 5268854

Please sign in to comment.