diff --git a/acegen/data/__init__.py b/acegen/data/__init__.py index afafbfa..f7a73d1 100644 --- a/acegen/data/__init__.py +++ b/acegen/data/__init__.py @@ -1,2 +1,2 @@ from acegen.data.smiles_dataset import load_dataset, MolBloomDataset, SMILESDataset -from acegen.data.utils import smiles_to_tensordict, collate_smiles_to_tensordict +from acegen.data.utils import collate_smiles_to_tensordict, smiles_to_tensordict diff --git a/acegen/data/utils.py b/acegen/data/utils.py index 2414bbb..c1b7149 100644 --- a/acegen/data/utils.py +++ b/acegen/data/utils.py @@ -69,12 +69,16 @@ def smiles_to_tensordict( return smiles_tensordict -def collate_smiles_to_tensordict(arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu"): +def collate_smiles_to_tensordict( + arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu" +): """Function to take a list of encoded sequences and turn them into a tensordict.""" collated_arr = torch.ones(len(arr), max_length) * -1 for i, seq in enumerate(arr): collated_arr[i, : seq.size(0)] = seq - data = smiles_to_tensordict(collated_arr, reward=reward, replace_mask_value=0, device=device) + data = smiles_to_tensordict( + collated_arr, reward=reward, replace_mask_value=0, device=device + ) data.set("sequence", data.get("observation")) data.set("sequence_mask", data.get("mask")) - return data \ No newline at end of file + return data diff --git a/scripts/augmented_memory/augmem.py b/scripts/augmented_memory/augmem.py index f1c0f05..bcdbe29 100644 --- a/scripts/augmented_memory/augmem.py +++ b/scripts/augmented_memory/augmem.py @@ -11,10 +11,10 @@ import torch import tqdm import yaml +from acegen.data import collate_smiles_to_tensordict from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, TokenEnv -from acegen.data import collate_smiles_to_tensordict from acegen.scoring_functions import ( custom_scoring_functions, register_custom_scoring_function, @@ -182,9 +182,7 @@ def run_reinvent(cfg, task): actor_training = actor_training.to(device) prior, _ = create_actor(vocabulary_size=len(vocabulary)) - prior.load_state_dict( - adapt_state_dict(deepcopy(ckpt), prior.state_dict()) - ) + prior.load_state_dict(adapt_state_dict(deepcopy(ckpt), prior.state_dict())) prior = prior.to(device) # Create RL environment @@ -288,7 +286,7 @@ def create_env_fn(): ) data, loss, agent_likelihood = compute_loss(data, actor_training, prior, sigma) - + # Average loss over the batch loss = loss.mean() @@ -306,14 +304,26 @@ def create_env_fn(): sampled_smiles = augment_smiles(data.get("SMILES").cpu().data) sampled_reward = data.get(("next", "reward")).squeeze(-1).sum(-1) # Sample replay buffer - replay_smiles, replay_reward = task.replay(cfg.replay_batch_size, augment=True) + replay_smiles, replay_reward = task.replay( + cfg.replay_batch_size, augment=True + ) replay_reward = torch.tensor(replay_reward, device=device).float() # Concatenate and create tensor - aug_tokens = [torch.tensor(vocabulary.encode(smi)) for smi in sampled_smiles + replay_smiles] + aug_tokens = [ + torch.tensor(vocabulary.encode(smi)) + for smi in sampled_smiles + replay_smiles + ] aug_reward = torch.cat([sampled_reward, replay_reward], dim=0) - aug_data = collate_smiles_to_tensordict(arr=aug_tokens, max_length=env.max_length, reward=aug_reward, device=device) + aug_data = collate_smiles_to_tensordict( + arr=aug_tokens, + max_length=env.max_length, + reward=aug_reward, + device=device, + ) # Compute loss - aug_data, loss, agent_likelihood = compute_loss(aug_data, actor_training, prior, sigma) + aug_data, loss, agent_likelihood = compute_loss( + aug_data, actor_training, prior, sigma + ) # Average loss over the batch loss = loss.mean() # Add regularizer that penalizes high likelihood for the entire sequence