-
Notifications
You must be signed in to change notification settings - Fork 11
/
custom_rl_policy.py
93 lines (79 loc) · 3.25 KB
/
custom_rl_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import ActorCriticPolicy
class Transpose(nn.Module):
def forward(self, x):
batch_size = x.shape[0]
return x.view(batch_size, -1)
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the feature extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""
def __init__(
self,
feature_dim: int,
timesteps: int = 12,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super(CustomNetwork, self).__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(last_layer_dim_pi*timesteps*feature_dim, last_layer_dim_pi*feature_dim), nn.ReLU(),
nn.Linear(last_layer_dim_pi*feature_dim, last_layer_dim_pi), nn.Tanh()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(last_layer_dim_vf*timesteps*feature_dim, last_layer_dim_vf*feature_dim), nn.ReLU(),
nn.Linear(last_layer_dim_vf*feature_dim, last_layer_dim_vf), nn.Tanh()
)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.policy_net(features), self.value_net(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(features)
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
*args,
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(
last_layer_dim_pi = self.action_space.shape[0],
last_layer_dim_vf = self.action_space.shape[0],
timesteps = self.observation_space.shape[1],
feature_dim=self.observation_space.shape[2]
)