Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: Values of the node_space should be instances of Box or Discrete, got <class 'int'> #1152

Open
johnnytam100 opened this issue Sep 2, 2024 · 7 comments
Labels
question Further information is requested

Comments

@johnnytam100
Copy link

johnnytam100 commented Sep 2, 2024

Question

I want to add a graph to the observation space by

  # Create NetworkX graph
  G = nx.Graph()

  # Add nodes (C-alpha atoms) with 320-dimensional zero embeddings
  for _, row in ca_atoms.iterrows():
      G.add_node(row['residue_number'], feat=np.zeros(320)) 

  # Add edges based on distance threshold
  for i, row_i in ca_atoms.iterrows():
      for j, row_j in ca_atoms.iterrows():
          if i != j:
              distance = np.sqrt(
                  (row_i['x_coord'] - row_j['x_coord'])**2 +
                  (row_i['y_coord'] - row_j['y_coord'])**2 +
                  (row_i['z_coord'] - row_j['z_coord'])**2
              )
              if distance <= 4.0:  # 4 Angstrom threshold
                  G.add_edge(row_i['residue_number'], row_j['residue_number'])

  # Convert NetworkX graph to Gymnasium spaces.Graph 
  num_nodes = G.number_of_nodes()
  node_features = np.zeros((num_nodes, 320))  # All zeros for now
  for i, node in enumerate(G.nodes()):
      node_features[i] = G.nodes[node]['feat']
  node_features = torch.tensor(node_features, dtype=torch.float32)

  edge_indices = torch.tensor(list(G.edges())).t().contiguous()

  obs_graph = spaces.Graph(num_nodes, node_features, edge_indices)

