Source code for mtenv.envs.hipbmdp.wrappers.sticky_observation

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Wrapper to enable sitcky observations for single task environments."""
# type: ignore
import random
from collections import deque

import gym


[docs]class StickyObservation(gym.Wrapper): def __init__(self, env: gym.Env, sticky_probability: float, last_k: int): """Env wrapper that returns a previous observation with probability `p` and the current observation with a probability `1-p`. `last_k` previous observations are stored. Args: env (gym.Env): Single task environment. sticky_probability (float): Probability `p` for returning a previous observation. last_k (int): Number of previous observations to store. Raises: ValueError: Raise a ValueError if `sticky_probability` is not in range `[0, 1]`. """ super().__init__(self, env) if 1 >= sticky_probability >= 0: self._sticky_probability = sticky_probability else: raise ValueError( f"sticky_probability = {sticky_probability} is not in the interval [0, 1]." ) self._last_k = last_k + 1 self._observations: deque = deque([], maxlen=self._last_k) self.observation_space = env.observation_space self._max_episode_steps = env._max_episode_steps
[docs] def reset(self): obs = self.env.reset() for _ in range(self._last_k): self._observations.append(obs) return self._get_obs()
[docs] def step(self, action): obs, reward, done, info = self.env.step(action) self._observations.append(obs) return self._get_obs(), reward, done, info
def _get_obs(self): assert len(self._observations) == self._last_k should_choose_old_observation = random.random() < self._sticky_probability if should_choose_old_observation: index = random.randint(0, self._last_k - 2) return self._observations[index] else: return self._observations[-1]