Source code for tianshou.highlevel.trainer

from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar

from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.utils.string import ToStringMixin

TPolicy = TypeVar("TPolicy", bound=BasePolicy)


[docs]class TrainingContext: def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger): self.policy = policy self.envs = envs self.logger = logger
[docs]class TrainerEpochCallbackTrain(ToStringMixin, ABC): """Callback which is called at the beginning of each epoch."""
[docs] @abstractmethod def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: pass
[docs] def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: def fn(epoch: int, env_step: int) -> None: return self.callback(epoch, env_step, context) return fn
[docs]class TrainerEpochCallbackTest(ToStringMixin, ABC): """Callback which is called at the beginning of each epoch."""
[docs] @abstractmethod def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: pass
[docs] def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]: def fn(epoch: int, env_step: int | None) -> None: return self.callback(epoch, env_step, context) return fn
[docs]class TrainerStopCallback(ToStringMixin, ABC): """Callback indicating whether training should stop."""
[docs] @abstractmethod def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: """Determines whether training should stop. :param mean_rewards: the average undiscounted returns of the testing result :param context: the training context :return: True if the goal has been reached and training should stop, False otherwise """
[docs] def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: def fn(mean_rewards: float) -> bool: return self.should_stop(mean_rewards, context) return fn
[docs]@dataclass class TrainerCallbacks: """Container for callbacks used during training.""" epoch_callback_train: TrainerEpochCallbackTrain | None = None epoch_callback_test: TrainerEpochCallbackTest | None = None stop_callback: TrainerStopCallback | None = None