getting

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-25-20fd2ce85efa>](https://localhost:8080/#) in <cell line: 271>()
    269 # Create environment
    270 env = ProteinOptimizationEnv()
--> 271 check_env(env)
    272 
    273 # Create and train agent

3 frames
[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in check_env(env, warn, skip_render_check)
    431 
    432     try:
--> 433         env.reset(seed=0)
    434     except TypeError as e:
    435         raise TypeError("The reset() method must accept a `seed` parameter") from e

[<ipython-input-25-20fd2ce85efa>](https://localhost:8080/#) in reset(self, seed)
     98         )
     99         self.current_pdb = os.path.join(self.output_folder, 'initial.pdb')
--> 100         return self._get_obs(), {}
    101 
    102 

[<ipython-input-25-20fd2ce85efa>](https://localhost:8080/#) in _get_obs(self)
    263         edge_indices = torch.tensor(list(G.edges())).t().contiguous()
    264 
--> 265         obs_graph = spaces.Graph(num_nodes, node_features, edge_indices)
    266 
    267         return {'score': np.array([score], dtype=np.float32), 'graph': obs_graph}

[/usr/local/lib/python3.10/dist-packages/gymnasium/spaces/graph.py](https://localhost:8080/#) in __init__(self, node_space, edge_space, seed)
     72             seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
     73         """
---> 74         assert isinstance(
     75             node_space, (Box, Discrete)
     76         ), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}"

AssertionError: Values of the node_space should be instances of Box or Discrete, got <class 'int'>

Any idea what's wrong with my way of doing?

@johnnytam100 johnnytam100 added the question Further information is requested label Sep 2, 2024
@pseudo-rnd-thoughts
Copy link
Member

Hi, I'm glad that someone is using the Graph space.
In short, Graph node_space and edge_space is expecting a space describing each node's structure and each edge's structure rather than the number of nodes / features.
https://gymnasium.farama.org/api/spaces/composite/#graph

Therefore, this might be more simplify that you imagined with your node space being the range of possible node values, np.zeros(320) if I understand, Box(low=np.zeros(320), high=...) and the edge seems to be a binary value, Discrete(2) should work.

@johnnytam100
Copy link
Author

Good to know you're glad that graph space is used!
However, I don't know if I understand your response correctly, and I am just posting the whole code that I am using:

#@title RL (trial 11)

! rm -rf output_*

! pip install stable-baselines3
! pip install biopandas
! pip install gymnasium
! pip install networkx

import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
import subprocess
import numpy as np
import time
import os
import torch

# Custom Environment
class ProteinOptimizationEnv(gym.Env):
    def __init__(self):
        # Define action and observation spaces
        self.action_space = spaces.Box(low=np.array([0, 0, 0, 0]), high=np.array([1, 1, 1, 1]), dtype=np.float32)

        # Define the node and edge spaces
        self.node_space = spaces.Box(low=np.zeros(320), high=np.inf * np.ones(320), dtype=np.float32)  # 320-dim node features, all zeros initially
        self.edge_space = spaces.Discrete(1)  # Single edge type for now (distance-based)

        self.observation_space = spaces.Dict({
            'score': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32),
            'graph': spaces.Graph(node_space=self.node_space, edge_space=self.edge_space)
        })

        '''
        # Conda environment names (uncomment and replace if needed)
        self.perturbation_env = "perturbation_env"
        self.scorer_env = "scorer_env"
        '''

        # Fixed parameters
        self.fixed_perturbation_param = 0.5
        self.fixed_scorer_param = 0.5

        # Parameter ranges
        self.perturbation_param_min = [0, 0]
        self.perturbation_param_max = [1, 1]
        self.scorer_param_min = [0.9, 0.9]
        self.scorer_param_max = [1, 1]

        # Initialize protein structure and parameters
        self.perturbation_params = [0.5, 0.5]  # Initial values for P1 and P2
        self.scorer_params = [0.5, 0.5]  # Initial values for S1 and S2

        # Initialize time_taken, cycle_count, and outer_cycle_count
        self.time_taken = 0
        self.cycle_count = -1
        self.outer_cycle_count = 0
        self.outer_cycle_start_time = time.time()

        # Generate initial pdb and save it in the first output folder
        self.output_folder = f"output_{self.outer_cycle_count}"
        os.makedirs(self.output_folder, exist_ok=True)
        subprocess.run(
            # f"conda activate {self.perturbation_env} && "  # Uncomment and replace if needed
            f"python perturbation.py --P1 {self.perturbation_params[0]} --P2 {self.perturbation_params[1]} "
            f"--P3 {self.fixed_perturbation_param} --P4 initial.pdb --P5 {os.path.join(self.output_folder, 'initial.pdb')}",
            shell=True,
            check=True
        )
        self.current_pdb = os.path.join(self.output_folder, 'initial.pdb')

        # Flag to track if CUDA device has been switched
        self.cuda_device_switched = False

    def reset(self, seed=None):
        # ... (reset other parameters if needed)

        # Reset time_taken
        self.time_taken = 0

        # Create output folder for this outer cycle only if it's the first cycle
        # or if the target score was reached in the previous episode
        if self.cycle_count == 0 or (hasattr(self, 'done') and self.done):
            self.output_folder = f"output_{self.outer_cycle_count}"
            os.makedirs(self.output_folder, exist_ok=True)
            self.cuda_device_switched = False

        # Regenerate initial pdb and save it
        subprocess.run(
            # f"conda activate {self.perturbation_env} && "  # Uncomment and replace if needed
            f"python perturbation.py --P1 {self.perturbation_params[0]} --P2 {self.perturbation_params[1]} "
            f"--P3 {self.fixed_perturbation_param} --P4 initial.pdb --P5 {os.path.join(self.output_folder, 'initial.pdb')}",
            shell=True,
            check=True
        )
        self.current_pdb = os.path.join(self.output_folder, 'initial.pdb')
        return self._get_obs(), {}


    def step(self, action):
        start_time = time.time()

        # Check if CUDA device has been switched
        current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if current_device.type == "cuda" and not self.cuda_device_switched:
            self.cuda_device_switched = True
            print("Using cuda device")

        # Increment cycle_count at the beginning of the step
        self.cycle_count += 1

        # Update perturbation and scorer parameters based on action
        self.perturbation_params = [
            action[0] * (self.perturbation_param_max[0] - self.perturbation_param_min[0]) + self.perturbation_param_min[0],
            action[1] * (self.perturbation_param_max[1] - self.perturbation_param_min[1]) + self.perturbation_param_min[1]
        ]
        self.scorer_params = [
            action[2] * (self.scorer_param_max[0] - self.scorer_param_min[0]) + self.scorer_param_min[0],
            action[3] * (self.scorer_param_max[1] - self.scorer_param_min[1]) + self.scorer_param_min[1]
        ]

        # Perturb the protein structure and save it
        #try:
        subprocess.run(
            # f"conda activate {self.perturbation_env} && "  # Uncomment and replace if needed
            f"python perturbation.py --P1 {self.perturbation_params[0]} --P2 {self.perturbation_params[1]} "
            f"--P3 {self.fixed_perturbation_param} --P4 {self.current_pdb} --P5 {os.path.join(self.output_folder, f'next_{1 if self.cycle_count == 0 else self.cycle_count}.pdb')}",
            shell=True,
            check=True
        )
        self.current_pdb = os.path.join(self.output_folder, f'next_{1 if self.cycle_count == 0 else self.cycle_count}.pdb')

        # Get score
        obs = self._get_obs()
        if obs is None or not np.isscalar(obs[0]) or not np.isfinite(obs[0]):
            print("Error: _get_obs() returned an invalid score")
            done = True
            reward = -1.0  # Penalty for invalid score
            truncated = False
            info = {"error": "scoring"}
        else:
            score = obs[0]

            # Check if done
            done = score >= 90
            done = bool(done)

            end_time = time.time()
            time_taken_this_step = end_time - start_time
            time_taken_so_far = end_time - self.outer_cycle_start_time

            # Increment time_taken
            self.time_taken += time_taken_this_step

            # Calculate score increment
            score_increment = score - self.prev_score if hasattr(self, 'prev_score') else 0
            self.prev_score = score  # Store current score for the next step

            # Calculate reward as score increment divided by time taken
            if score_increment > 0:
                reward = score_increment / time_taken_this_step # if score increased, shorter time -> higher reward
            else:
                reward = score_increment * time_taken_this_step # if score decreased, longer time -> higher penalty

            # Print verbose output (using cycle_count)
            print(f"Cycle: {self.cycle_count}, "
                  f"Perturbation Params: {self.perturbation_params}, "
                  f"Scorer Params: {self.scorer_params}, "
                  f"Score: {score}, "
                  f"Time (this step): {time_taken_this_step:.2f} seconds, "
                  f"Time (total): {time_taken_so_far:.2f} seconds, ")

            # Calculate reward using exponential decay
            if done:
                print("Finished!")

                # Calculate and print outer cycle time
                outer_cycle_time = time.time() - self.outer_cycle_start_time
                print(f"Outer Cycle Time: {outer_cycle_time:.2f} seconds")

                # Increment outer_cycle_count, reset outer_cycle_start_time, and reset cycle_count
                self.outer_cycle_count += 1
                self.outer_cycle_start_time = time.time()
                self.cycle_count = 0

            truncated = False
            info = {}

        '''
        except subprocess.CalledProcessError as e:
            print(f"Error perturbing protein: {e}")
            done = True
            reward = -1.0  # Penalty for perturbation error
            truncated = False
            info = {"error": "perturbation"}

        except Exception as e:
            print(f"Unexpected error in step: {e}")
            done = True
            reward = -1.0  # Penalty for unexpected error
            truncated = False
            info = {"error": "unexpected"}
        '''

        return self._get_obs(), reward, done, truncated, info


    def _get_obs(self):
        # Get score using scorer command and current_pdb
        #try:
        scoring_result = subprocess.run(
            # f"conda activate {self.scorer_env} && "
            f"python scorer.py --S1 {self.scorer_params[0]} --S2 {self.scorer_params[1]} "
            f"--S3 {self.fixed_scorer_param} --S4 {self.current_pdb}",
            shell=True,
            capture_output=True,
            text=True,
            check=True
        )
        score = float(scoring_result.stdout)
        '''
        except (subprocess.CalledProcessError, ValueError) as e:
            print(f"Error scoring protein: {e}")
            score = 0.0
        '''

        # Generate graph representation from current_pdb using NetworkX
        ppdb = PandasPdb().read_pdb(self.current_pdb)
        df = ppdb.df['ATOM']
        
        # Filter for C-alpha atoms
        ca_atoms = df[df['atom_name'] == 'CA']

        # Create NetworkX graph
        G = nx.Graph()

        # Add nodes (C-alpha atoms) with 320-dimensional zero embeddings
        for _, row in ca_atoms.iterrows():
            G.add_node(row['residue_number'], feat=np.zeros(320)) 

        # Add edges based on distance threshold
        for i, row_i in ca_atoms.iterrows():
            for j, row_j in ca_atoms.iterrows():
                if i != j:
                    distance = np.sqrt(
                        (row_i['x_coord'] - row_j['x_coord'])**2 +
                        (row_i['y_coord'] - row_j['y_coord'])**2 +
                        (row_i['z_coord'] - row_j['z_coord'])**2
                    )
                    if distance <= 4.0:  # 4 Angstrom threshold
                        G.add_edge(row_i['residue_number'], row_j['residue_number'])

        # Convert NetworkX graph to Gymnasium spaces.Graph 
        num_nodes = G.number_of_nodes()
        node_features = np.zeros((num_nodes, 320))  # All zeros for now
        for i, node in enumerate(G.nodes()):
            node_features[i] = G.nodes[node]['feat']
        node_features = torch.tensor(node_features, dtype=torch.float32)

        edge_indices = torch.tensor(list(G.edges())).t().contiguous()

        obs_graph = spaces.Graph(self.node_space, self.edge_space, edge_indices, node_features)
        obs_graph.ndata['feat'] = node_features

        return {'score': np.array([score], dtype=np.float32), 'graph': obs_graph}

# Create environment
env = ProteinOptimizationEnv()
check_env(env)

# Create and train agent
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

# Evaluate or use the trained agent
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

    if dones:
        # Convert seed to integer before resetting
        seed_value = seed.item() if isinstance(seed, torch.Tensor) else int(seed)  
        obs, _ = env.reset(seed=seed_value)
        break

Unfortunately, another error appeared now:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in check_env(env, warn, skip_render_check)
    432     try:
--> 433         env.reset(seed=0)
    434     except TypeError as e:

3 frames
[<ipython-input-34-9489b1de1055>](https://localhost:8080/#) in reset(self, seed)
     97         self.current_pdb = os.path.join(self.output_folder, 'initial.pdb')
---> 98         return self._get_obs(), {}
     99 

[<ipython-input-34-9489b1de1055>](https://localhost:8080/#) in _get_obs(self)
    262 
--> 263         obs_graph = spaces.Graph(self.node_space, self.edge_space, edge_indices, node_features)
    264         obs_graph.ndata['feat'] = node_features

TypeError: Graph.__init__() takes from 3 to 4 positional arguments but 5 were given

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[<ipython-input-34-9489b1de1055>](https://localhost:8080/#) in <cell line: 270>()
    268 # Create environment
    269 env = ProteinOptimizationEnv()
--> 270 check_env(env)
    271 
    272 # Create and train agent

[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in check_env(env, warn, skip_render_check)
    433         env.reset(seed=0)
    434     except TypeError as e:
--> 435         raise TypeError("The reset() method must accept a `seed` parameter") from e
    436 
    437     # Warn the user if needed.

TypeError: The reset() method must accept a `seed` parameter

@pseudo-rnd-thoughts
Copy link
Member

I can't run the script because you haven't provided the perturbation script.

@pseudo-rnd-thoughts
Copy link
Member

The error was actually

TypeError: Graph.init() takes from 3 to 4 positional arguments but 5 were given

for obs_graph = spaces.Graph(self.node_space, self.edge_space, edge_indices, node_features) in _get_obs()

You shouldn't be using a space.Graph for generating an observation, rather the GraphInstance object

@johnnytam100
Copy link
Author

Um... I think I don't exactly know how to use this GraphInstance object.
Would you mind giving a very brief example?

@pseudo-rnd-thoughts
Copy link
Member

pseudo-rnd-thoughts commented Sep 12, 2024

It solely acts as a tuple with three elements,

class GraphInstance(NamedTuple):

You can include the nodes, edges and edge links. The edge links is purely a binary matrix and the edge contains the weights

To make a GraphInstance(nodes, edges, edge_links)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants