This project is a basic application of Reinforcement Learning.
It integrates Deep Java Library (DJL) to uses DQN to train agent. The pretrained model are trained with 3M steps on a single GPU.
You can find article explaining the training process on towards data science, or 中文版文章.
This project supports building with Maven, you can use the following command to build:
mvn compile
The following command will start to train without graphics:
mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird"
The above command will train from scratch. You can also try to train with the pretrained weight:
mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p"
To test with the model directly, you can do the followings
mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p -t"
Argument | Comments |
---|---|
-g |
Training with graphics. |
-b |
Batch size to use for training. |
-p |
Use pre-trained weights. |
-t |
Test the trained model. |
The pseudo-code for the Deep Q Learning algorithm, as given in Human-level Control through Deep Reinforcement Learning. Nature, can be found below:
Initialize replay memory D to size N
Initialize action-value function Q with random weights
for episode = 1, M do
Initialize state s_1
for t = 1, T do
With probability ϵ select random action a_t
otherwise select a_t=max_a Q(s_t,a; θ_i)
Execute action a_t in emulator and observe r_t and s_(t+1)
Store transition (s_t,a_t,r_t,s_(t+1)) in D
Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
Set y_j:=
r_j for terminal s_(j+1)
r_j+γ*max_(a^' ) Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
end for
end for
Trained Model
- It may take 10+ hours to train a bird to a perfect state. You can find the model trained with three million steps in project resource folder:
src/main/resources/model/dqn-trained-0000-params
Troubleshooting
This work is based on the following repos:
MIT © Kingyu Luk