Source code for mushroom.environments.grid_world

import numpy as np
from scipy.stats import norm

from mushroom.environments import Environment, MDPInfo
from mushroom.utils import spaces
from mushroom.utils.viewer import Viewer


[docs]class AbstractGridWorld(Environment): """ Abstract class to build a grid world. """
[docs] def __init__(self, mdp_info, height, width, start, goal): """ Constructor. Args: height (int): height of the grid; width (int): width of the grid; start (tuple): x-y coordinates of the goal; goal (tuple): x-y coordinates of the goal. """ assert not np.array_equal(start, goal) assert goal[0] < height and goal[1] < width,\ 'Goal position not suitable for the grid world dimension.' self._state = None self._height = height self._width = width self._start = start self._goal = goal # Visualization self._viewer = Viewer(self._width, self._height, 500, self._height * 500 // self._width) super().__init__(mdp_info)
[docs] def reset(self, state=None): if state is None: state = self.convert_to_int(self._start, self._width) self._state = state return self._state
[docs] def step(self, action): state = self.convert_to_grid(self._state, self._width) new_state, reward, absorbing, info = self._step(state, action) self._state = self.convert_to_int(new_state, self._width) return self._state, reward, absorbing, info
def render(self): for row in range(1, self._height): for col in range(1, self._width): self._viewer.line(np.array([col, 0]), np.array([col, self._height])) self._viewer.line(np.array([0, row]), np.array([self._width, row])) goal_center = np.array([.5 + self._goal[1], self._height - (.5 + self._goal[0])]) self._viewer.square(goal_center, 0, 1, (0, 255, 0)) start_grid = self.convert_to_grid(self._start, self._width) start_center = np.array([.5 + start_grid[1], self._height - (.5 + start_grid[0])]) self._viewer.square(start_center, 0, 1, (255, 0, 0)) state_grid = self.convert_to_grid(self._state, self._width) state_center = np.array([.5 + state_grid[1], self._height - (.5 + state_grid[0])]) self._viewer.circle(state_center, .4, (0, 0, 255)) self._viewer.display(.1) def _step(self, state, action): raise NotImplementedError('AbstractGridWorld is an abstract class.') @staticmethod def convert_to_grid(state, width): return np.array([state[0] // width, state[0] % width]) @staticmethod def convert_to_int(state, width): return np.array([state[0] * width + state[1]])
[docs]class GridWorld(AbstractGridWorld): """ Standard grid world. """
[docs] def __init__(self, height, width, goal, start=(0, 0)): # MDP properties observation_space = spaces.Discrete(height * width) action_space = spaces.Discrete(4) horizon = 100 gamma = .9 mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) super().__init__(mdp_info, height, width, start, goal)
def _step(self, state, action): if action == 0: if state[0] > 0: state[0] -= 1 elif action == 1: if state[0] + 1 < self._height: state[0] += 1 elif action == 2: if state[1] > 0: state[1] -= 1 elif action == 3: if state[1] + 1 < self._width: state[1] += 1 if np.array_equal(state, self._goal): reward = 10 absorbing = True else: reward = 0 absorbing = False return state, reward, absorbing, {}
[docs]class GridWorldVanHasselt(AbstractGridWorld): """ A variant of the grid world as presented in: "Double Q-Learning". Hasselt H. V.. 2010. """
[docs] def __init__(self, height=3, width=3, goal=(0, 2), start=(2, 0)): # MDP properties observation_space = spaces.Discrete(height * width) action_space = spaces.Discrete(4) horizon = np.inf gamma = .95 mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) super().__init__(mdp_info, height, width, start, goal)
def _step(self, state, action): if np.array_equal(state, self._goal): reward = 5 absorbing = True else: if action == 0: if state[0] > 0: state[0] -= 1 elif action == 1: if state[0] + 1 < self._height: state[0] += 1 elif action == 2: if state[1] > 0: state[1] -= 1 elif action == 3: if state[1] + 1 < self._width: state[1] += 1 reward = np.random.choice([-12, 10]) absorbing = False return state, reward, absorbing, {}