rainbow#
Source code: tianshou/policy/modelfree/rainbow.py
- class RainbowPolicy(*, model: Module, optim: Optimizer, action_space: Discrete, discount_factor: float = 0.99, num_atoms: int = 51, v_min: float = - 10.0, v_max: float = 10.0, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, observation_space: gymnasium.spaces.space.Space | None = None, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers | None = None)[source]#
Implementation of Rainbow DQN. arXiv:1710.02298.
Same parameters as
C51Policy.See also
Please refer to
C51Policyfor more detailed explanation.- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) dict[str, float][source]#
Update policy with a given batch of data.
- Returns:
A dict, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.trainingandself.updating. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normalandtorch.distributions.Categoricalto calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.