diff --git a/README.md b/README.md index b8b8db6..57e4be0 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,82 @@ Instructions coming soon... Instructions coming soon... ### Molecular optimization -Instructions coming soon... +In order to run the optimization algorithm, define the Oracle as a class, that is responsible for calculating the Oracle fuction for given molecules. + +```python +# Oracle implementation scheme + +class ExampleOracle: + def __init__(self, ...): + # maximum number of oracle calls to make + self.max_oracle_calls: int = ... + + # the frequence with which to log + self.freq_log: int = ... + + # the buffer to keep track of all unique molecules generated + self.mol_buffer: Dict = ... + + # the maximum possible oracle score or an upper bound + self.max_possible_oracle_score: float = ... + + def __call__(self, molecules): + """ + Evaluate and return the oracle scores for molecules. Log the intermediate results if necessary. + """ + ... + return oracle_scores + + @property + def finish(self): + """ + Specify the stopping condition for the optimization process. + """ + return stopping_condition +``` + +Define configuration and hyperparameters used for the optimization process in a yaml file. + +```yaml +# yaml config scheme + +checkpoint_path: /path/to/model_dir +tokenizer_path: /path/to/tokenizer_dir + +... optimization algorithm hyperparameter (pool size, number of similar molecules to use, etc.) ... + +generation_config: + ... molecule generation hyperparameters ... + +strategy: [rej-sample-v2] # or use [default] for not performing the fine-tuning step during the optimization. + +rej_sample_config: + ... fine tuning hyperparameters ... +``` + +Putting everything toghether and running the optimization process. + +```python +from chemlactica.mol_opt.optimization import optimize + +# Load config +config = yaml.safe_load(open(path_to_yaml_config)) + +# Load the model and the tokenizer +model = AutoModelForCausalLM.from_pretrained(...) +tokenizer = AutoTokenizer.from_pretrained(...) + +# Create Oracle +oracle = ExampleOracle(...) + +# Call the optimize function to optimize against the defined oracle +optimize( + model, tokenizer, + oracle, config +) +``` + +[example_run.py]() illustrates a full working example of an optimization run. For more complex examples refer to the [ChemlacticaTestSuit]() repository [mol_opt/run.py]() and [retmol/run_qed.py]() files. ## Tests The test for running the a small sized model with the same diff --git a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml index 04227b4..c08cccb 100644 --- a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml +++ b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml @@ -1,5 +1,5 @@ -checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 -tokenizer_path: ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66 +checkpoint_path: /path/to/model_dir +tokenizer_path: /path/to/tokenizer_dir pool_size: 10 validation_perc: 0.2 num_mols: 0 diff --git a/chemlactica/mol_opt/example_run.py b/chemlactica/mol_opt/example_run.py new file mode 100644 index 0000000..fc49466 --- /dev/null +++ b/chemlactica/mol_opt/example_run.py @@ -0,0 +1,113 @@ +from typing import List +import yaml +import argparse +import os +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import numpy as np +from rdkit.Chem import rdMolDescriptors +from chemlactica.mol_opt.optimization import optimize +from chemlactica.mol_opt.utils import set_seed, MoleculeEntry + + +class TPSA_Weight_Oracle: + def __init__(self, max_oracle_calls: int): + # maximum number of oracle calls to make + self.max_oracle_calls = max_oracle_calls + + # the frequence with which to log + self.freq_log = 100 + + # the buffer to keep track of all unique molecules generated + self.mol_buffer = {} + + # the maximum possible oracle score or an upper bound + self.max_possible_oracle_score = 1.0 + + # if True the __call__ function takes list of MoleculeEntry objects + # if False (or unspecified) the __call__ function takes list of SMILES strings + self.takes_entry = True + + def __call__(self, molecules: List[MoleculeEntry]): + """ + Evaluate and return the oracle scores for molecules. Log the intermediate results if necessary. + """ + oracle_scores = [] + for molecule in molecules: + if self.mol_buffer.get(molecule.smiles): + oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0])) + else: + try: + tpsa = rdMolDescriptors.CalcTPSA(molecule.mol) + tpsa_score = min(tpsa / 1000, 1) + weight = rdMolDescriptors.CalcExactMolWt(molecule.mol) + if weight <= 349: + weight_score = 1 + elif weight >= 500: + weight_score = 0 + else: + weight_score = -0.00662 * weight + 3.31125 + + oracle_score = (tpsa_score + weight_score) / 3 + except Exception as e: + print(e) + oracle_score = 0 + self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1] + if len(self.mol_buffer) % 100 == 0: + self.log_intermediate() + oracle_scores.append(oracle_score) + return oracle_scores + + def log_intermediate(self): + scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:] + scores_sorted = sorted(scores, reverse=True)[:100] + n_calls = len(self.mol_buffer) + + score_avg_top1 = np.max(scores_sorted) + score_avg_top10 = np.mean(scores_sorted[:10]) + score_avg_top100 = np.mean(scores_sorted) + + print(f"{n_calls}/{self.max_oracle_calls} | ", + f'avg_top1: {score_avg_top1:.3f} | ' + f'avg_top10: {score_avg_top10:.3f} | ' + f'avg_top100: {score_avg_top100:.3f}') + + def __len__(self): + return len(self.mol_buffer) + + @property + def budget(self): + return self.max_oracle_calls + + @property + def finish(self): + # the stopping condition for the optimization process + return len(self.mol_buffer) >= self.max_oracle_calls + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--config_default", type=str, required=True) + parser.add_argument("--n_runs", type=int, required=False, default=1) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_arguments() + config = yaml.safe_load(open(args.config_default)) + + model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"]) + tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left") + + seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] + for i in range(args.n_runs): + set_seed(seeds[i]) + oracle = TPSA_Weight_Oracle(max_oracle_calls=1000) + config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") + config["max_possible_oracle_score"] = oracle.max_possible_oracle_score + optimize( + model, tokenizer, + oracle, config + ) \ No newline at end of file diff --git a/chemlactica/mol_opt/hparams_tune.yaml b/chemlactica/mol_opt/hparams_tune.yaml index 71721f6..bde5ce7 100644 --- a/chemlactica/mol_opt/hparams_tune.yaml +++ b/chemlactica/mol_opt/hparams_tune.yaml @@ -12,7 +12,7 @@ parameters: num_gens_per_iter: [200, 400, 600] generation_temperature: [[1.0, 1.0], [1.5, 1.5], [1.0, 1.5]] - # rej_sample_config: - # num_train_epochs: [1, 3, 5, 7, 9] - # train_tol_level: [1, 3, 5, 7, 9] - # max_learning_rate: [0.0001, 0.00001, 0.000001] \ No newline at end of file + rej_sample_config: + num_train_epochs: [1, 3, 5, 7, 9] + train_tol_level: [1, 3, 5, 7, 9] + max_learning_rate: [0.0001, 0.00001, 0.000001] \ No newline at end of file diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py new file mode 100644 index 0000000..9cb2897 --- /dev/null +++ b/chemlactica/mol_opt/optimization.py @@ -0,0 +1,238 @@ +from typing import List +import torch +from datasets import Dataset +import gc +import shutil +from trl import SFTTrainer +from transformers import OPTForCausalLM +from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback, CustomModelSelectionCallback + + +def create_similar_mol_entries(pool, mol_entry, num_similars): + similar_entries = [e.last_entry for e in pool.random_subset(num_similars)] + count = 0 + valid_similar_entries = [] + for similar_entry in similar_entries: + if count >= num_similars: + break + if similar_entry == mol_entry: + continue + valid_similar_entries.append(similar_entry) + count += 1 + return valid_similar_entries + + +def create_optimization_entries(num_entries, pool, config): + optim_entries = [] + for i in range(num_entries): + mol_entries = [e.last_entry for e in pool.random_subset(config["num_mols"])] + entries = [] + for mol_entry in mol_entries: + similar_mol_entries = create_similar_mol_entries(pool, mol_entry, num_similars=config["num_similars"]) + mol_entry.similar_mol_entries = similar_mol_entries + entries.append(mol_entry) + optim_entries.append(OptimEntry(None, entries)) + return optim_entries + + +def create_molecule_entry(output_text): + start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" + start_ind = output_text.rfind(start_smiles_tag) + end_ind = output_text.rfind(end_smiles_tag) + if start_ind == -1 or end_ind == -1: + return None + generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] + if len(generated_smiles) == 0: + return None + + try: + molecule = MoleculeEntry( + smiles=generated_smiles, + ) + return molecule + except: + return None + + +def optimize( + model, tokenizer, + oracle, config, + additional_properties={} + ): + file = open(config["log_dir"], "w") + print("config", config) + # print("molecule generation arguments", config["generation_config"]) + pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) + + config["generation_config"]["temperature"] = config["generation_temperature"][0] + + if "rej-sample-v2" in config["strategy"]: + training_args = get_training_arguments(config["rej_sample_config"]) + effective_batch_size = config["rej_sample_config"]["gradient_accumulation_steps"] * config["rej_sample_config"]["train_batch_size"] + num_single_train_steps = config["rej_sample_config"]["num_train_epochs"] * ((1 - config["validation_perc"]) * config["pool_size"] / effective_batch_size) + max_num_trains = oracle.max_oracle_calls / (config["rej_sample_config"]["train_tol_level"] * config["num_gens_per_iter"]) + max_num_train_steps = int(max_num_trains * num_single_train_steps) + optimizer, lr_scheduler = get_optimizer_and_lr_scheduler(model, config["rej_sample_config"], max_num_train_steps) + max_score = 0 + tol_level = 0 + num_iter = 0 + prev_train_iter = 0 + while True: + model.eval() + new_best_molecule_generated = False + iter_unique_optim_entries: List[OptimEntry] = {} + while len(iter_unique_optim_entries) < config["num_gens_per_iter"]: + optim_entries = create_optimization_entries( + config["generation_batch_size"], pool, + config=config + ) + for i in range(len(optim_entries)): + last_entry = MoleculeEntry(smiles="") + last_entry.similar_mol_entries = create_similar_mol_entries( + pool, last_entry, config["num_similars"] + ) + for prop_name, prop_spec in additional_properties.items(): + last_entry.add_props[prop_name] = prop_spec + optim_entries[i].last_entry = last_entry + + prompts = [ + optim_entry.to_prompt( + is_generation=True, include_oracle_score=prev_train_iter != 0, + config=config, max_score=max_score + ) + for optim_entry in optim_entries + ] + output_texts = [] + data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + if type(model) == OPTForCausalLM: + del data["token_type_ids"] + for key, value in data.items(): + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] + output = model.generate( + **data, + **config["generation_config"] + ) + gc.collect() + torch.cuda.empty_cache() + output_texts.extend(tokenizer.batch_decode(output)) + + current_unique_optim_entries = {} + # with multiprocessing.Pool(processes=config["num_processes"]) as pol: + for i, molecule in enumerate(map(create_molecule_entry, output_texts)): + if molecule and not optim_entries[i].contains_entry(molecule): + if molecule.smiles not in oracle.mol_buffer and molecule.smiles not in current_unique_optim_entries: + molecule.similar_mol_entries = optim_entries[i].last_entry.similar_mol_entries + for prop_name, prop_spec in additional_properties.items(): + molecule.add_props[prop_name] = prop_spec + molecule.add_props[prop_name]["value"] = molecule.add_props[prop_name]["calculate_value"](molecule) + optim_entries[i].last_entry = molecule + current_unique_optim_entries[molecule.smiles] = optim_entries[i] + + num_of_molecules_to_score = min(len(current_unique_optim_entries), config["num_gens_per_iter"] - len(iter_unique_optim_entries)) + current_unique_smiles_list = list(current_unique_optim_entries.keys())[:num_of_molecules_to_score] + current_unique_optim_entries = {smiles: current_unique_optim_entries[smiles] for smiles in current_unique_smiles_list} + + if getattr(oracle, "takes_entry", False): + oracle_scores = oracle([current_unique_optim_entries[smiles].last_entry for smiles in current_unique_smiles_list]) + else: + oracle_scores = oracle(current_unique_smiles_list) + + for smiles, oracle_score in zip(current_unique_smiles_list, oracle_scores): + current_unique_optim_entries[smiles].last_entry.score = oracle_score + iter_unique_optim_entries[smiles] = current_unique_optim_entries[smiles] + file.write(f"generated smiles: {smiles}, score: {current_unique_optim_entries[smiles].last_entry.score:.4f}\n") + if max_score >= config["max_possible_oracle_score"] - 1e-2 or current_unique_optim_entries[smiles].last_entry.score > max_score: + max_score = max(max_score, current_unique_optim_entries[smiles].last_entry.score) + new_best_molecule_generated = True + + print(f"Iter unique optim entries: {len(iter_unique_optim_entries)}, budget: {len(oracle)}") + + if oracle.finish: + break + + if oracle.finish: + break + initial_num_iter = num_iter + num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] + if num_iter > initial_num_iter: + tol_level += 1 + + if new_best_molecule_generated: + tol_level = 0 + + print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") + if num_iter > initial_num_iter: + config["generation_config"]["temperature"] += config["num_gens_per_iter"] / (oracle.budget - config["num_gens_per_iter"]) * (config["generation_temperature"][1] - config["generation_temperature"][0]) + print(f"Generation temperature: {config['generation_config']['temperature']}") + + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) + pool.add(list(iter_unique_optim_entries.values())) + file.write("Pool\n") + for i, optim_entry in enumerate(pool.optim_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + if "rej-sample-v2" in config["strategy"]: + # round_entries.extend(current_entries) + # round_entries = list(np.unique(round_entries))[::-1] + # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) + # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: + if tol_level >= config["rej_sample_config"]["train_tol_level"]: + train_entries, validation_entries = pool.get_train_valid_entries() + print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") + file.write("Training entries\n") + for i, optim_entry in enumerate(train_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + file.write("Validation entries\n") + for i, optim_entry in enumerate(validation_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + train_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt( + is_generation=False, include_oracle_score=True, + config=config, max_score=config["max_possible_oracle_score"] + ) + for optim_entry in train_entries + ] + }) + validation_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt( + is_generation=False, include_oracle_score=True, + config=config, max_score=config["max_possible_oracle_score"] + ) + for optim_entry in validation_entries + ] + }) + train_dataset.shuffle(seed=42) + validation_dataset.shuffle(seed=42) + + # early_stopping_callback = CustomEarlyStopCallback( + # early_stopping_patience=1, + # early_stopping_threshold=0.0001 + # ) + model_selection_callback = CustomModelSelectionCallback() + + model.train() + trainer = SFTTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=validation_dataset, + formatting_func=lambda x: x["sample"], + args=training_args, + packing=config["rej_sample_config"]["packing"], + tokenizer=tokenizer, + max_seq_length=config["rej_sample_config"]["max_seq_length"], + # data_collator=collator, + callbacks=[model_selection_callback], + optimizers=[optimizer, lr_scheduler], + ) + trainer.train() + print(f"Loading the best model state dict with validation loss {model_selection_callback.best_validation_loss}") + model.load_state_dict(model_selection_callback.best_model_state_dict) + del model_selection_callback.best_model_state_dict + gc.collect() + torch.cuda.empty_cache() + tol_level = 0 + prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/optimization_run_example.py b/chemlactica/mol_opt/optimization_run_example.py deleted file mode 100644 index d295736..0000000 --- a/chemlactica/mol_opt/optimization_run_example.py +++ /dev/null @@ -1,37 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -import torch -import yaml -import datetime -import argparse -import os -from utils import ConstrainedTPSAOracle -from typing import List -from chemlactica.mol_opt.optimization import optimize - -os.environ["TOKENIZERS_PARALLELISM"] = "true" - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument("--run_name", type=str, required=False) - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument("--config_default", type=str, required=False, default="chemlactica/chemlactica_125m_hparams.yaml") - parser.add_argument("--n_runs", type=int, required=False, default=1) - args = parser.parse_args() - return args - - -if __name__ == "__main__": - args = parse_arguments() - config = yaml.safe_load(open(args.config_default)) - - model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"]) - tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left") - - seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] - oracle = ConstrainedTPSAOracle(max_oracle_calls=5000) - config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") - optimize( - model, tokenizer, - oracle, config - ) \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 4bf89ea..a602189 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -5,7 +5,6 @@ from pathlib import Path import numpy as np import torch -from chemlactica.mol_opt.metrics import top_auc from rdkit import Chem, DataStructs, RDLogger from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors @@ -86,78 +85,6 @@ def __hash__(self): return hash(self.smiles) -class ConstrainedTPSAOracle: - def __init__(self, max_oracle_calls: int): - self.max_oracle_calls = max_oracle_calls - self.freq_log = 100 - self.mol_buffer = {} - self.max_possible_oracle_score = 1.0 - self.takes_entry = True - - def __call__(self, molecules): - oracle_scores = [] - for molecule in molecules: - if self.mol_buffer.get(molecule.smiles): - oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0])) - else: - try: - tpsa = rdMolDescriptors.CalcTPSA(molecule.mol) - tpsa_score = min(tpsa / 1000, 1) - weight = rdMolDescriptors.CalcExactMolWt(molecule.mol) - if weight <= 349: - weight_score = 1 - elif weight >= 500: - weight_score = 0 - else: - weight_score = -0.00662 * weight + 3.31125 - - num_rings = rdMolDescriptors.CalcNumRings(molecule.mol) - if num_rings >= 2: - num_rights_score = 1 - else: - num_rights_score = 0 - # print(tpsa_score, weight_score, num_rights_score) - oracle_score = (tpsa_score + weight_score + num_rights_score) / 3 - except Exception as e: - print(e) - oracle_score = 0 - self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1] - if len(self.mol_buffer) % 100 == 0: - self.log_intermediate() - oracle_scores.append(oracle_score) - return oracle_scores - - def log_intermediate(self): - scores = [v[0] for v in self.mol_buffer.values()] - scores_sorted = sorted(scores, reverse=True)[:100] - n_calls = len(self.mol_buffer) - - score_avg_top1 = np.max(scores_sorted) - score_avg_top10 = np.mean(scores_sorted[:10]) - score_avg_top100 = np.mean(scores_sorted) - - print(f"{n_calls}/{self.max_oracle_calls} | ", - f"auc_top1: {top_auc(self.mol_buffer, 1, False, self.freq_log, self.max_oracle_calls)} | ", - f"auc_top10: {top_auc(self.mol_buffer, 10, False, self.freq_log, self.max_oracle_calls)} | ", - f"auc_top100: {top_auc(self.mol_buffer, 100, False, self.freq_log, self.max_oracle_calls)}") - - print(f'avg_top1: {score_avg_top1:.3f} | ' - f'avg_top10: {score_avg_top10:.3f} | ' - f'avg_top100: {score_avg_top100:.3f}') - - def __len__(self): - return len(self.mol_buffer) - - @property - def budget(self): - return self.max_oracle_calls - - @property - def finish(self): - return len(self.mol_buffer) >= self.max_oracle_calls - - - class Pool: def __init__(self, size, validation_perc: float): self.size = size