# Source code for mushroom.environments.generators.grid_world

```
import numpy as np
from mushroom.environments.finite_mdp import FiniteMDP
[docs]def generate_grid_world(grid, prob, pos_rew, neg_rew, gamma=.9, horizon=100):
"""
This Grid World generator requires a .txt file to specify the
shape of the grid world and the cells. There are five types of cells: 'S' is
the starting position where the agent is; 'G' is the goal state; '.' is a
normal cell; '*' is a hole, when the agent steps on a hole, it receives a
negative reward and the episode ends; '#' is a wall, when the agent is
supposed to step on a wall, it actually remains in its current state. The
initial states distribution is uniform among all the initial states
provided.
The grid is expected to be rectangular.
Args:
grid (str): the path of the file containing the grid structure;
prob (float): probability of success of an action;
pos_rew (float): reward obtained in goal states;
neg_rew (float): reward obtained in "hole" states;
gamma (float, .9): discount factor;
horizon (int, 100): the horizon.
Returns:
A FiniteMDP object built with the provided parameters.
"""
grid_map, cell_list = parse_grid(grid)
p = compute_probabilities(grid_map, cell_list, prob)
r = compute_reward(grid_map, cell_list, pos_rew, neg_rew)
mu = compute_mu(grid_map, cell_list)
return FiniteMDP(p, r, mu, gamma, horizon)
[docs]def parse_grid(grid):
"""
Parse the grid file:
Args:
grid (str): the path of the file containing the grid structure;
Returns:
A list containing the grid structure.
"""
grid_map = list()
cell_list = list()
with open(grid, 'r') as f:
m = f.read()
assert 'S' in m and 'G' in m
row = list()
row_idx = 0
col_idx = 0
for c in m:
if c in ['#', '.', 'S', 'G', '*']:
row.append(c)
if c in ['.', 'S', 'G', '*']:
cell_list.append([row_idx, col_idx])
col_idx += 1
elif c == '\n':
grid_map.append(row)
row = list()
row_idx += 1
col_idx = 0
else:
raise ValueError('Unknown marker.')
return grid_map, cell_list
[docs]def compute_probabilities(grid_map, cell_list, prob):
"""
Compute the transition probability matrix.
Args:
grid_map (list): list containing the grid structure;
cell_list (list): list of non-wall cells;
prob (float): probability of success of an action.
Returns:
The transition probability matrix;
"""
g = np.array(grid_map)
c = np.array(cell_list)
n_states = len(cell_list)
p = np.zeros((n_states, 4, n_states))
directions = [[-1, 0], [1, 0], [0, -1], [0, 1]]
for i in range(len(c)):
state = c[i]
if g[tuple(state)] in ['.', 'S']:
for a in range(len(directions)):
new_state = state + directions[a]
j = np.where((c == new_state).all(axis=1))[0]
if j.size > 0:
assert j.size == 1
p[i, a, i] = 1. - prob
p[i, a, j] = prob
else:
p[i, a, i] = 1.
return p
[docs]def compute_reward(grid_map, cell_list, pos_rew, neg_rew):
"""
Compute the reward matrix.
Args:
grid_map (list): list containing the grid structure;
cell_list (list): list of non-wall cells;
pos_rew (float): reward obtained in goal states;
neg_rew (float): reward obtained in "hole" states;
Returns:
The reward matrix.
"""
g = np.array(grid_map)
c = np.array(cell_list)
n_states = len(c)
r = np.zeros((n_states, 4, n_states))
directions = [[-1, 0], [1, 0], [0, -1], [0, 1]]
for goal in np.argwhere(g == 'G'):
j = np.where((c == goal).all(axis=1))[0]
for a in range(len(directions)):
prev_state = goal - directions[a]
if prev_state in c:
i = np.where((c == prev_state).all(axis=1))[0]
r[i, a, j] = pos_rew
for hole in np.argwhere(g == '*'):
j = np.where((c == hole).all(axis=1))[0]
for a in range(len(directions)):
prev_state = hole - directions[a]
if prev_state in c:
i = np.where((c == prev_state).all(axis=1))[0]
r[i, a, j] = neg_rew
return r
[docs]def compute_mu(grid_map, cell_list):
"""
Compute the initial states distribution.
Args:
grid_map (list): list containing the grid structure;
cell_list (list): list of non-wall cells.
Returns:
The initial states distribution.
"""
g = np.array(grid_map)
c = np.array(cell_list)
n_states = len(c)
mu = np.zeros(n_states)
starts = np.argwhere(g == 'S')
for s in starts:
i = np.where((c == s).all(axis=1))[0]
mu[i] = 1. / len(starts)
return mu
```