Source code for mushroom_rl.environments.atari

from copy import deepcopy
from collections import deque

import gym

from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.utils.spaces import *
from mushroom_rl.utils.frames import LazyFrames, preprocess_frame


class MaxAndSkip(gym.Wrapper):
    def __init__(self, env, skip, max_pooling=True):
        gym.Wrapper.__init__(self, env)
        self._obs_buffer = np.zeros((2,) + env.observation_space.shape,
                                    dtype=np.uint8)
        self._skip = skip
        self._max_pooling = max_pooling

    def reset(self):
        return self.env.reset()

    def step(self, action):
        total_reward = 0.
        for i in range(self._skip):
            obs, reward, absorbing, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if absorbing:
                break
        if self._max_pooling:
            frame = self._obs_buffer.max(axis=0)
        else:
            frame = self._obs_buffer.mean(axis=0)

        return frame, total_reward, absorbing, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


[docs]class Atari(Environment): """ The Atari environment as presented in: "Human-level control through deep reinforcement learning". Mnih et. al.. 2015. """
[docs] def __init__(self, name, width=84, height=84, ends_at_life=False, max_pooling=True, history_length=4, max_no_op_actions=30): """ Constructor. Args: name (str): id name of the Atari game in Gym; width (int, 84): width of the screen; height (int, 84): height of the screen; ends_at_life (bool, False): whether the episode ends when a life is lost or not; max_pooling (bool, True): whether to do max-pooling or average-pooling of the last two frames when using NoFrameskip; history_length (int, 4): number of frames to form a state; max_no_op_actions (int, 30): maximum number of no-op action to execute at the beginning of an episode. """ # MPD creation if 'NoFrameskip' in name: self.env = MaxAndSkip(gym.make(name), history_length, max_pooling) else: self.env = gym.make(name) # MDP parameters self._img_size = (width, height) self._episode_ends_at_life = ends_at_life self._max_lives = self.env.unwrapped.ale.lives() self._lives = self._max_lives self._force_fire = None self._real_reset = True self._max_no_op_actions = max_no_op_actions self._history_length = history_length self._current_no_op = None assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP' # MDP properties action_space = Discrete(self.env.action_space.n) observation_space = Box( low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0])) horizon = np.inf # the gym time limit is used. gamma = .99 dt = 1/60 mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) super().__init__(mdp_info)
[docs] def reset(self, state=None): if self._real_reset: self._state = preprocess_frame(self.env.reset(), self._img_size) self._state = deque([deepcopy( self._state) for _ in range(self._history_length)], maxlen=self._history_length ) self._lives = self._max_lives self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE' self._current_no_op = np.random.randint(self._max_no_op_actions + 1) return LazyFrames(list(self._state), self._history_length)
[docs] def step(self, action): action = action[0] # Force FIRE action to start episodes in games with lives if self._force_fire: obs, _, _, _ = self.env.env.step(1) self._force_fire = False while self._current_no_op > 0: obs, _, _, _ = self.env.env.step(0) self._current_no_op -= 1 obs, reward, absorbing, info = self.env.step(action) self._real_reset = absorbing if info['lives'] != self._lives: if self._episode_ends_at_life: absorbing = True self._lives = info['lives'] self._force_fire = self.env.unwrapped.get_action_meanings()[ 1] == 'FIRE' self._state.append(preprocess_frame(obs, self._img_size)) return LazyFrames(list(self._state), self._history_length), reward, absorbing, info
[docs] def render(self, record=False): self.env.render(mode='human') if record: return self.env.render(mode='rgb_array') else: return None
[docs] def stop(self): self.env.close() self._real_reset = True
[docs] def set_episode_end(self, ends_at_life): """ Setter. Args: ends_at_life (bool): whether the episode ends when a life is lost or not. """ self._episode_ends_at_life = ends_at_life