A JAX implementation of the MuZero agent.
Everything is implemented in JAX, including the MCTS. The entire search process can be jitted and can run on accelerators such as GPUs.
Run the following command to create a new conda environment with all dependencies:
conda env create -f conda_env.yml
Then activate the conda environment by
conda activate muzero
Or if you prefer using your own Python environment, run the following command to install the dependencies:
pip install -r requirements.txt
Run the following command for learning to play the Atari game Breakout:
python -m experiments.breakout
Median human-normalized score:
.
├── algorithms # Files for the MuZero algorithm.
│ ├── actors.py # Agent-environment interaction.
│ ├── agents.py # An RL agent that plans with a learned model by MCTS.
│ ├── haiku_nets.py # Neural networks.
│ ├── muzero.py # The training pipeline.
│ ├── replay_buffers.py # Experience replay.
│ ├── types.py # Customized data structures.
│ └── utils.py # Helper functions.
├── environments # The Atari environment interface and wrappers.
├── experiments # Experiment configuration files.
├── vec_env # Vectorized environment interfaces.
├── conda_env.yml # Conda environment specification.
├── requirements.txt # Python dependencies.
├── LICENSE
└── README.md
- NeurIPS 2020: JAX Ecosystem Meetup, video and slides
- https://arxiv.org/src/1911.08265v2/anc/pseudocode.py
- https://github.com/YeWR/EfficientZero