diff --git a/lemur b/lemur index ea20741..40b320f 100755 --- a/lemur +++ b/lemur @@ -442,21 +442,21 @@ class LemurRunEnv(): cigar_gene_mat_tuples) else: P_rgs_data["log_P"] = pool.starmap(self.score_cigar_markov, - zip(P_rgs_data["cigar"], + zip(P_rgs_data["cigar"], repeat(self.transition_mat))) elif self.aln_score == "edit": if self.by_gene: log_P_func = self.score_cigar_fixed(P_rgs_data["cigar"][i], self.gene_edit_cigars[gene]) else: P_rgs_data["log_P"] = pool.starmap(self.score_cigar_fixed, - zip(P_rgs_data["cigar"], + zip(P_rgs_data["cigar"], repeat(self.edit_cigar))) else: if self.by_gene: log_P_func = self.score_cigar_fixed(P_rgs_data["cigar"][i], self.gene_fixed_cigars[gene]) else: P_rgs_data["log_P"] = pool.starmap(self.score_cigar_fixed, - zip(P_rgs_data["cigar"], + zip(P_rgs_data["cigar"], repeat(self.fixed_cigar))) del P_rgs_data["cigar"] @@ -585,13 +585,22 @@ class LemurRunEnv(): @staticmethod def logSumExp(ns): + ''' + Computes log(e^x1+e^x2+...) for very small x values that might run into a numerical stability problem if + calculated the regular way. Specifically this function subtracts a constant from every number first, then + does the calculation, then adds the constant back at the end. + + Args: + ns (np.array): array of arguments to be log-sum-exp'ed + ''' + # TODO: maybe remove code duplication with the external version of this function __max = np.max(ns) if not np.isfinite(__max): __max = 0 ds = ns - __max with np.errstate(divide='ignore'): sumOfExp = np.exp(ds).sum() - return __max + np.log(sumOfExp) + return __max + np.log(sumOfExp) def EM_step(self, final=False): @@ -603,6 +612,7 @@ class LemurRunEnv(): M-Step: F(t) := sum_{r_in_Reads} [P(t|r)] ''' + # t0 = datetime.datetime.now().timestamp() if final: self.F = self.final_F @@ -611,31 +621,44 @@ class LemurRunEnv(): self.P_tgr = self.P_rgs_df.reset_index().merge(self.F, how="inner", left_on="Target_ID", - right_index=True) - self.P_tgr["P(r|t)*F(t)"] = self.P_tgr.log_P + np.log(self.P_tgr.F) - self.P_tgr_sum = self.P_tgr[["Read_ID", "P(r|t)*F(t)"]].groupby(by="Read_ID", group_keys=False) \ - .agg(self.logSumExp) + right_index=True) #Step 0: Merge-F-to_LL + # t1 = datetime.datetime.now().timestamp() + + # 2) Compute Likelihood x Prior: + self.P_tgr["P(r|t)*F(t)"] = self.P_tgr.log_P + np.log(self.P_tgr.F) #Step 1: Calc_LL*F + # t2 = datetime.datetime.now().timestamp() + + # 3) Compute Lik*Pri Scale factors by Read: + self.P_tgr_sum = EM_get_Prgt_Ft_sums(self.P_tgr, self.args.num_threads) + # t3 = datetime.datetime.now().timestamp() + + # 4) Read in per-read scale factors to (read,target)-level dataframe self.P_tgr = self.P_tgr.merge(self.P_tgr_sum, how="left", left_on="Read_ID", right_index=True, - suffixes=["", "_sum"]) - self.P_tgr["P(t|r)"] = self.P_tgr["P(r|t)*F(t)"] - self.P_tgr["P(r|t)*F(t)_sum"] - - self.log(set(self.P_tgr["Target_ID"]), logging.DEBUG) + suffixes=["", "_sum"]) #Step 3: Merge-LL*Fsum-to-LL + # t4 = datetime.datetime.now().timestamp() + # 5) E-Step: Calculate P(t|r) = [ P(r|t)*F(t) / sum_{t in taxo} (P(r|t)*F(t)) ] + self.P_tgr["P(t|r)"] = self.P_tgr["P(r|t)*F(t)"] - self.P_tgr["P(r|t)*F(t)_sum"] #Step 4: Calc-Ptgr + # t5 = datetime.datetime.now().timestamp() + self.log(set(self.P_tgr["Target_ID"]), logging.DEBUG) n_reads = len(self.P_tgr_sum) + # 6) M-Step: Update the estimated values of F(t) = sum_{r}[P(t|r)] #Step 5: Recalc F self.F = self.P_tgr[["Target_ID", "P(t|r)"]].groupby("Target_ID") \ .agg(lambda x: np.exp(LemurRunEnv.logSumExp(x) - np.log(n_reads)))["P(t|r)"] self.F.name = "F" self.F = self.F.loc[self.F!=0] + + # Logging: report the sum of the F vector (which should be 1.0) self.log(self.F.sum(), logging.DEBUG) + # t6 = datetime.datetime.now().timestamp() if final: self.final_F = self.F - def compute_loglikelihood(self): '''Computes the total loglikelihood of the model, which should increase at every iteration or something is wrong.''' @@ -647,7 +670,7 @@ class LemurRunEnv(): criteria has been met, stopping once it has.''' n_reads = len(set(self.P_rgs_df.reset_index()["Read_ID"])) self.low_abundance_threshold = 1. / n_reads - + if self.args.width_filter: __P_rgs_df = self.P_rgs_df.reset_index() tids = list(self.F.index) @@ -725,9 +748,9 @@ class LemurRunEnv(): self.collapse_rank() return - + i += 1 - + @staticmethod def get_expected_gene_hits(N_genes, N_reads): @@ -775,11 +798,70 @@ class LemurRunEnv(): self.log(df_emu_copy.nlargest(30, ["F"]), logging.DEBUG) self.log(f"File generated: {output_path}\n", logging.DEBUG) +def logSumExp_ReadId(readid, ns): + ''' + This is a duplicate of the logSumExp function in the LemurRunEnv object, except that it accepts and + returns the Read ID as an argument. This is so that we can parallelize this function and still associate + the results with the Read ID after being returned. + ''' + __max = np.max(ns) + if not np.isfinite(__max): + __max = 0 + ds = ns - __max + with np.errstate(divide='ignore'): + sumOfExp = np.exp(ds).sum() + return readid, __max + np.log(sumOfExp) + +def EM_get_Prgt_Ft_sums(P_tgr, nthreads): + ''' + Computes the log-sum-exp and does the sum aggregation to the read level efficiently. + + This function is a strictly numpy replacement for the pandas *.groupby(...).agg(self.logSumExp) + function that was the bottleneck in the EM_step() function. Basically it takes the DF self.P_tgr + and returns a 2-column data frame with columns ["Read_ID","P(r|t)*F(t)"] where each row is the + sum of the second column over all rows with Read_ID equal to the first column. That step was + a real CPU burden in the Pandas implementation and this is much faster. + + This function is set up to be parallelized using thread pools, but in practice that may not add + a lot of speedup. But it has to be extracted from the LemurRunEnv object in order to be passed + to the thread pool. + + Args: + P_tgr (pandas.df): Pandas data frame containing the two columns "Read ID" and "P(r|t)*F(t)". + nthreads (int): Number of threads to parallelize with. + + Returns: + P_tgr_sum (pandas.df) + ''' + # Convert each column to individual numpy arrays + Read_ID = P_tgr["Read_ID"].to_numpy().astype(np.str_) + Prgt_Ft = P_tgr["P(r|t)*F(t)"].to_numpy() + # Sort each of them by Read ID + Read_ID_as = Read_ID.argsort() + Read_ID = Read_ID[Read_ID_as] + Prgt_Ft = Prgt_Ft[Read_ID_as] + # Compute indices of unique read IDs + Read_ID_unq, Read_ID_Idx = np.unique(Read_ID, return_index=True) + # Zip up a list of ReadIDs along with the portions of the df to be computed for each read. + map_args = list(zip(Read_ID_unq.tolist(), np.split(Prgt_Ft, Read_ID_Idx[1:]))) + thread_pool = Pool(nthreads) + map_res = thread_pool.starmap(logSumExp_ReadId, map_args) + thread_pool.close() + thread_pool.join() + # Reassemble the result into a pandas DF that can go right back into the EM iteration method + P_tgr_sum = pd.DataFrame(map_res, columns=["Read_ID","P(r|t)*F(t)"], dtype='O') + P_tgr_sum['P(r|t)*F(t)'] = P_tgr_sum['P(r|t)*F(t)'].astype(float) + rid=P_tgr_sum["Read_ID"].to_numpy() + if not np.all(rid[:-1]