From 5c06cae67ac02541ca06f56f919dd8dd30aa496c Mon Sep 17 00:00:00 2001 From: johnjim0816 Date: Thu, 19 Dec 2024 23:42:43 +0800 Subject: [PATCH] update CategoricalDQN_CartPole-v1 --- .../CategoricalDQN_CartPole-v1.ipynb | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/notebooks/CategoricalDQN/CategoricalDQN_CartPole-v1.ipynb b/notebooks/CategoricalDQN/CategoricalDQN_CartPole-v1.ipynb index 2927986..9e0531a 100644 --- a/notebooks/CategoricalDQN/CategoricalDQN_CartPole-v1.ipynb +++ b/notebooks/CategoricalDQN/CategoricalDQN_CartPole-v1.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 247, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ " # 各层对应的激活函数\n", " batch_size = x.size(0)\n", " logits = self.layers(x)\n", - " dist = torch.softmax(logits.view(batch_size, self.action_dim, self.cfg.n_atoms), dim=2) # (batch_size, a_dim, n_atoms)\n", + " dist = torch.softmax(logits.view(-1, self.action_dim, self.cfg.n_atoms), dim=2) # (batch_size, a_dim, n_atoms)\n", " return dist" ] }, @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 248, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 249, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -124,6 +124,7 @@ " self.epsilon_decay = cfg.epsilon_decay\n", " self.n_atoms = cfg.n_atoms\n", " self.atoms = torch.linspace(self.cfg.v_min, self.cfg.v_max, steps=self.cfg.n_atoms,device=self.device)\n", + " self.proj_dist = torch.zeros((self.cfg.batch_size, self.n_atoms), device=self.device) # [batch_size, n_atoms]\n", " self.memory = ReplayBuffer(self.cfg.buffer_size)\n", " # 当前网络和目标网络\n", " self.model = Model(self.cfg, self.cfg.state_dim, self.cfg.action_dim, hidden_dim = self.cfg.hidden_dim).to(self.device)\n", @@ -156,7 +157,7 @@ " state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0) # [1, state_dim]\n", " dist = self.model(state)\n", " q_values = (dist * self.atoms).sum(2)\n", - " action = torch.argmax(q_values, dim=1).item()\n", + " action = torch.argmax(q_values).detach().cpu().numpy().item()\n", " return action\n", " \n", " def _projection_distribution(self, next_states, rewards, dones):\n", @@ -165,7 +166,7 @@ " with torch.no_grad():\n", " delta_z = float(self.cfg.v_max - self.cfg.v_min) / (self.n_atoms - 1)\n", " next_dist = self.target_model(next_states) # [batch_size, action_dim, n_atoms]\n", - " next_q_values = (next_dist * self.n_atoms).sum(2)\n", + " next_q_values = (next_dist * self.n_atoms).sum(2) # [batch_size, action_dim]\n", " next_dist = next_dist[torch.arange(self.cfg.batch_size), torch.argmax(next_q_values, dim=1)]\n", " Tz = rewards + (1-dones) * self.gamma * self.atoms\n", " Tz = Tz.clamp(min=self.cfg.v_min, max=self.cfg.v_max)\n", @@ -175,10 +176,9 @@ " delta_m_l = (u + (l == u) - b) * next_dist # (batch_size, n_atoms)\n", " delta_m_u = (b - l) * next_dist # (batch_size, n_atoms)\n", " offset = torch.linspace(0, (self.cfg.batch_size - 1) * self.n_atoms, self.cfg.batch_size,device=self.device).unsqueeze(-1).long() \n", - " proj_dist = torch.zeros((self.cfg.batch_size, self.n_atoms), device=self.device) # [batch_size, n_atoms]\n", - " proj_dist.view(-1).index_add_(0, (l + offset).view(-1), delta_m_l.view(-1))\n", - " proj_dist.view(-1).index_add_(0, (u + offset).view(-1), delta_m_u.view(-1))\n", - " return proj_dist\n", + " self.proj_dist *= 0\n", + " self.proj_dist.view(-1).index_add_(0, (l + offset).view(-1), delta_m_l.view(-1))\n", + " self.proj_dist.view(-1).index_add_(0, (u + offset).view(-1), delta_m_u.view(-1))\n", "\n", " def update(self):\n", " if len(self.memory) < self.cfg.batch_size: # 当经验回放中不满足一个批量时,不更新策略\n", @@ -193,17 +193,15 @@ " next_states = torch.tensor(np.array(next_states), device=self.device, dtype=torch.float) # [batch_size, state_dim]\n", " dones = torch.tensor(np.float32(dones), device=self.device).unsqueeze(1) # [batch_size,1]\n", " # 计算下一时刻的分布\n", - " proj_dist = self._projection_distribution(next_states, rewards, dones)\n", + " self._projection_distribution(next_states, rewards, dones)\n", " # 计算当前状态的分布\n", " dist = self.model(states) # [batch_size, action_dim, n_atoms]\n", - " # print(dist.grad_fn,dist.requires_grad,dist.shape,states.shape)\n", - " # print(actions.squeeze(1).shape)\n", - " \n", - " dist = dist[torch.arange(self.cfg.batch_size), actions.squeeze(1).cpu().numpy()] # [batch_size, n_atoms]\n", + " dist = dist.gather(1, actions.unsqueeze(1).expand(self.cfg.batch_size, 1, self.cfg.n_atoms)).squeeze(1) # [batch_size, n_atoms]\n", + " # dist = dist[torch.arange(self.cfg.batch_size), actions.squeeze(1).cpu().numpy()] # [batch_size, n_atoms]\n", " # action_ = actions.unsqueeze(1).expand(self.batch_size, 1, self.n_atoms)\n", " # dist = dist.gather(1, action_).squeeze(1) # [batch_size, n_atoms]\n", " # dist.data.clamp_(0.01, 0.99) # 为了数值稳定,将概率值限制在[0.01, 0.99]之间\n", - " loss = (-(proj_dist* dist.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()\n", + " loss = (-(self.proj_dist * dist.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()\n", " # print(proj_dist.grad_fn, dist.grad_fn)\n", " # print(loss)\n", " self.optimizer.zero_grad() \n", @@ -233,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 250, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -288,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 251, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -329,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 252, + "execution_count": 36, "metadata": {}, "outputs": [ {