Skip to content

Commit

Permalink
pre-commit formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
MorganCThomas committed Aug 13, 2024
1 parent df09f26 commit 967cbfb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion acegen/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from acegen.data.smiles_dataset import load_dataset, MolBloomDataset, SMILESDataset
from acegen.data.utils import smiles_to_tensordict, collate_smiles_to_tensordict
from acegen.data.utils import collate_smiles_to_tensordict, smiles_to_tensordict
10 changes: 7 additions & 3 deletions acegen/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def smiles_to_tensordict(
return smiles_tensordict


def collate_smiles_to_tensordict(arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu"):
def collate_smiles_to_tensordict(
arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu"
):
"""Function to take a list of encoded sequences and turn them into a tensordict."""
collated_arr = torch.ones(len(arr), max_length) * -1
for i, seq in enumerate(arr):
collated_arr[i, : seq.size(0)] = seq
data = smiles_to_tensordict(collated_arr, reward=reward, replace_mask_value=0, device=device)
data = smiles_to_tensordict(
collated_arr, reward=reward, replace_mask_value=0, device=device
)
data.set("sequence", data.get("observation"))
data.set("sequence_mask", data.get("mask"))
return data
return data
28 changes: 19 additions & 9 deletions scripts/augmented_memory/augmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import torch
import tqdm
import yaml
from acegen.data import collate_smiles_to_tensordict

from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, TokenEnv
from acegen.data import collate_smiles_to_tensordict
from acegen.scoring_functions import (
custom_scoring_functions,
register_custom_scoring_function,
Expand Down Expand Up @@ -182,9 +182,7 @@ def run_reinvent(cfg, task):
actor_training = actor_training.to(device)

prior, _ = create_actor(vocabulary_size=len(vocabulary))
prior.load_state_dict(
adapt_state_dict(deepcopy(ckpt), prior.state_dict())
)
prior.load_state_dict(adapt_state_dict(deepcopy(ckpt), prior.state_dict()))
prior = prior.to(device)

# Create RL environment
Expand Down Expand Up @@ -288,7 +286,7 @@ def create_env_fn():
)

data, loss, agent_likelihood = compute_loss(data, actor_training, prior, sigma)

# Average loss over the batch
loss = loss.mean()

Expand All @@ -306,14 +304,26 @@ def create_env_fn():
sampled_smiles = augment_smiles(data.get("SMILES").cpu().data)
sampled_reward = data.get(("next", "reward")).squeeze(-1).sum(-1)
# Sample replay buffer
replay_smiles, replay_reward = task.replay(cfg.replay_batch_size, augment=True)
replay_smiles, replay_reward = task.replay(
cfg.replay_batch_size, augment=True
)
replay_reward = torch.tensor(replay_reward, device=device).float()
# Concatenate and create tensor
aug_tokens = [torch.tensor(vocabulary.encode(smi)) for smi in sampled_smiles + replay_smiles]
aug_tokens = [
torch.tensor(vocabulary.encode(smi))
for smi in sampled_smiles + replay_smiles
]
aug_reward = torch.cat([sampled_reward, replay_reward], dim=0)
aug_data = collate_smiles_to_tensordict(arr=aug_tokens, max_length=env.max_length, reward=aug_reward, device=device)
aug_data = collate_smiles_to_tensordict(
arr=aug_tokens,
max_length=env.max_length,
reward=aug_reward,
device=device,
)
# Compute loss
aug_data, loss, agent_likelihood = compute_loss(aug_data, actor_training, prior, sigma)
aug_data, loss, agent_likelihood = compute_loss(
aug_data, actor_training, prior, sigma
)
# Average loss over the batch
loss = loss.mean()
# Add regularizer that penalizes high likelihood for the entire sequence
Expand Down

0 comments on commit 967cbfb

Please sign in to comment.