Skip to content

Commit

Permalink
Merge pull request #35 from YerevaNN/mol_opt
Browse files Browse the repository at this point in the history
add optim instructions to README.md
  • Loading branch information
tigranfah authored Jul 27, 2024
2 parents 34d3f57 + a03e50b commit dfc6d3f
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 117 deletions.
77 changes: 76 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,82 @@ Instructions coming soon...
Instructions coming soon...

### Molecular optimization
Instructions coming soon...
In order to run the optimization algorithm, define the Oracle as a class, that is responsible for calculating the Oracle fuction for given molecules.

```python
# Oracle implementation scheme

class ExampleOracle:
def __init__(self, ...):
# maximum number of oracle calls to make
self.max_oracle_calls: int = ...

# the frequence with which to log
self.freq_log: int = ...

# the buffer to keep track of all unique molecules generated
self.mol_buffer: Dict = ...

# the maximum possible oracle score or an upper bound
self.max_possible_oracle_score: float = ...

def __call__(self, molecules):
"""
Evaluate and return the oracle scores for molecules. Log the intermediate results if necessary.
"""
...
return oracle_scores

@property
def finish(self):
"""
Specify the stopping condition for the optimization process.
"""
return stopping_condition
```

Define configuration and hyperparameters used for the optimization process in a yaml file.

```yaml
# yaml config scheme

checkpoint_path: /path/to/model_dir
tokenizer_path: /path/to/tokenizer_dir

... optimization algorithm hyperparameter (pool size, number of similar molecules to use, etc.) ...

generation_config:
... molecule generation hyperparameters ...

strategy: [rej-sample-v2] # or use [default] for not performing the fine-tuning step during the optimization.

rej_sample_config:
... fine tuning hyperparameters ...
```
Putting everything toghether and running the optimization process.
```python
from chemlactica.mol_opt.optimization import optimize

# Load config
config = yaml.safe_load(open(path_to_yaml_config))

# Load the model and the tokenizer
model = AutoModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# Create Oracle
oracle = ExampleOracle(...)

# Call the optimize function to optimize against the defined oracle
optimize(
model, tokenizer,
oracle, config
)
```

[example_run.py]() illustrates a full working example of an optimization run. For more complex examples refer to the [ChemlacticaTestSuit]() repository [mol_opt/run.py]() and [retmol/run_qed.py]() files.

## Tests
The test for running the a small sized model with the same
Expand Down
4 changes: 2 additions & 2 deletions chemlactica/mol_opt/chemlactica_125m_hparams.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000
tokenizer_path: ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66
checkpoint_path: /path/to/model_dir
tokenizer_path: /path/to/tokenizer_dir
pool_size: 10
validation_perc: 0.2
num_mols: 0
Expand Down
113 changes: 113 additions & 0 deletions chemlactica/mol_opt/example_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import List
import yaml
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
from rdkit.Chem import rdMolDescriptors
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed, MoleculeEntry


class TPSA_Weight_Oracle:
def __init__(self, max_oracle_calls: int):
# maximum number of oracle calls to make
self.max_oracle_calls = max_oracle_calls

# the frequence with which to log
self.freq_log = 100

# the buffer to keep track of all unique molecules generated
self.mol_buffer = {}

# the maximum possible oracle score or an upper bound
self.max_possible_oracle_score = 1.0

# if True the __call__ function takes list of MoleculeEntry objects
# if False (or unspecified) the __call__ function takes list of SMILES strings
self.takes_entry = True

def __call__(self, molecules: List[MoleculeEntry]):
"""
Evaluate and return the oracle scores for molecules. Log the intermediate results if necessary.
"""
oracle_scores = []
for molecule in molecules:
if self.mol_buffer.get(molecule.smiles):
oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0]))
else:
try:
tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
tpsa_score = min(tpsa / 1000, 1)
weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
if weight <= 349:
weight_score = 1
elif weight >= 500:
weight_score = 0
else:
weight_score = -0.00662 * weight + 3.31125

oracle_score = (tpsa_score + weight_score) / 3
except Exception as e:
print(e)
oracle_score = 0
self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1]
if len(self.mol_buffer) % 100 == 0:
self.log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores

def log_intermediate(self):
scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:]
scores_sorted = sorted(scores, reverse=True)[:100]
n_calls = len(self.mol_buffer)

score_avg_top1 = np.max(scores_sorted)
score_avg_top10 = np.mean(scores_sorted[:10])
score_avg_top100 = np.mean(scores_sorted)

print(f"{n_calls}/{self.max_oracle_calls} | ",
f'avg_top1: {score_avg_top1:.3f} | '
f'avg_top10: {score_avg_top10:.3f} | '
f'avg_top100: {score_avg_top100:.3f}')

def __len__(self):
return len(self.mol_buffer)

@property
def budget(self):
return self.max_oracle_calls

@property
def finish(self):
# the stopping condition for the optimization process
return len(self.mol_buffer) >= self.max_oracle_calls


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--config_default", type=str, required=True)
parser.add_argument("--n_runs", type=int, required=False, default=1)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_arguments()
config = yaml.safe_load(open(args.config_default))

model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"])
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left")

seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
for i in range(args.n_runs):
set_seed(seeds[i])
oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)
config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log")
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
optimize(
model, tokenizer,
oracle, config
)
8 changes: 4 additions & 4 deletions chemlactica/mol_opt/hparams_tune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ parameters:
num_gens_per_iter: [200, 400, 600]
generation_temperature: [[1.0, 1.0], [1.5, 1.5], [1.0, 1.5]]

# rej_sample_config:
# num_train_epochs: [1, 3, 5, 7, 9]
# train_tol_level: [1, 3, 5, 7, 9]
# max_learning_rate: [0.0001, 0.00001, 0.000001]
rej_sample_config:
num_train_epochs: [1, 3, 5, 7, 9]
train_tol_level: [1, 3, 5, 7, 9]
max_learning_rate: [0.0001, 0.00001, 0.000001]
Loading

0 comments on commit dfc6d3f

Please sign in to comment.