Source code for tianshou.policy.modelfree.rainbow

from typing import Any

from torch import nn

from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy import C51Policy
from tianshou.utils.net.discrete import NoisyLinear


# TODO: this is a hacky thing interviewing side-effects and a return. Should improve.
def _sample_noise(model: nn.Module) -> bool:
    """Sample the random noises of NoisyLinear modules in the model.

    Returns True if at least one NoisyLinear submodule was found.

    :param model: a PyTorch module which may have NoisyLinear submodules.
    :returns: True if model has at least one NoisyLinear submodule;
        otherwise, False.
    """
    sampled_any_noise = False
    for m in model.modules():
        if isinstance(m, NoisyLinear):
            m.sample()
            sampled_any_noise = True
    return sampled_any_noise


# TODO: is this class worth keeping? It barely does anything
[docs]class RainbowPolicy(C51Policy): """Implementation of Rainbow DQN. arXiv:1710.02298. Same parameters as :class:`~tianshou.policy.C51Policy`. .. seealso:: Please refer to :class:`~tianshou.policy.C51Policy` for more detailed explanation. """
[docs] def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: _sample_noise(self.model) if self._target and _sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect return super().learn(batch, **kwargs)