actor_critic_value_loss
[source]
actor_critic_value_loss
(value_estimates
:Tensor
,env_returns
:Tensor
)
Loss for an actor-critic value function.
Is just Mean-Squared-Error between the value estimates and the real returns.
Args:
- value_estimates (torch.Tensor): Estimates of state-value from the critic network.
- env_returns (torch.Tensor): Real returns from the environment.
Returns:
- value_loss (torch.Tensor): MSE loss betwen the estimates and real returns.
reinforce_policy_loss
[source]
reinforce_policy_loss
(logps
:Tensor
,env_returns
:Tensor
)
Reinforce Policy gradient loss. $-(log(\pi(a | s)) * R_t)$
Args:
- logps (PyTorch Tensor): Action log probabilities.
- env_returns (PyTorch Tensor): Returns from the environment.
Returns:
- reinforce_loss (torch.Tensor): REINFORCE loss term.
a2c_policy_loss
[source]
a2c_policy_loss
(logps
:Tensor
,advs
:Tensor
)
Loss function for an A2C policy. $-(logp(\pi(a|s)) * A_t)$
Args:
- logps (torch.Tensor): Log-probabilities of selected actions.
- advs (torch.Tensor): Advantage estimates of selected actions.
Returns:
- a2c_loss (torch.Tensor): A2C loss term.
ppo_clip_policy_loss
[source]
ppo_clip_policy_loss
(logps
:Tensor
,logps_old
:Tensor
,advs
:Tensor
,clipratio
:Optional
[float
]=0.2
)
Loss function for a PPO-clip policy. See paper for full loss function math: https://arxiv.org/abs/1707.06347
Args:
- logps (torch.Tensor): Action log-probabilities under the current policy.
- logps_old (torch.Tensor): Action log-probabilities under the old (pre-update) policy.
- advs (torch.Tensor): Advantage estimates for the actions taken.
- clipratio (float): Clipping parameter for PPO-clip loss. In general, is fine with being left as default.
Returns:
- ppo_loss (torch.Tensor): Loss term for PPO agent.
- kl (torch.Tensor): KL-divergence estimate between new and old policies.
ddpg_policy_loss
[source]
ddpg_policy_loss
(states
:Tensor
,qfunc
:Module
,policy
:Module
)
Policy loss function for DDPG agent. See the paper: https://arxiv.org/abs/1509.02971
Args:
- states (torch.Tensor): States to get Q-policy estimates for.
- qfunc (nn.Module): Q-function network.
- policy (nn.Module): Policy network.
Returns:
- q_policy_loss (torch.Tensor): Loss term for DDPG policy.
ddpg_qfunc_loss
[source]
ddpg_qfunc_loss
(data
:Tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
],qfunc
:Module
,qfunc_target
:Module
,policy_target
:Module
,gamma
:Optional
[float
]=0.99
)
Loss for a DDPG Q-function. See the paper: https://arxiv.org/abs/1509.02971
Args:
- data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the following: (states, next_states, actions, rewards, dones).
- qfunc (nn.Module): Q-function network being trained.
- qfunc_target (nn.Module): Q-function target network.
- policy_target (nn.Module): Policy target network.
- gamma (float): Discount factor.
Returns:
- loss_q (torch.Tensor): DDPG loss for the Q-function.
- loss_info (dict): Dictionary containing useful loss info for logging.
td3_policy_loss
[source]
td3_policy_loss
(states
:Tensor
,qfunc
:Module
,policy
:Module
)
Calculate policy loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477
Args:
- states (torch.Tensor): Input states to get policy loss for.
- qfunc (torch.Tensor): TD3 q-function network.
- policy (torch.Tensor): Policy network.
Returns:
- q_policy_loss (torch.Tensor): The TD3 policy loss term.
td3_qfunc_loss
[source]
td3_qfunc_loss
(data
:Tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
],qfunc1
:Module
,qfunc2
:Module
,qfunc1_target
:Module
,qfunc2_target
:Module
,policy
:Module
,act_limit
:Union
[float
,int
],target_noise
:Optional
[float
]=0.2
,noise_clip
:Optional
[float
]=0.5
,gamma
:Optional
[float
]=0.99
)
Calculate Q-function loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477
Args:
- data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the following: (states, next_states, actions, rewards, dones).
- qfunc1 (nn.Module): First Q-function network being trained.
- qfunc2 (nn.Module): Other Q-function network being trained.
- qfunc1_target (nn.Module): First Q-function target network.
- qfunc2_target (nn.Module): Other Q-function target network.
- policy (nn.Module): Policy network.
- act_limit (float or int): Action limit from the environment.
- target_noise (float): Noise to apply to policy target network.
- noise_clip (float): Clip the noise within + and - this range.
- gamma (float): Gamma discount factor.
Returns:
- loss_q (torch.Tensor): TD3 loss for the Q-function.
- loss_info (dict): Dictionary containing useful loss info for logging.
sac_policy_loss
[source]
sac_policy_loss
(states
:Tensor
,qfunc1
:Module
,qfunc2
:Module
,policy
:Module
,alpha
:Optional
[float
]=0.2
)
Calculate policy loss for Soft-Actor Critic agent. See paper here: https://arxiv.org/abs/1801.01290
Args:
- states (torch.Tensor): Input states for the policy.
- qfunc1 (nn.Module): First Q-function in SAC agent.
- qfunc2 (nn.Module): Second Q-function in SAC agent.
- policy (nn.Module): Policy network.
- alpha (float): alpha factor for entropy-regularized policy loss.
Returns:
- loss_policy (torch.Tensor): The policy loss term.
- policy_info (dict): Useful logging info for the policy.
sac_qfunc_loss
[source]
sac_qfunc_loss
(data
,qfunc1
:Module
,qfunc2
:Module
,qfunc1_target
:Module
,qfunc2_target
:Module
,policy
:Module
,gamma
:Optional
[float
]=0.99
,alpha
:Optional
[float
]=0.2
)
Q-function loss for Soft-Actor Critic agent.
Args:
- data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the following: (states, next_states, actions, rewards, dones).
- qfunc1 (nn.Module): First Q-function network being trained.
- qfunc2 (nn.Module): Other Q-function network being trained.
- qfunc1_target (nn.Module): First Q-function target network.
- qfunc2_target (nn.Module): Other Q-function target network.
- policy (nn.Module): Policy network.
- gamma (float): Gamma discount factor.
- alpha (float): Loss term alpha factor.
Returns:
- loss_q (torch.Tensor): SAC loss for the Q-function.
- loss_info (dict): Dictionary containing useful loss info for logging.