This module contains algorithms we choose to implement and test.
Here is an example training PPO on the CartPole-v1 environment. Since it is a PyTorch-Lightning Module it is trained using their Trainer API.
Note that this PPO implementation needs to be more thoroughly benchmarked and so may be a work in progress.
The reload_dataloaders_every_epoch
flag is needed to ensure that during each training step, the updates are computed on the latest batch of data.
To see how we implement this, view the source code for the PPO
class.
agent = PPO("CartPole-v1")
trainer = pl.Trainer(reload_dataloaders_every_epoch=True, max_epochs=25)
trainer.fit(agent)