# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Wrapper to (lazily) construct a multitask environment from a list of
constructors (list of functions to construct the environments)."""
from typing import Callable, List, Optional
from gym.core import Env
from gym.spaces.discrete import Discrete as DiscreteSpace
from mtenv import MTEnv
from mtenv.utils import seeding
from mtenv.utils.types import ActionType, EnvObsType, ObsType, StepReturnType
EnvBuilderType = Callable[[], Env]
TaskStateType = int
TaskObsType = int
[docs]class MultiEnvWrapper(MTEnv):
def __init__(
self,
funcs_to_make_envs: List[EnvBuilderType],
initial_task_state: TaskStateType,
) -> None:
"""Wrapper to (lazily) construct a multitask environment from a
list of constructors (list of functions to construct the
environments).
The wrapper enables activating/slecting any environment (from the
list of environments that can be created) and that environment is
treated as the current task. The environments are created lazily.
Note that this wrapper is experimental and may change in the future.
Args:
funcs_to_make_envs (List[EnvBuilderType]): list of constructor
functions to make the environments.
initial_task_state (TaskStateType): intial task/environment
to select.
"""
self._num_tasks = len(funcs_to_make_envs)
self._funcs_to_make_envs = funcs_to_make_envs
self._envs = [None for _ in range(self._num_tasks)]
self._envs[initial_task_state] = funcs_to_make_envs[initial_task_state]()
self.env: Env = self._envs[initial_task_state]
super().__init__(
action_space=self.env.action_space,
env_observation_space=self.env.observation_space,
task_observation_space=DiscreteSpace(n=self._num_tasks),
)
self.task_obs: TaskObsType = initial_task_state
def _make_observation(self, env_obs: EnvObsType) -> ObsType:
return {
"env_obs": env_obs,
"task_obs": self.task_obs,
}
[docs] def step(self, action: ActionType) -> StepReturnType:
env_obs, reward, done, info = self.env.step(action)
return self._make_observation(env_obs=env_obs), reward, done, info
[docs] def get_task_state(self) -> TaskStateType:
return self.task_obs
[docs] def set_task_state(self, task_state: TaskStateType) -> None:
self.task_obs = task_state
if self._envs[task_state] is None:
self._envs[task_state] = self._funcs_to_make_envs[task_state]()
self.env = self._envs[task_state]
[docs] def assert_env_seed_is_set(self) -> None:
"""The seed is set during the call to the constructor of self.env"""
pass
[docs] def assert_task_seed_is_set(self) -> None:
assert self.np_random_task is not None, "please call `seed_task()` first"
[docs] def reset(self) -> ObsType:
return self._make_observation(env_obs=self.env.reset())
[docs] def sample_task_state(self) -> TaskStateType:
self.assert_task_seed_is_set()
task_state = self.np_random_task.randint(self._num_tasks) # type: ignore[union-attr]
# The assert statement (at the start of the function) ensures that self.np_random_task
# is not None. Mypy is raising the warning incorrectly.
assert isinstance(task_state, int)
return task_state
[docs] def reset_task_state(self) -> None:
self.set_task_state(task_state=self.sample_task_state())
[docs] def seed(self, seed: Optional[int] = None) -> List[int]:
self.np_random_env, seed = seeding.np_random(seed)
env_seeds = self.env.seed(seed)
if isinstance(env_seeds, list):
return [seed] + env_seeds
return [seed]