Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hpt model and corresponding examples. #839

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .hpt import HPT

from .bcq import BCQ
from .edac import EDAC
from .qgpo import QGPO
from .ebm import EBM, AutoregressiveEBM
from .havac import HAVAC

152 changes: 152 additions & 0 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Union, Optional, Dict, Callable, List
from einops import rearrange, repeat
import torch
import torch.nn as nn
from ding.model.common.head import DuelingHead
from ding.utils.registry_factory import MODEL_REGISTRY


INIT_CONST = 0.02

@MODEL_REGISTRY.register('hpt')
class HPT(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add paper link and original github repo link to refer others' work

def __init__(self, state_dim, action_dim):
super(HPT, self).__init__()
# Initialise Policy Stem
self.policy_stem = PolicyStem()
self.policy_stem.init_cross_attn()

# Dueling Head, input is 16*128, output is action dimension
self.head = DuelingHead(hidden_size=16*128, output_size=action_dim)
def forward(self, x):
# Policy Stem Outputs [B, 16, 128]
tokens = self.policy_stem.compute_latent(x)
# Flatten Operation
tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128]
# Enter to Dueling Head
q_values = self.head(tokens_flattened)
return q_values



class PolicyStem(nn.Module):
"""policy stem
Overview:
The reference uses PolicyStem from
<https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py>
"""
def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
super().__init__()
# Initialise the feature extraction module
self.feature_extractor = nn.Linear(feature_dim, token_dim)
# Initialise CrossAttention
self.init_cross_attn()

def init_cross_attn(self):
"""Initialize cross attention module and learnable tokens."""
token_num = 16
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST)
self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1)

def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute latent representations of input data using attention.
Args:
x (torch.Tensor): Input tensor with shape [B, T, ..., F].
Returns:
torch.Tensor: Latent tokens, shape [B, 16, 128].
"""
# Using the Feature Extractor
stem_feat = self.feature_extractor(x)
stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
# Calculating latent tokens using CrossAttention
stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128)
stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128)
return stem_tokens
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass to compute latent tokens.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Latent tokens tensor.
"""
return self.compute_latent(x)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def save(self, path : str):
torch.save(self.state_dict(), path)

@property
def device(self):
return next(self.parameters()).device

class CrossAttention(nn.Module):
"""
CrossAttention module used in the Perceiver IO model.
Args:
query_dim (int): The dimension of the query input.
heads (int, optional): The number of attention heads. Defaults to 8.
dim_head (int, optional): The dimension of each attention head. Defaults to 64.
dropout (float, optional): The dropout probability. Defaults to 0.0.
"""

def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim
self.scale = dim_head**-0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, query_dim)

self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, context: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Forward pass of the CrossAttention module.
Args:
x (torch.Tensor): The query input tensor.
context (torch.Tensor): The context input tensor.
mask (torch.Tensor, optional): The attention mask tensor. Defaults to None.
Returns:
torch.Tensor: The output tensor.
"""
h = self.heads
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

if mask is not None:
# fill in the masks with negative values
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

# dropout
attn = self.dropout(attn)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
77 changes: 77 additions & 0 deletions dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from easydict import EasyDict

nstep = 3
lunarlander_hpt_config = dict(
exp_name='lunarlander_hpt_seed0',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
# Env number respectively for collector and evaluator.
collector_env_num=8,
evaluator_env_num=8,
env_id='LunarLander-v2',
n_evaluator_episode=8,
stop_value=200,
# The path to save the game replay
# replay_path='./lunarlander_hpt_seed0/video',
),
policy=dict(
# Whether to use cuda for network.
cuda=True,
load_path="./lunarlander_hpt_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=8,
action_shape=4,
),
# Reward's future discount factor, aka. gamma.
discount_factor=0.99,
# How many steps in td error.
nstep=nstep,
# learn_mode config
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=0.0005,
# Frequency of target network update.
target_update_freq=100,
),
# collect_mode config
collect=dict(
# You can use either "n_sample" or "n_episode" in collector.collect.
# Get "n_sample" samples per collect.
n_sample=64,
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
# command_mode config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
decay=50000,
),
replay_buffer=dict(replay_buffer_size=100000, )
),
),
)
lunarlander_hpt_config = EasyDict(lunarlander_hpt_config)
main_config = lunarlander_hpt_config

lunarlander_hpt_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
# env_manager=dict(type='base'),
policy=dict(type='dqn'),
)
lunarlander_hpt_create_config = EasyDict(lunarlander_hpt_create_config)
create_config = lunarlander_hpt_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline([main_config, create_config], seed=0)
14 changes: 13 additions & 1 deletion dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import torch
from ditk import logging
from ding.data.model_loader import FileModelLoader
from ding.data.storage_loader import FileStorageLoader
Expand All @@ -16,10 +17,13 @@
from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config




def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
ding_init(cfg)

with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
Expand All @@ -32,10 +36,16 @@ def main():

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
# # Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DQN(**cfg.policy.model).to(device)

buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

# Pass the model into Policy
policy = DQNPolicy(cfg.policy, model=model)


# Consider the case with multiple processes
if task.router.is_active:
# You can use labels to distinguish between workers with different roles,
Expand All @@ -50,8 +60,10 @@ def main():
# Sync their context and model between each worker.
task.use(ContextExchanger(skip_n_iter=1))
task.use(ModelExchanger(model))


# Here is the part of single process pipeline.
evaluator_env.enable_save_replay(replay_path='./video')
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
Expand Down
80 changes: 80 additions & 0 deletions dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

import gym
import torch
import torch.nn as nn
from ditk import logging
from ding.data.model_loader import FileModelLoader
from ding.data.storage_loader import FileStorageLoader
from ding.model.common.head import DuelingHead
from ding.model.template.hpt import HPT
from ding.policy import DQNPolicy
from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, termination_checker, \
nstep_reward_enhancer
from ding.utils import set_pkg_seed
from dizoo.box2d.lunarlander.config.lunarlander_hpt_config import main_config, create_config




def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
ding_init(cfg)

with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

# Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HPT(cfg.policy.model.obs_shape,cfg.policy.model.action_shape).to(device)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

# Pass the model into Policy
policy = DQNPolicy(cfg.policy, model=model)

# Consider the case with multiple processes
if task.router.is_active:
# You can use labels to distinguish between workers with different roles,
# here we use node_id to distinguish.
if task.router.node_id == 0:
task.add_role(task.role.LEARNER)
elif task.router.node_id == 1:
task.add_role(task.role.EVALUATOR)
else:
task.add_role(task.role.COLLECTOR)

# Sync their context and model between each worker.
task.use(ContextExchanger(skip_n_iter=1))
task.use(ModelExchanger(model))


# Here is the part of single process pipeline.
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(online_logger(train_show_freq=50))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(termination_checker(max_env_step=int(3e6)))
task.run()


if __name__ == "__main__":
main()
Loading