import time
import gym
from gym import spaces as gym_spaces
import numpy as np
try:
import pybullet_envs
pybullet_found = True
except ImportError:
pybullet_found = False
from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.utils.spaces import *
gym.logger.set_level(40)
[docs]class Gym(Environment):
"""
Interface for OpenAI Gym environments. It makes it possible to use every
Gym environment just providing the id, except for the Atari games that
are managed in a separate class.
"""
[docs] def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args=None,
**env_args):
"""
Constructor.
Args:
name (str): gym id of the environment;
horizon (int): the horizon. If None, use the one from Gym;
gamma (float, 0.99): the discount factor;
wrappers (list, None): list of wrappers to apply over the environment. It
is possible to pass arguments to the wrappers by providing
a tuple with two elements: the gym wrapper class and a
dictionary containing the parameters needed by the wrapper
constructor;
wrappers_args (list, None): list of list of arguments for each wrapper;
** env_args: other gym environment parameters.
"""
# MDP creation
self._not_pybullet = True
self._first = True
if pybullet_found and '- ' + name in pybullet_envs.getList():
import pybullet
pybullet.connect(pybullet.DIRECT)
self._not_pybullet = False
self.env = gym.make(name, **env_args)
if wrappers is not None:
if wrappers_args is None:
wrappers_args = [dict()] * len(wrappers)
for wrapper, args in zip(wrappers, wrappers_args):
if isinstance(wrapper, tuple):
self.env = wrapper[0](self.env, *args, **wrapper[1])
else:
self.env = wrapper(self.env, *args, **env_args)
horizon = self._set_horizon(self.env, horizon)
# MDP properties
assert not isinstance(self.env.observation_space,
gym_spaces.MultiDiscrete)
assert not isinstance(self.env.action_space, gym_spaces.MultiDiscrete)
dt = self.env.unwrapped.dt if hasattr(self.env.unwrapped, "dt") else 0.1
action_space = self._convert_gym_space(self.env.action_space)
observation_space = self._convert_gym_space(self.env.observation_space)
mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt)
if isinstance(action_space, Discrete):
self._convert_action = lambda a: a[0]
else:
self._convert_action = lambda a: a
super().__init__(mdp_info)
[docs] def reset(self, state=None):
if state is None:
return np.atleast_1d(self.env.reset())
else:
self.env.reset()
self.env.state = state
return np.atleast_1d(state)
[docs] def step(self, action):
action = self._convert_action(action)
obs, reward, absorbing, info = self.env.step(action)
return np.atleast_1d(obs), reward, absorbing, info
[docs] def render(self, record=False):
if self._first or self._not_pybullet:
self.env.render(mode='human')
self._first = False
time.sleep(self.info.dt)
if record:
return self.env.render(mode='rgb_array')
else:
return None
return None
[docs] def stop(self):
try:
if self._not_pybullet:
self.env.close()
except:
pass
@staticmethod
def _set_horizon(env, horizon):
while not hasattr(env, '_max_episode_steps') and env.env != env.unwrapped:
env = env.env
if horizon is None:
if not hasattr(env, '_max_episode_steps'):
raise RuntimeError('This gym environment has no specified time limit!')
horizon = env._max_episode_steps
if hasattr(env, '_max_episode_steps'):
env._max_episode_steps = np.inf # Hack to ignore gym time limit.
return horizon
@staticmethod
def _convert_gym_space(space):
if isinstance(space, gym_spaces.Discrete):
return Discrete(space.n)
elif isinstance(space, gym_spaces.Box):
return Box(low=space.low, high=space.high, shape=space.shape)
else:
raise ValueError