Skip to content

Commit

Permalink
update CategoricalDQN_CartPole-v1
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Dec 19, 2024
1 parent ab8e571 commit 5c06cae
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions notebooks/CategoricalDQN/CategoricalDQN_CartPole-v1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 247,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -39,7 +39,7 @@
" # 各层对应的激活函数\n",
" batch_size = x.size(0)\n",
" logits = self.layers(x)\n",
" dist = torch.softmax(logits.view(batch_size, self.action_dim, self.cfg.n_atoms), dim=2) # (batch_size, a_dim, n_atoms)\n",
" dist = torch.softmax(logits.view(-1, self.action_dim, self.cfg.n_atoms), dim=2) # (batch_size, a_dim, n_atoms)\n",
" return dist"
]
},
Expand All @@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 248,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 249,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -124,6 +124,7 @@
" self.epsilon_decay = cfg.epsilon_decay\n",
" self.n_atoms = cfg.n_atoms\n",
" self.atoms = torch.linspace(self.cfg.v_min, self.cfg.v_max, steps=self.cfg.n_atoms,device=self.device)\n",
" self.proj_dist = torch.zeros((self.cfg.batch_size, self.n_atoms), device=self.device) # [batch_size, n_atoms]\n",
" self.memory = ReplayBuffer(self.cfg.buffer_size)\n",
" # 当前网络和目标网络\n",
" self.model = Model(self.cfg, self.cfg.state_dim, self.cfg.action_dim, hidden_dim = self.cfg.hidden_dim).to(self.device)\n",
Expand Down Expand Up @@ -156,7 +157,7 @@
" state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0) # [1, state_dim]\n",
" dist = self.model(state)\n",
" q_values = (dist * self.atoms).sum(2)\n",
" action = torch.argmax(q_values, dim=1).item()\n",
" action = torch.argmax(q_values).detach().cpu().numpy().item()\n",
" return action\n",
" \n",
" def _projection_distribution(self, next_states, rewards, dones):\n",
Expand All @@ -165,7 +166,7 @@
" with torch.no_grad():\n",
" delta_z = float(self.cfg.v_max - self.cfg.v_min) / (self.n_atoms - 1)\n",
" next_dist = self.target_model(next_states) # [batch_size, action_dim, n_atoms]\n",
" next_q_values = (next_dist * self.n_atoms).sum(2)\n",
" next_q_values = (next_dist * self.n_atoms).sum(2) # [batch_size, action_dim]\n",
" next_dist = next_dist[torch.arange(self.cfg.batch_size), torch.argmax(next_q_values, dim=1)]\n",
" Tz = rewards + (1-dones) * self.gamma * self.atoms\n",
" Tz = Tz.clamp(min=self.cfg.v_min, max=self.cfg.v_max)\n",
Expand All @@ -175,10 +176,9 @@
" delta_m_l = (u + (l == u) - b) * next_dist # (batch_size, n_atoms)\n",
" delta_m_u = (b - l) * next_dist # (batch_size, n_atoms)\n",
" offset = torch.linspace(0, (self.cfg.batch_size - 1) * self.n_atoms, self.cfg.batch_size,device=self.device).unsqueeze(-1).long() \n",
" proj_dist = torch.zeros((self.cfg.batch_size, self.n_atoms), device=self.device) # [batch_size, n_atoms]\n",
" proj_dist.view(-1).index_add_(0, (l + offset).view(-1), delta_m_l.view(-1))\n",
" proj_dist.view(-1).index_add_(0, (u + offset).view(-1), delta_m_u.view(-1))\n",
" return proj_dist\n",
" self.proj_dist *= 0\n",
" self.proj_dist.view(-1).index_add_(0, (l + offset).view(-1), delta_m_l.view(-1))\n",
" self.proj_dist.view(-1).index_add_(0, (u + offset).view(-1), delta_m_u.view(-1))\n",
"\n",
" def update(self):\n",
" if len(self.memory) < self.cfg.batch_size: # 当经验回放中不满足一个批量时,不更新策略\n",
Expand All @@ -193,17 +193,15 @@
" next_states = torch.tensor(np.array(next_states), device=self.device, dtype=torch.float) # [batch_size, state_dim]\n",
" dones = torch.tensor(np.float32(dones), device=self.device).unsqueeze(1) # [batch_size,1]\n",
" # 计算下一时刻的分布\n",
" proj_dist = self._projection_distribution(next_states, rewards, dones)\n",
" self._projection_distribution(next_states, rewards, dones)\n",
" # 计算当前状态的分布\n",
" dist = self.model(states) # [batch_size, action_dim, n_atoms]\n",
" # print(dist.grad_fn,dist.requires_grad,dist.shape,states.shape)\n",
" # print(actions.squeeze(1).shape)\n",
" \n",
" dist = dist[torch.arange(self.cfg.batch_size), actions.squeeze(1).cpu().numpy()] # [batch_size, n_atoms]\n",
" dist = dist.gather(1, actions.unsqueeze(1).expand(self.cfg.batch_size, 1, self.cfg.n_atoms)).squeeze(1) # [batch_size, n_atoms]\n",
" # dist = dist[torch.arange(self.cfg.batch_size), actions.squeeze(1).cpu().numpy()] # [batch_size, n_atoms]\n",
" # action_ = actions.unsqueeze(1).expand(self.batch_size, 1, self.n_atoms)\n",
" # dist = dist.gather(1, action_).squeeze(1) # [batch_size, n_atoms]\n",
" # dist.data.clamp_(0.01, 0.99) # 为了数值稳定,将概率值限制在[0.01, 0.99]之间\n",
" loss = (-(proj_dist* dist.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()\n",
" loss = (-(self.proj_dist * dist.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()\n",
" # print(proj_dist.grad_fn, dist.grad_fn)\n",
" # print(loss)\n",
" self.optimizer.zero_grad() \n",
Expand Down Expand Up @@ -233,7 +231,7 @@
},
{
"cell_type": "code",
"execution_count": 250,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -288,7 +286,7 @@
},
{
"cell_type": "code",
"execution_count": 251,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -329,7 +327,7 @@
},
{
"cell_type": "code",
"execution_count": 252,
"execution_count": 36,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 5c06cae

Please sign in to comment.