import warnings
import numpy as np
from mushroom_rl.core.serialization import Serializable
[docs]class MDPInfo(Serializable):
"""
This class is used to store the information of the environment.
"""
[docs] def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1):
"""
Constructor.
Args:
observation_space ([Box, Discrete]): the state space;
action_space ([Box, Discrete]): the action space;
gamma (float): the discount factor;
horizon (int): the horizon;
dt (float, 1e-1): the control timestep of the environment.
"""
self.observation_space = observation_space
self.action_space = action_space
self.gamma = gamma
self.horizon = horizon
self.dt = dt
self._add_save_attr(
observation_space='mushroom',
action_space='mushroom',
gamma='primitive',
horizon='primitive',
dt='primitive'
)
@property
def size(self):
"""
Returns:
The sum of the number of discrete states and discrete actions. Only works for discrete spaces.
"""
return self.observation_space.size + self.action_space.size
@property
def shape(self):
"""
Returns:
The concatenation of the shape tuple of the state and action spaces.
"""
return self.observation_space.shape + self.action_space.shape
[docs]class Environment(object):
"""
Basic interface used by any mushroom environment.
"""
[docs] @classmethod
def register(cls):
"""
Register an environment in the environment list.
"""
env_name = cls.__name__
if env_name not in Environment._registered_envs:
Environment._registered_envs[env_name] = cls
[docs] @staticmethod
def list_registered():
"""
List registered environments.
Returns:
The list of the registered environments.
"""
return list(Environment._registered_envs.keys())
[docs] @staticmethod
def make(env_name, *args, **kwargs):
"""
Generate an environment given an environment name and parameters.
The environment is created using the generate method, if available. Otherwise, the constructor is used.
The generate method has a simpler interface than the constructor, making it easier to generate a standard
version of the environment. If the environment name contains a '.' separator, the string is splitted, the first
element is used to select the environment and the other elements are passed as positional parameters.
Args:
env_name (str): Name of the environment,
*args: positional arguments to be provided to the environment generator;
**kwargs: keyword arguments to be provided to the environment generator.
Returns:
An instance of the constructed environment.
"""
if '.' in env_name:
env_data = env_name.split('.')
env_name = env_data[0]
args = env_data[1:] + list(args)
env = Environment._registered_envs[env_name]
if hasattr(env, 'generate'):
return env.generate(*args, **kwargs)
else:
return env(*args, **kwargs)
[docs] def __init__(self, mdp_info):
"""
Constructor.
Args:
mdp_info (MDPInfo): an object containing the info of the environment.
"""
self._mdp_info = mdp_info
[docs] def seed(self, seed):
"""
Set the seed of the environment.
Args:
seed (float): the value of the seed.
"""
if hasattr(self, 'env') and hasattr(self.env, 'seed'):
self.env.seed(seed)
else:
warnings.warn('This environment has no custom seed. The call will have no effect. '
'You can set the seed manually by setting numpy/torch seed')
[docs] def reset(self, state=None):
"""
Reset the current state.
Args:
state (np.ndarray, None): the state to set to the current state.
Returns:
The current state.
"""
raise NotImplementedError
[docs] def step(self, action):
"""
Move the agent from its current state according to the action.
Args:
action (np.ndarray): the action to execute.
Returns:
The state reached by the agent executing ``action`` in its current state, the reward obtained in the
transition and a flag to signal if the next state is absorbing. Also, an additional dictionary is
returned (possibly empty).
"""
raise NotImplementedError
[docs] def render(self, record=False):
"""
Args:
record (bool, False): whether the visualized image should be returned or not.
Returns:
The visualized image, or None if the record flag is set to false.
"""
raise NotImplementedError
[docs] def stop(self):
"""
Method used to stop an mdp. Useful when dealing with real world environments, simulators, or when using
openai-gym rendering
"""
pass
@property
def info(self):
"""
Returns:
An object containing the info of the environment.
"""
return self._mdp_info
[docs] @staticmethod
def _bound(x, min_value, max_value):
"""
Method used to bound state and action variables.
Args:
x: the variable to bound;
min_value: the minimum value;
max_value: the maximum value;
Returns:
The bounded variable.
"""
return np.maximum(min_value, np.minimum(x, max_value))
_registered_envs = dict()