forked from opendilab/LightZero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
参考与代码修改.txt
35 lines (26 loc) · 2.8 KB
/
参考与代码修改.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
https://blog.csdn.net/qq_45691577/article/details/129386350
The obs returned by the `reset()` method is not within the observation space.
修改mazegame.py 的get_obs 将return show 改成return np.array[show] 这样就让返回的shape(4,4) 变成shape(1,4,4) 这样才不会与observation=(1,4,4)不match
Traceback (most recent call last):
File "/kaggle/working/LightZero/./lzero/mcts/tests/test_game_segment_mazegame.py", line 161, in <module>
test_game_segment(test_algo)
File "/kaggle/working/LightZero/./lzero/mcts/tests/test_game_segment_mazegame.py", line 54, in test_game_segment
model.load_state_dict(matched_state_dict, strict=False)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MuZeroModel:
size mismatch for dynamics_network.fc_reward_head.3.weight: copying a param with shape torch.Size([601, 32]) from checkpoint, the shape in current model is torch.Size([1, 32]).
size mismatch for dynamics_network.fc_reward_head.3.bias: copying a param with shape torch.Size([601]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for prediction_network.fc_value.3.weight: copying a param with shape torch.Size([601, 32]) from checkpoint, the shape in current model is torch.Size([1, 32]).
size mismatch for prediction_network.fc_value.3.bias: copying a param with shape torch.Size([601]) from checkpoint, the shape in current model is torch.Size([1]).
这是因为/lzero/mcts/tests/config/mazegame_muzero_config_for_test.py的53行 support_scale设置成了1 这跟https://github.com/valkryhx/LightZero/blob/main/zoo/classic_control/mazegame/config/mymaze_muzero_config.py#L37 的默认support_scale=300 冲突
为什么默认是300 这可以根据 support_scale=300 在repo里面查 可以看到其他代码中都是默认这个值,最好是学习其他代码把这个值在train的config中写出来
support_scale=300, # 默认是300 可以用 support_scale=300 这个string来查
model 参数含义
https://github.com/valkryhx/LightZero/blob/4d73183c5b3a40cba3a5a66bf792bb87016d92d2/lzero/policy/muzero.py#L53
git 上游同步方法 如果有冲突就将自己修改的文件在别处保存后删除,再执行如下操作 最后记得对比上游文件。
https://github.com/selfteaching/the-craft-of-selfteaching/issues/67
https://gitbook.tw/chapters/github/syncing-a-fork
# 注意 action not in self.legal_actions时的处理 是随机从合法actions中选一个
https://github.com/valkryhx/LightZero/blob/1d181b8f85810866ef7ef52ffb3c2c836d0dc4a2/zoo/game_2048/envs/game_2048_env.py#L216
https://github.com/valkryhx/LightZero/blob/1d181b8f85810866ef7ef52ffb3c2c836d0dc4a2/zoo/game_2048/envs/game_2048_env.py#L270