trainer#


class TrainerCallbacks(epoch_callback_train: TrainerEpochCallbackTrain | None = None, epoch_callback_test: TrainerEpochCallbackTest | None = None, stop_callback: TrainerStopCallback | None = None)[source]#

Container for callbacks used during training.

epoch_callback_test: TrainerEpochCallbackTest | None = None#
epoch_callback_train: TrainerEpochCallbackTrain | None = None#
stop_callback: TrainerStopCallback | None = None#
class TrainerEpochCallbackTest[source]#

Callback which is called at the beginning of each epoch.

abstract callback(epoch: int, env_step: int | None, context: TrainingContext) None[source]#
get_trainer_fn(context: TrainingContext) Callable[[int, int | None], None][source]#
class TrainerEpochCallbackTrain[source]#

Callback which is called at the beginning of each epoch.

abstract callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
get_trainer_fn(context: TrainingContext) Callable[[int, int], None][source]#
class TrainerStopCallback[source]#

Callback indicating whether training should stop.

get_trainer_fn(context: TrainingContext) Callable[[float], bool][source]#
abstract should_stop(mean_rewards: float, context: TrainingContext) bool[source]#

Determines whether training should stop.

Parameters:
  • mean_rewards – the average undiscounted returns of the testing result

  • context – the training context

Returns:

True if the goal has been reached and training should stop, False otherwise

class TrainingContext(policy: TPolicy, envs: Environments, logger: BaseLogger)[source]#