This module defines losses for a variety of RL agents.

actor_critic_value_loss[source]

actor_critic_value_loss(value_estimates:Tensor, env_returns:Tensor)

Loss for an actor-critic value function.

Is just Mean-Squared-Error between the value estimates and the real returns.

Args:

  • value_estimates (torch.Tensor): Estimates of state-value from the critic network.
  • env_returns (torch.Tensor): Real returns from the environment.

Returns:

  • value_loss (torch.Tensor): MSE loss betwen the estimates and real returns.

reinforce_policy_loss[source]

reinforce_policy_loss(logps:Tensor, env_returns:Tensor)

Reinforce Policy gradient loss. $-(log(\pi(a | s)) * R_t)$

Args:

  • logps (PyTorch Tensor): Action log probabilities.
  • env_returns (PyTorch Tensor): Returns from the environment.

Returns:

  • reinforce_loss (torch.Tensor): REINFORCE loss term.

a2c_policy_loss[source]

a2c_policy_loss(logps:Tensor, advs:Tensor)

Loss function for an A2C policy. $-(logp(\pi(a|s)) * A_t)$

Args:

  • logps (torch.Tensor): Log-probabilities of selected actions.
  • advs (torch.Tensor): Advantage estimates of selected actions.

Returns:

  • a2c_loss (torch.Tensor): A2C loss term.

ppo_clip_policy_loss[source]

ppo_clip_policy_loss(logps:Tensor, logps_old:Tensor, advs:Tensor, clipratio:Optional[float]=0.2)

Loss function for a PPO-clip policy. See paper for full loss function math: https://arxiv.org/abs/1707.06347

Args:

  • logps (torch.Tensor): Action log-probabilities under the current policy.
  • logps_old (torch.Tensor): Action log-probabilities under the old (pre-update) policy.
  • advs (torch.Tensor): Advantage estimates for the actions taken.
  • clipratio (float): Clipping parameter for PPO-clip loss. In general, is fine with being left as default.

Returns:

  • ppo_loss (torch.Tensor): Loss term for PPO agent.
  • kl (torch.Tensor): KL-divergence estimate between new and old policies.

ddpg_policy_loss[source]

ddpg_policy_loss(states:Tensor, qfunc:Module, policy:Module)

Policy loss function for DDPG agent. See the paper: https://arxiv.org/abs/1509.02971

Args:

  • states (torch.Tensor): States to get Q-policy estimates for.
  • qfunc (nn.Module): Q-function network.
  • policy (nn.Module): Policy network.

Returns:

  • q_policy_loss (torch.Tensor): Loss term for DDPG policy.

ddpg_qfunc_loss[source]

ddpg_qfunc_loss(data:Tuple[Tensor, Tensor, Tensor, Tensor, Tensor], qfunc:Module, qfunc_target:Module, policy_target:Module, gamma:Optional[float]=0.99)

Loss for a DDPG Q-function. See the paper: https://arxiv.org/abs/1509.02971

Args:

  • data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the following: (states, next_states, actions, rewards, dones).
  • qfunc (nn.Module): Q-function network being trained.
  • qfunc_target (nn.Module): Q-function target network.
  • policy_target (nn.Module): Policy target network.
  • gamma (float): Discount factor.

Returns:

  • loss_q (torch.Tensor): DDPG loss for the Q-function.
  • loss_info (dict): Dictionary containing useful loss info for logging.

td3_policy_loss[source]

td3_policy_loss(states:Tensor, qfunc:Module, policy:Module)

Calculate policy loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477

Args:

  • states (torch.Tensor): Input states to get policy loss for.
  • qfunc (torch.Tensor): TD3 q-function network.
  • policy (torch.Tensor): Policy network.

Returns:

  • q_policy_loss (torch.Tensor): The TD3 policy loss term.

td3_qfunc_loss[source]

td3_qfunc_loss(data:Tuple[Tensor, Tensor, Tensor, Tensor, Tensor], qfunc1:Module, qfunc2:Module, qfunc1_target:Module, qfunc2_target:Module, policy:Module, act_limit:Union[float, int], target_noise:Optional[float]=0.2, noise_clip:Optional[float]=0.5, gamma:Optional[float]=0.99)

Calculate Q-function loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477

Args:

  • data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the following: (states, next_states, actions, rewards, dones).
  • qfunc1 (nn.Module): First Q-function network being trained.
  • qfunc2 (nn.Module): Other Q-function network being trained.
  • qfunc1_target (nn.Module): First Q-function target network.
  • qfunc2_target (nn.Module): Other Q-function target network.
  • policy (nn.Module): Policy network.
  • act_limit (float or int): Action limit from the environment.
  • target_noise (float): Noise to apply to policy target network.
  • noise_clip (float): Clip the noise within + and - this range.
  • gamma (float): Gamma discount factor.

Returns:

  • loss_q (torch.Tensor): TD3 loss for the Q-function.
  • loss_info (dict): Dictionary containing useful loss info for logging.

sac_policy_loss[source]

sac_policy_loss(states:Tensor, qfunc1:Module, qfunc2:Module, policy:Module, alpha:Optional[float]=0.2)

Calculate policy loss for Soft-Actor Critic agent. See paper here: https://arxiv.org/abs/1801.01290

Args:

  • states (torch.Tensor): Input states for the policy.
  • qfunc1 (nn.Module): First Q-function in SAC agent.
  • qfunc2 (nn.Module): Second Q-function in SAC agent.
  • policy (nn.Module): Policy network.
  • alpha (float): alpha factor for entropy-regularized policy loss.

Returns:

  • loss_policy (torch.Tensor): The policy loss term.
  • policy_info (dict): Useful logging info for the policy.

sac_qfunc_loss[source]

sac_qfunc_loss(data, qfunc1:Module, qfunc2:Module, qfunc1_target:Module, qfunc2_target:Module, policy:Module, gamma:Optional[float]=0.99, alpha:Optional[float]=0.2)

Q-function loss for Soft-Actor Critic agent.

Args:

  • data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the following: (states, next_states, actions, rewards, dones).
  • qfunc1 (nn.Module): First Q-function network being trained.
  • qfunc2 (nn.Module): Other Q-function network being trained.
  • qfunc1_target (nn.Module): First Q-function target network.
  • qfunc2_target (nn.Module): Other Q-function target network.
  • policy (nn.Module): Policy network.
  • gamma (float): Gamma discount factor.
  • alpha (float): Loss term alpha factor.

Returns:

  • loss_q (torch.Tensor): SAC loss for the Q-function.
  • loss_info (dict): Dictionary containing useful loss info for logging.