Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profile em #8

Merged
merged 7 commits into from
Mar 27, 2024
114 changes: 98 additions & 16 deletions lemur
Original file line number Diff line number Diff line change
Expand Up @@ -442,21 +442,21 @@ class LemurRunEnv():
cigar_gene_mat_tuples)
else:
P_rgs_data["log_P"] = pool.starmap(self.score_cigar_markov,
zip(P_rgs_data["cigar"],
zip(P_rgs_data["cigar"],
repeat(self.transition_mat)))
elif self.aln_score == "edit":
if self.by_gene:
log_P_func = self.score_cigar_fixed(P_rgs_data["cigar"][i], self.gene_edit_cigars[gene])
else:
P_rgs_data["log_P"] = pool.starmap(self.score_cigar_fixed,
zip(P_rgs_data["cigar"],
zip(P_rgs_data["cigar"],
repeat(self.edit_cigar)))
else:
if self.by_gene:
log_P_func = self.score_cigar_fixed(P_rgs_data["cigar"][i], self.gene_fixed_cigars[gene])
else:
P_rgs_data["log_P"] = pool.starmap(self.score_cigar_fixed,
zip(P_rgs_data["cigar"],
zip(P_rgs_data["cigar"],
repeat(self.fixed_cigar)))

del P_rgs_data["cigar"]
Expand Down Expand Up @@ -585,13 +585,22 @@ class LemurRunEnv():

@staticmethod
def logSumExp(ns):
'''
Computes log(e^x1+e^x2+...) for very small x values that might run into a numerical stability problem if
calculated the regular way. Specifically this function subtracts a constant from every number first, then
does the calculation, then adds the constant back at the end.

Args:
ns (np.array): array of arguments to be log-sum-exp'ed
'''
# TODO: maybe remove code duplication with the external version of this function
__max = np.max(ns)
if not np.isfinite(__max):
__max = 0
ds = ns - __max
with np.errstate(divide='ignore'):
sumOfExp = np.exp(ds).sum()
return __max + np.log(sumOfExp)
return __max + np.log(sumOfExp)


def EM_step(self, final=False):
Expand All @@ -603,6 +612,7 @@ class LemurRunEnv():
M-Step: F(t) := sum_{r_in_Reads} [P(t|r)]

'''
# t0 = datetime.datetime.now().timestamp()
if final:
self.F = self.final_F

Expand All @@ -611,31 +621,44 @@ class LemurRunEnv():
self.P_tgr = self.P_rgs_df.reset_index().merge(self.F,
how="inner",
left_on="Target_ID",
right_index=True)
self.P_tgr["P(r|t)*F(t)"] = self.P_tgr.log_P + np.log(self.P_tgr.F)
self.P_tgr_sum = self.P_tgr[["Read_ID", "P(r|t)*F(t)"]].groupby(by="Read_ID", group_keys=False) \
.agg(self.logSumExp)
right_index=True) #Step 0: Merge-F-to_LL
# t1 = datetime.datetime.now().timestamp()

# 2) Compute Likelihood x Prior:
self.P_tgr["P(r|t)*F(t)"] = self.P_tgr.log_P + np.log(self.P_tgr.F) #Step 1: Calc_LL*F
# t2 = datetime.datetime.now().timestamp()

# 3) Compute Lik*Pri Scale factors by Read:
self.P_tgr_sum = EM_get_Prgt_Ft_sums(self.P_tgr, self.args.num_threads)
# t3 = datetime.datetime.now().timestamp()

# 4) Read in per-read scale factors to (read,target)-level dataframe
self.P_tgr = self.P_tgr.merge(self.P_tgr_sum,
how="left",
left_on="Read_ID",
right_index=True,
suffixes=["", "_sum"])
self.P_tgr["P(t|r)"] = self.P_tgr["P(r|t)*F(t)"] - self.P_tgr["P(r|t)*F(t)_sum"]

self.log(set(self.P_tgr["Target_ID"]), logging.DEBUG)
suffixes=["", "_sum"]) #Step 3: Merge-LL*Fsum-to-LL
# t4 = datetime.datetime.now().timestamp()

# 5) E-Step: Calculate P(t|r) = [ P(r|t)*F(t) / sum_{t in taxo} (P(r|t)*F(t)) ]
self.P_tgr["P(t|r)"] = self.P_tgr["P(r|t)*F(t)"] - self.P_tgr["P(r|t)*F(t)_sum"] #Step 4: Calc-Ptgr
# t5 = datetime.datetime.now().timestamp()
self.log(set(self.P_tgr["Target_ID"]), logging.DEBUG)
n_reads = len(self.P_tgr_sum)

# 6) M-Step: Update the estimated values of F(t) = sum_{r}[P(t|r)] #Step 5: Recalc F
self.F = self.P_tgr[["Target_ID", "P(t|r)"]].groupby("Target_ID") \
.agg(lambda x: np.exp(LemurRunEnv.logSumExp(x) - np.log(n_reads)))["P(t|r)"]
self.F.name = "F"
self.F = self.F.loc[self.F!=0]

# Logging: report the sum of the F vector (which should be 1.0)
self.log(self.F.sum(), logging.DEBUG)
# t6 = datetime.datetime.now().timestamp()

if final:
self.final_F = self.F


def compute_loglikelihood(self):
'''Computes the total loglikelihood of the model, which should increase at every iteration or something
is wrong.'''
Expand All @@ -647,7 +670,7 @@ class LemurRunEnv():
criteria has been met, stopping once it has.'''
n_reads = len(set(self.P_rgs_df.reset_index()["Read_ID"]))
self.low_abundance_threshold = 1. / n_reads

if self.args.width_filter:
__P_rgs_df = self.P_rgs_df.reset_index()
tids = list(self.F.index)
Expand Down Expand Up @@ -725,9 +748,9 @@ class LemurRunEnv():
self.collapse_rank()

return

i += 1


@staticmethod
def get_expected_gene_hits(N_genes, N_reads):
Expand Down Expand Up @@ -775,11 +798,70 @@ class LemurRunEnv():
self.log(df_emu_copy.nlargest(30, ["F"]), logging.DEBUG)
self.log(f"File generated: {output_path}\n", logging.DEBUG)

def logSumExp_ReadId(readid, ns):
'''
This is a duplicate of the logSumExp function in the LemurRunEnv object, except that it accepts and
returns the Read ID as an argument. This is so that we can parallelize this function and still associate
the results with the Read ID after being returned.
'''
__max = np.max(ns)
if not np.isfinite(__max):
__max = 0
ds = ns - __max
with np.errstate(divide='ignore'):
sumOfExp = np.exp(ds).sum()
return readid, __max + np.log(sumOfExp)

def EM_get_Prgt_Ft_sums(P_tgr, nthreads):
'''
Computes the log-sum-exp and does the sum aggregation to the read level efficiently.

This function is a strictly numpy replacement for the pandas *.groupby(...).agg(self.logSumExp)
function that was the bottleneck in the EM_step() function. Basically it takes the DF self.P_tgr
and returns a 2-column data frame with columns ["Read_ID","P(r|t)*F(t)"] where each row is the
sum of the second column over all rows with Read_ID equal to the first column. That step was
a real CPU burden in the Pandas implementation and this is much faster.

This function is set up to be parallelized using thread pools, but in practice that may not add
a lot of speedup. But it has to be extracted from the LemurRunEnv object in order to be passed
to the thread pool.

Args:
P_tgr (pandas.df): Pandas data frame containing the two columns "Read ID" and "P(r|t)*F(t)".
nthreads (int): Number of threads to parallelize with.

Returns:
P_tgr_sum (pandas.df)
'''
# Convert each column to individual numpy arrays
Read_ID = P_tgr["Read_ID"].to_numpy().astype(np.str_)
Prgt_Ft = P_tgr["P(r|t)*F(t)"].to_numpy()
# Sort each of them by Read ID
Read_ID_as = Read_ID.argsort()
Read_ID = Read_ID[Read_ID_as]
Prgt_Ft = Prgt_Ft[Read_ID_as]
# Compute indices of unique read IDs
Read_ID_unq, Read_ID_Idx = np.unique(Read_ID, return_index=True)
# Zip up a list of ReadIDs along with the portions of the df to be computed for each read.
map_args = list(zip(Read_ID_unq.tolist(), np.split(Prgt_Ft, Read_ID_Idx[1:])))
thread_pool = Pool(nthreads)
map_res = thread_pool.starmap(logSumExp_ReadId, map_args)
thread_pool.close()
thread_pool.join()
# Reassemble the result into a pandas DF that can go right back into the EM iteration method
P_tgr_sum = pd.DataFrame(map_res, columns=["Read_ID","P(r|t)*F(t)"], dtype='O')
P_tgr_sum['P(r|t)*F(t)'] = P_tgr_sum['P(r|t)*F(t)'].astype(float)
rid=P_tgr_sum["Read_ID"].to_numpy()
if not np.all(rid[:-1]<rid[1:]):
P_tgr_sum.sort_values(["Read_ID"],inplace=True)
P_tgr_sum.set_index(['Read_ID'], inplace=True)
return P_tgr_sum

def main():
run = LemurRunEnv()

if not run.args.sam_input:
run.log(f"Starting run of minimap2 at {datetime.datetime.now()}", logging.INFO)
ts = time.time_ns()
run.run_minimap2()
t0 = time.time_ns()
Expand Down
Loading