Skip to content

Using reinforcement learning to train FlappyBird.

License

Notifications You must be signed in to change notification settings

JyKlub/RL-FlappyBird

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RL Flappy Bird

Overview

This project is a basic application of Reinforcement Learning.

It integrates Deep Java Library (DJL) to uses DQN to train agent. The pretrained model are trained with 3M steps on a single GPU.

You can find article explaining the training process on towards data science, or 中文版文章.

Build the project and run

This project supports building with Maven, you can use the following command to build:

mvn compile  

The following command will start to train without graphics:

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird"

The above command will train from scratch. You can also try to train with the pretrained weight:

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p"

To test with the model directly, you can do the followings

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p -t"  
Argument Comments
-g Training with graphics.
-b Batch size to use for training.
-p Use pre-trained weights.
-t Test the trained model.

Deep Q-Network Algorithm

The pseudo-code for the Deep Q Learning algorithm, as given in Human-level Control through Deep Reinforcement Learning. Nature, can be found below:

Initialize replay memory D to size N
Initialize action-value function Q with random weights
for episode = 1, M do
    Initialize state s_1
    for t = 1, T do
        With probability ϵ select random action a_t
        otherwise select a_t=max_a  Q(s_t,a; θ_i)
        Execute action a_t in emulator and observe r_t and s_(t+1)
        Store transition (s_t,a_t,r_t,s_(t+1)) in D
        Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
        Set y_j:=
            r_j for terminal s_(j+1)
            r_j+γ*max_(a^' )  Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
        Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
    end for
end for

Notes

Trained Model

  • It may take 10+ hours to train a bird to a perfect state. You can find the model trained with three million steps in project resource folder: src/main/resources/model/dqn-trained-0000-params

Troubleshooting

This work is based on the following repos:

License

MIT © Kingyu Luk

About

Using reinforcement learning to train FlappyBird.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Java 100.0%