Source code for mtenv.envs.hipbmdp.dmc_env

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict

import gym
from gym.core import Env
from gym.envs.registration import register

from mtenv.envs.hipbmdp.wrappers import framestack, sticky_observation


def _build_env(
    domain_name: str,
    task_name: str,
    seed: int = 1,
    xml_file_id: str = "none",
    visualize_reward: bool = True,
    from_pixels: bool = False,
    height: int = 84,
    width: int = 84,
    camera_id: int = 0,
    frame_skip: int = 1,
    environment_kwargs: Any = None,
    episode_length: int = 1000,
) -> Env:
    if xml_file_id is None:
        env_id = "dmc_%s_%s_%s-v1" % (domain_name, task_name, seed)
    else:
        env_id = "dmc_%s_%s_%s_%s-v1" % (domain_name, task_name, xml_file_id, seed)

    if from_pixels:
        assert (
            not visualize_reward
        ), "cannot use visualize reward when learning from pixels"

    # shorten episode length
    max_episode_steps = (episode_length + frame_skip - 1) // frame_skip

    if env_id not in gym.envs.registry.env_specs:
        register(
            id=env_id,
            entry_point="mtenv.envs.hipbmdp.wrappers.dmc_wrapper:DMCWrapper",
            kwargs={
                "domain_name": domain_name,
                "task_name": task_name,
                "task_kwargs": {"random": seed, "xml_file_id": xml_file_id},
                "environment_kwargs": environment_kwargs,
                "visualize_reward": visualize_reward,
                "from_pixels": from_pixels,
                "height": height,
                "width": width,
                "camera_id": camera_id,
                "frame_skip": frame_skip,
            },
            max_episode_steps=max_episode_steps,
        )
    return gym.make(env_id)


[docs]def build_dmc_env( domain_name: str, task_name: str, seed: int, xml_file_id: str, visualize_reward: bool, from_pixels: bool, height: int, width: int, frame_skip: int, frame_stack: int, sticky_observation_cfg: Dict[str, Any], ) -> Env: """Build a single DMC environment as described in :cite:`tassa2020dmcontrol`. Args: domain_name (str): name of the domain. task_name (str): name of the task. seed (int): environment seed (for reproducibility). xml_file_id (str): id of the xml file to use. visualize_reward (bool): should visualize reward ? from_pixels (bool): return pixel observations? height (int): height of pixel frames. width (int): width of pixel frames. frame_skip (int): should skip frames? frame_stack (int): should stack frames together? sticky_observation_cfg (Dict[str, Any]): Configuration for using sticky observations. It should be a dictionary with three keys, `should_use` which specifies if the config should be used, `sticky_probability` which specifies the probability of choosing a previous task and `last_k` which specifies the number of previous frames to choose from. Returns: Env: """ env = _build_env( domain_name=domain_name, task_name=task_name, seed=seed, visualize_reward=visualize_reward, from_pixels=from_pixels, height=height, width=width, frame_skip=frame_skip, xml_file_id=xml_file_id, ) if from_pixels: env = framestack.FrameStack(env, k=frame_stack) if sticky_observation_cfg and sticky_observation_cfg["should_use"]: env = sticky_observation.StickyObservation( # type: ignore[attr-defined] env=env, sticky_probability=sticky_observation_cfg["sticky_probability"], last_k=sticky_observation_cfg["last_k"], ) return env