Source code for tianshou.data.types
import numpy as np
import torch
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol, arr_type
[docs]class ObsBatchProtocol(BatchProtocol):
"""Observations of an environment that a policy can turn into actions.
Typically used inside a policy's forward
"""
obs: arr_type | BatchProtocol
info: arr_type
[docs]class RolloutBatchProtocol(ObsBatchProtocol):
"""Typically, the outcome of sampling from a replay buffer."""
obs_next: arr_type | BatchProtocol
act: arr_type
rew: np.ndarray
terminated: arr_type
truncated: arr_type
[docs]class BatchWithReturnsProtocol(RolloutBatchProtocol):
"""With added returns, usually computed with GAE."""
returns: arr_type
[docs]class PrioBatchProtocol(RolloutBatchProtocol):
"""Contains weights that can be used for prioritized replay."""
weight: np.ndarray | torch.Tensor
[docs]class RecurrentStateBatch(BatchProtocol):
"""Used by RNNs in policies, contains `hidden` and `cell` fields."""
hidden: torch.Tensor
cell: torch.Tensor
[docs]class ActBatchProtocol(BatchProtocol):
"""Simplest batch, just containing the action. Useful e.g., for random policy."""
act: arr_type
[docs]class ActStateBatchProtocol(ActBatchProtocol):
"""Contains action and state (which can be None), useful for policies that can support RNNs."""
state: dict | BatchProtocol | np.ndarray | None
[docs]class ModelOutputBatchProtocol(ActStateBatchProtocol):
"""In addition to state and action, contains model output: (logits)."""
logits: torch.Tensor
state: dict | BatchProtocol | np.ndarray | None
[docs]class FQFBatchProtocol(ModelOutputBatchProtocol):
"""Model outputs, fractions and quantiles_tau - specific to the FQF model."""
fractions: torch.Tensor
quantiles_tau: torch.Tensor
[docs]class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol):
"""Contains estimated advantages and values.
Returns are usually computed from GAE of advantages by adding the value.
"""
adv: torch.Tensor
v_s: torch.Tensor
[docs]class DistBatchProtocol(ModelOutputBatchProtocol):
"""Contains dist instances for actions (created by dist_fn).
Usually categorical or normal.
"""
dist: torch.distributions.Distribution
[docs]class DistLogProbBatchProtocol(DistBatchProtocol):
"""Contains dist objects that can be sampled from and log_prob of taken action."""
log_prob: torch.Tensor
[docs]class LogpOldProtocol(BatchWithAdvantagesProtocol):
"""Contains logp_old, often needed for importance weights, in particular in PPO.
Builds on batches that contain advantages and values.
"""
logp_old: torch.Tensor
[docs]class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol):
"""Contains taus for algorithms using quantile regression.
See e.g. https://arxiv.org/abs/1806.06923
"""
taus: torch.Tensor
[docs]class ImitationBatchProtocol(ActBatchProtocol):
"""Similar to other batches, but contains imitation_logits and q_value fields."""
state: dict | Batch | np.ndarray | None
q_value: torch.Tensor
imitation_logits: torch.Tensor