# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
from typing import Any, Dict, Optional
from gym import error
from gym.core import Env
from gym.envs.registration import EnvRegistry, EnvSpec
[docs]class MultitaskEnvSpec(EnvSpec): # type: ignore[misc]
def __init__(
self,
id: str,
entry_point: Optional[str] = None,
reward_threshold: Optional[int] = None,
kwargs: Optional[Dict[str, Any]] = None,
nondeterministic: bool = False,
max_episode_steps: Optional[int] = None,
test_kwargs: Optional[Dict[str, Any]] = None,
):
"""A specification for a particular instance of the environment.
Used to register the parameters for official evaluations.
Args:
id (str): The official environment ID
entry_point (Optional[str]): The Python entrypoint of the
environment class (e.g. module.name:Class)
reward_threshold (Optional[int]): The reward threshold before
the task is considered solved
kwargs (dict): The kwargs to pass to the environment class
nondeterministic (bool): Whether this environment is
non-deterministic even after seeding
max_episode_steps (Optional[int]): The maximum number of steps
that an episode can consist of
test_kwargs (Optional[Dict[str, Any]], optional): Dictionary
to specify parameters for automated testing. Defaults to
None.
"""
super().__init__(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
kwargs=kwargs,
)
self.test_kwargs = test_kwargs
def __repr__(self) -> str:
return f"MultitaskEnvSpec({self.id})"
@property
def kwargs(self) -> Dict[str, Any]:
return self._kwargs # type: ignore[no-any-return]
[docs]class MultiEnvRegistry(EnvRegistry): # type: ignore[misc]
def __init__(self) -> None:
super().__init__()
[docs] def register(self, id: str, **kwargs: Any) -> None:
if id in self.env_specs:
raise error.Error("Cannot re-register id: {}".format(id))
self.env_specs[id] = MultitaskEnvSpec(id, **kwargs)
# Have a global registry
mtenv_registry = MultiEnvRegistry()
[docs]def register(id: str, **kwargs: Any) -> None:
return mtenv_registry.register(id, **kwargs)
[docs]def make(id: str, **kwargs: Any) -> Env:
env = mtenv_registry.make(id, **kwargs)
assert isinstance(env, Env)
return env
[docs]def spec(id: str) -> MultitaskEnvSpec:
spec = mtenv_registry.spec(id)
assert isinstance(spec, MultitaskEnvSpec)
return spec