# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math

import numpy as np
from gym import logger, spaces

from mtenv import MTEnv
from mtenv.utils import seeding

Classic cart-pole system implemented based on Rich Sutton et al.
[docs]class MTCartPole(MTEnv): """A cartpole environment with varying physical values (see the self._mu_to_vars function) """ metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50} def _mu_to_vars(self, mu): self.gravity = 9.8 + mu[0] * 5 self.masscart = 1.0 + mu[1] * 0.5 self.masspole = 0.1 + mu[2] * 0.09 self.total_mass = self.masspole + self.masscart self.length = 0.5 + mu[3] * 0.3 self.polemass_length = self.masspole * self.length self.force_mag = 10 * mu[4] if mu[4] == 0: self.force_mag = 10 def __init__(self): # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds self.x_threshold = 2.4 self.theta_threshold_radians = 12 * 2 * math.pi / 360 high = np.array( [ self.x_threshold * 2, np.finfo(np.float32).max, self.theta_threshold_radians * 2, np.finfo(np.float32).max, ] ) observation_space = spaces.Box(-high, high, dtype=np.float32) action_space = spaces.Discrete(2) high = np.array([1.0 for k in range(5)]) task_space = spaces.Box(-high, high, dtype=np.float32) super().__init__( action_space=action_space, env_observation_space=observation_space, task_observation_space=task_space, ) self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 self.total_mass = self.masspole + self.masscart self.length = 0.5 # actually half the pole's length self.polemass_length = self.masspole * self.length self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates self.kinematics_integrator = "euler" # Angle at which to fail the episode self.state = None self.steps_beyond_done = None self.task_state = None
[docs] def step(self, action): self.t += 1 self._mu_to_vars(self.task_state) assert self.action_space.contains(action), "%r (%s) invalid" % ( action, type(action), ) state = self.state x, x_dot, theta, theta_dot = state force = self.force_mag if action == 1 else -self.force_mag costheta = math.cos(theta) sintheta = math.sin(theta) temp = ( force + self.polemass_length * theta_dot * theta_dot * sintheta ) / self.total_mass thetaacc = (self.gravity * sintheta - costheta * temp) / ( self.length * (4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass) ) xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass if self.kinematics_integrator == "euler": x = x + self.tau * x_dot x_dot = x_dot + self.tau * xacc theta = theta + self.tau * theta_dot theta_dot = theta_dot + self.tau * thetaacc else: # semi-implicit euler x_dot = x_dot + self.tau * xacc x = x + self.tau * x_dot theta_dot = theta_dot + self.tau * thetaacc theta = theta + self.tau * theta_dot self.state = [x, x_dot, theta, theta_dot] done = ( x < -self.x_threshold or x > self.x_threshold or theta < -self.theta_threshold_radians or theta > self.theta_threshold_radians ) done = bool(done) reward = 0 if not done: reward = 1.0 elif self.steps_beyond_done is None: # Pole just fell! self.steps_beyond_done = 0 reward = 1.0 else: if self.steps_beyond_done == 0: logger.warn( "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior." ) print( "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior." ) self.steps_beyond_done += 1 reward = 0.0 return ( {"env_obs": self.state, "task_obs": self.get_task_obs()}, reward, done, {}, )
[docs] def reset(self, **args): self.assert_env_seed_is_set() assert self.task_state is not None self._mu_to_vars(self.task_state) self.state = self.np_random_env.uniform(low=-0.05, high=0.05, size=(4,)) self.steps_beyond_done = None self.t = 0 return {"env_obs": self.state, "task_obs": self.get_task_obs()}
[docs] def get_task_obs(self): return self.task_state
[docs] def get_task_state(self): return self.task_state
[docs] def set_task_state(self, task_state): self.task_state = task_state
[docs] def sample_task_state(self): self.assert_task_seed_is_set() super().sample_task_state() new_task_state = [ self.np_random_task.uniform(-1, 1), self.np_random_task.uniform(-1, 1), self.np_random_task.uniform(-1, 1), self.np_random_task.uniform(-1, 1), self.np_random_task.uniform(-1, 1), ] return new_task_state
[docs] def seed(self, env_seed): self.np_random_env, seed = seeding.np_random(env_seed) return [seed]
[docs] def seed_task(self, task_seed): self.np_random_task, seed = seeding.np_random(task_seed) return [seed]
[docs]class CartPole(MTCartPole): """The original cartpole environment in the MTEnv fashion""" def __init__(self): super().__init__()
[docs] def sample_task_state(self): new_task_state = [0.0, 0.0, 0.0, 0.0, 0.0] return new_task_state
if __name__ == "__main__": env = MTCartPole() env.seed(5) env.seed_task(15) env.reset_task_state() obs = env.reset() print(obs) done = False while not done: obs, rew, done, _ = env.step(np.random.randint(env.action_space.n)) print(obs)