trainer#
Source code: tianshou/highlevel/trainer.py
- class TrainerCallbacks(epoch_callback_train: TrainerEpochCallbackTrain | None = None, epoch_callback_test: TrainerEpochCallbackTest | None = None, stop_callback: TrainerStopCallback | None = None)[source]#
Container for callbacks used during training.
- epoch_callback_test: TrainerEpochCallbackTest | None = None#
- epoch_callback_train: TrainerEpochCallbackTrain | None = None#
- stop_callback: TrainerStopCallback | None = None#
- class TrainerEpochCallbackTest[source]#
Callback which is called at the beginning of each epoch.
- abstract callback(epoch: int, env_step: int | None, context: TrainingContext) None[source]#
- get_trainer_fn(context: TrainingContext) Callable[[int, int | None], None][source]#
- class TrainerEpochCallbackTrain[source]#
Callback which is called at the beginning of each epoch.
- abstract callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
- get_trainer_fn(context: TrainingContext) Callable[[int, int], None][source]#
- class TrainerStopCallback[source]#
Callback indicating whether training should stop.
- get_trainer_fn(context: TrainingContext) Callable[[float], bool][source]#
- abstract should_stop(mean_rewards: float, context: TrainingContext) bool[source]#
Determines whether training should stop.
- Parameters:
mean_rewards – the average undiscounted returns of the testing result
context – the training context
- Returns:
True if the goal has been reached and training should stop, False otherwise
- class TrainingContext(policy: TPolicy, envs: Environments, logger: BaseLogger)[source]#