# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Wrapper to convert an environment into multitask environment."""
from typing import Any, Dict, List, Optional
from gym.core import Env
from gym.spaces.space import Space
from mtenv import MTEnv
from mtenv.utils import seeding
from mtenv.utils.types import (
ActionType,
EnvObsType,
ObsType,
StepReturnType,
TaskObsType,
TaskStateType,
)
[docs]class EnvToMTEnv(MTEnv):
def __init__(self, env: Env, task_observation_space: Space) -> None:
"""Wrapper to convert an environment into a multitak environment.
Args:
env (Env): Environment to wrap over.
task_observation_space (Space): Task observation space for the
resulting multitask environment.
"""
super().__init__(
action_space=env.action_space,
env_observation_space=env.observation_space,
task_observation_space=task_observation_space,
)
self.env = env
self.reward_range = self.env.reward_range
self.metadata = self.env.metadata
@property
def spec(self) -> Any:
return self.env.spec
[docs] @classmethod
def class_name(cls) -> str:
return cls.__name__
def _make_observation(self, env_obs: EnvObsType) -> ObsType:
return {"env_obs": env_obs, "task_obs": self.get_task_obs()}
[docs] def get_task_obs(self) -> TaskObsType:
return self._task_obs
[docs] def get_task_state(self) -> TaskStateType:
raise NotImplementedError
[docs] def set_task_state(self, task_state: TaskStateType) -> None:
raise NotImplementedError
[docs] def sample_task_state(self) -> TaskStateType:
raise NotImplementedError
[docs] def reset(self, **kwargs: Dict[str, Any]) -> ObsType:
self.assert_env_seed_is_set()
env_obs = self.env.reset(**kwargs)
return self._make_observation(env_obs=env_obs)
[docs] def reset_task_state(self) -> None:
self.set_task_state(task_state=self.sample_task_state())
[docs] def step(self, action: ActionType) -> StepReturnType:
env_obs, reward, done, info = self.env.step(action)
return (
self._make_observation(env_obs=env_obs),
reward,
done,
info,
)
[docs] def seed(self, seed: Optional[int] = None) -> List[int]:
self.np_random_env, seed = seeding.np_random(seed)
env_seeds = self.env.seed(seed)
if isinstance(env_seeds, list):
return [seed] + env_seeds
return [seed]
[docs] def render(self, mode: str = "human", **kwargs: Dict[str, Any]) -> Any:
"""Renders the environment."""
return self.env.render(mode, **kwargs)
[docs] def close(self) -> Any:
return self.env.close()
def __str__(self) -> str:
return f"{type(self).__name__}{self.env}"
def __repr__(self) -> str:
return str(self)
@property
def unwrapped(self) -> Env:
return self.env.unwrapped
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
return getattr(self.env, name)