Source code for tianshou.highlevel.env

from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any, TypeAlias, cast

import gymnasium as gym

from tianshou.env import (
    BaseVectorEnv,
    DummyVectorEnv,
    RayVectorEnv,
    ShmemVectorEnv,
    SubprocVectorEnv,
)
from tianshou.highlevel.persistence import Persistence
from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin

TObservationShape: TypeAlias = int | Sequence[int]


[docs]class EnvType(Enum): """Enumeration of environment types.""" CONTINUOUS = "continuous" DISCRETE = "discrete"
[docs] def is_discrete(self) -> bool: return self == EnvType.DISCRETE
[docs] def is_continuous(self) -> bool: return self == EnvType.CONTINUOUS
[docs] def assert_continuous(self, requiring_entity: Any) -> None: if not self.is_continuous(): raise AssertionError(f"{requiring_entity} requires continuous environments")
[docs] def assert_discrete(self, requiring_entity: Any) -> None: if not self.is_discrete(): raise AssertionError(f"{requiring_entity} requires discrete environments")
[docs]class VectorEnvType(Enum): DUMMY = "dummy" """Vectorized environment without parallelization; environments are processed sequentially""" SUBPROC = "subproc" """Parallelization based on `subprocess`""" SUBPROC_SHARED_MEM = "shmem" """Parallelization based on `subprocess` with shared memory""" RAY = "ray" """Parallelization based on the `ray` library"""
[docs] def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv: match self: case VectorEnvType.DUMMY: return DummyVectorEnv(factories) case VectorEnvType.SUBPROC: return SubprocVectorEnv(factories) case VectorEnvType.SUBPROC_SHARED_MEM: return ShmemVectorEnv(factories) case VectorEnvType.RAY: return RayVectorEnv(factories) case _: raise NotImplementedError(self)
[docs]class Environments(ToStringMixin, ABC): """Represents (vectorized) environments.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs self.persistence: Sequence[Persistence] = []
[docs] @staticmethod def from_factory_and_type( factory_fn: Callable[[], gym.Env], env_type: EnvType, venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, ) -> "Environments": """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). :param factory_fn: the factory for a single environment instance :param env_type: the type of environments created by `factory_fn` :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :return: the instance """ train_envs = venv_type.create_venv([factory_fn] * num_training_envs) test_envs = venv_type.create_venv([factory_fn] * num_test_envs) env = factory_fn() match env_type: case EnvType.CONTINUOUS: return ContinuousEnvironments(env, train_envs, test_envs) case EnvType.DISCRETE: return DiscreteEnvironments(env, train_envs, test_envs) case _: raise ValueError(f"Environment type {env_type} not handled")
def _tostring_includes(self) -> list[str]: return [] def _tostring_additional_entries(self) -> dict[str, Any]: return self.info()
[docs] def info(self) -> dict[str, Any]: return { "action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape(), }
[docs] def set_persistence(self, *p: Persistence) -> None: """Associates the given persistence handlers which may persist and restore environment-specific information. :param p: persistence handlers """ self.persistence = p
[docs] @abstractmethod def get_action_shape(self) -> TActionShape: pass
[docs] @abstractmethod def get_observation_shape(self) -> TObservationShape: pass
[docs] def get_action_space(self) -> gym.Space: return self.env.action_space
[docs] def get_observation_space(self) -> gym.Space: return self.env.observation_space
[docs] @abstractmethod def get_type(self) -> EnvType: pass
[docs]class ContinuousEnvironments(Environments): """Represents (vectorized) continuous environments.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
[docs] @staticmethod def from_factory( factory_fn: Callable[[], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, ) -> "ContinuousEnvironments": """Creates an instance from a factory function that creates a single instance. :param factory_fn: the factory for a single environment instance :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :return: the instance """ return cast( ContinuousEnvironments, Environments.from_factory_and_type( factory_fn, EnvType.CONTINUOUS, venv_type, num_training_envs, num_test_envs, ), )
[docs] def info(self) -> dict[str, Any]: d = super().info() d["max_action"] = self.max_action return d
@staticmethod def _get_continuous_env_info( env: gym.Env, ) -> tuple[tuple[int, ...], tuple[int, ...], float]: if not isinstance(env.action_space, gym.spaces.Box): raise ValueError( "Only environments with continuous action space are supported here. " f"But got env with action space: {env.action_space.__class__}.", ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore if not state_shape: raise ValueError("Observation space shape is not defined") action_shape = env.action_space.shape max_action = env.action_space.high[0] return state_shape, action_shape, max_action
[docs] def get_action_shape(self) -> TActionShape: return self.action_shape
[docs] def get_observation_shape(self) -> TObservationShape: return self.state_shape
[docs] def get_type(self) -> EnvType: return EnvType.CONTINUOUS
[docs]class DiscreteEnvironments(Environments): """Represents (vectorized) discrete environments.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
[docs] @staticmethod def from_factory( factory_fn: Callable[[], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, ) -> "DiscreteEnvironments": """Creates an instance from a factory function that creates a single instance. :param factory_fn: the factory for a single environment instance :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :return: the instance """ return cast( DiscreteEnvironments, Environments.from_factory_and_type( factory_fn, EnvType.CONTINUOUS, venv_type, num_training_envs, num_test_envs, ), )
[docs] def get_action_shape(self) -> TActionShape: return self.action_shape
[docs] def get_observation_shape(self) -> TObservationShape: return self.observation_shape
[docs] def get_type(self) -> EnvType: return EnvType.DISCRETE
[docs]class EnvFactory(ToStringMixin, ABC):
[docs] @abstractmethod def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: pass