Utils

Angles

normalize_angle_positive(angle)[source]

Wrap the angle between 0 and 2 * pi.

Parameters:

angle (float) – angle to wrap.

Returns:

The wrapped angle.

normalize_angle(angle)[source]

Wrap the angle between -pi and pi.

Parameters:

angle (float) – angle to wrap.

Returns:

The wrapped angle.

shortest_angular_distance(from_angle, to_angle)[source]

Compute the shortest distance between two angles

Parameters:
  • from_angle (float) – starting angle;

  • to_angle (float) – final angle.

Returns:

The shortest distance between from_angle and to_angle.

quat_to_euler(quat)[source]

Convert a quaternion to euler angles.

Parameters:

quat (np.ndarray) – quaternion to be converted, must be in format [w, x, y, z]

Returns:

The euler angles [x, y, z] representation of the quaternion

euler_to_quat(euler)[source]

Convert euler angles into a quaternion.

Parameters:

euler (np.ndarray) – euler angles to be converted

Returns:

Quaternion in format [w, x, y, z]

mat_to_euler(mat)[source]

Convert a rotation matrix to euler angles.

Parameters:

mat (np.ndarray) – a 3d rotation matrix.

Returns:

The euler angles [x, y, z] representation of the quaternion

euler_to_mat(euler)[source]

Convert euler angles into a a rotation matrix.

Parameters:

euler (np.ndarray) – euler angles [x, y, z] to be converted.

Returns:

The rotation matrix representation of the euler angles

Features

uniform_grid(n_centers, low, high, eta=0.25, cyclic=False)[source]

This function is used to create the parameters of uniformly spaced radial basis functions with eta of overlap. It creates a uniformly spaced grid of n_centers[i] points in each dimension i. Also returns a vector containing the appropriate width of the radial basis functions.

Parameters:
  • n_centers (list) – number of centers of each dimension;

  • low (np.ndarray) – lowest value for each dimension;

  • high (np.ndarray) – highest value for each dimension;

  • eta (float, 0.25) – overlap between two radial basis functions;

  • cyclic (bool, False) – whether the state space is a ring or not

Returns:

The uniformly spaced grid and the width vector.

Frames

class LazyFrames(frames, history_length)[source]

Bases: object

From OpenAI Baseline. https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py

This class provides a solution to optimize the use of memory when concatenating different frames, e.g. Atari frames in DQN. The frames are individually stored in a list and, when numpy arrays containing them are created, the reference to each frame is used instead of a copy.

__init__(frames, history_length)[source]
preprocess_frame(obs, img_size)[source]

Convert a frame from rgb to grayscale and resize it.

Parameters:
  • obs (np.ndarray) – array representing an rgb frame;

  • img_size (tuple) – target size for images.

Returns:

The transformed frame as 8 bit integer array.

Minibatches

minibatch_number(size, batch_size)[source]

Function to retrieve the number of batches, given a batch sizes.

Parameters:
  • size (int) – size of the dataset;

  • batch_size (int) – size of the batches.

Returns:

The number of minibatches in the dataset.

minibatch_generator(batch_size, *dataset)[source]

Generator that creates a minibatch from the full dataset.

Parameters:
  • batch_size (int) – the maximum size of each minibatch;

  • dataset – the dataset to be splitted.

Returns:

The current minibatch.

Numerical gradient

numerical_diff_policy(policy, state, action, eps=1e-06)[source]

Compute the gradient of a policy in (state, action) numerically.

Parameters:
  • policy (Policy) – the policy whose gradient has to be returned;

  • state (np.ndarray) – the state;

  • action (np.ndarray) – the action;

  • eps (float, 1e-6) – the value of the perturbation.

Returns:

The gradient of the provided policy in (state, action) computed numerically.

numerical_diff_dist(dist, theta, eps=1e-06)[source]

Compute the gradient of a distribution in theta numerically.

Parameters:
  • dist (Distribution) – the distribution whose gradient has to be returned;

  • theta (np.ndarray) – the parametrization where to compute the gradient;

  • eps (float, 1e-6) – the value of the perturbation.

Returns:

The gradient of the provided distribution theta computed numerically.

numerical_diff_function(function, params, eps=1e-06)[source]

Compute the gradient of a function in theta numerically.

Parameters:
  • function – a function whose gradient has to be returned;

  • params – parameter vector w.r.t. we need to compute the gradient;

  • eps (float, 1e-6) – the value of the perturbation.

Returns:

The numerical gradient of the function computed w.r.t. parameters params.

Plots

get_mean_and_confidence(data)[source]

Compute the mean and 95% confidence interval

Parameters:

data (np.ndarray) – Array of experiment data of shape (n_runs, n_epochs).

Returns:

The mean of the dataset at each epoch along with the confidence interval.

plot_mean_conf(data, ax, color='blue', line='-', facecolor=None, alpha=0.4, label=None)[source]

Method to plot mean and confidence interval for data on matplotlib axes.

Parameters:
  • data (np.ndarray) – Array of experiment data of shape (n_runs, n_epochs);

  • ax (plt.Axes) – matplotlib axes where to create the curve;

  • color (str, 'blue') – matplotlib color identifier for the mean curve;

  • line (str, '-') – matplotlib line type to be used for the mean curve;

  • facecolor (str, None) – matplotlib color identifier for the confidence interval;

  • alpha (float, 0.4) – transparency of the confidence interval;

  • label (str, one) – legend label for the plotted curve.

Record

class VideoRecorder(path='./mushroom_rl_recordings', tag=None, video_name=None, fps=60)[source]

Bases: object

Simple video record that creates a video from a stream of images.

__init__(path='./mushroom_rl_recordings', tag=None, video_name=None, fps=60)[source]

Constructor.

Parameters:
  • path – Path at which videos will be stored.

  • tag – Name of the directory at path in which the video will be stored. If None, a timestamp will be created.

  • video_name – Name of the video without extension. Default is “recording”.

  • fps – Frame rate of the video.

__call__(frame)[source]
Parameters:

frame (np.ndarray) – Frame to be added to the video (H, W, RGB)

Torch

class CategoricalWrapper(*args: Any, **kwargs: Any)[source]

Bases: Categorical

Wrapper for the Torch Categorical distribution.

Needed to convert a vector of mushroom discrete action in an input with the proper shape of the original distribution implemented in torch

__init__(logits)[source]
__call__(*args: Any, **kwargs: Any) Any

Call self as a function.

Viewer

class ImageViewer(size, dt, headless=False)[source]

Bases: object

Interface to pygame for visualizing plain images.

__init__(size, dt, headless=False)[source]

Constructor.

Parameters:
  • size ([list, tuple]) – size of the displayed image;

  • dt (float) – duration of a control step.

  • headless (bool, False) – skip the display.

display(img)[source]

Display given frame.

Parameters:

img – image to display.

property size

Property.

Returns:

The size of the screen.

close()[source]

Close the viewer, destroy the window.

class Viewer(env_width, env_height, width=500, height=500, background=(0, 0, 0))[source]

Bases: object

Interface to pygame for visualizing mushroom native environments.

__init__(env_width, env_height, width=500, height=500, background=(0, 0, 0))[source]

Constructor.

Parameters:
  • env_width (float) – The x dimension limit of the desired environment;

  • env_height (float) – The y dimension limit of the desired environment;

  • width (int, 500) – width of the environment window;

  • height (int, 500) – height of the environment window;

  • background (tuple, (0, 0, 0)) – background color of the screen.

property screen

Property.

Returns:

The screen created by this viewer.

property size

Property.

Returns:

The size of the screen.

line(start, end, color=(255, 255, 255), width=1)[source]

Draw a line on the screen.

Parameters:
  • start (np.ndarray) – starting point of the line;

  • end (np.ndarray) – end point of the line;

  • color (tuple (255, 255, 255)) – color of the line;

  • width (int, 1) – width of the line.

square(center, angle, edge, color=(255, 255, 255), width=0)[source]

Draw a square on the screen and apply a roto-translation to it.

Parameters:
  • center (np.ndarray) – the center of the polygon;

  • angle (float) – the rotation to apply to the polygon;

  • edge (float) – length of an edge;

  • color (tuple, (255, 255, 255)) – the color of the polygon;

  • width (int, 0) – the width of the polygon line, 0 to fill the polygon.

polygon(center, angle, points, color=(255, 255, 255), width=0)[source]

Draw a polygon on the screen and apply a roto-translation to it.

Parameters:
  • center (np.ndarray) – the center of the polygon;

  • angle (float) – the rotation to apply to the polygon;

  • points (list) – the points of the polygon w.r.t. the center;

  • color (tuple, (255, 255, 255)) – the color of the polygon;

  • width (int, 0) – the width of the polygon line, 0 to fill the polygon.

circle(center, radius, color=(255, 255, 255), width=0)[source]

Draw a circle on the screen.

Parameters:
  • center (np.ndarray) – the center of the circle;

  • radius (float) – the radius of the circle;

  • color (tuple, (255, 255, 255)) – the color of the circle;

  • width (int, 0) – the width of the circle line, 0 to fill the circle.

arrow_head(center, scale, angle, color=(255, 255, 255))[source]

Draw an harrow head.

Parameters:
  • center (np.ndarray) – the position of the arrow head;

  • scale (float) – scale of the arrow, correspond to the length;

  • angle (float) – the angle of rotation of the angle head;

  • color (tuple, (255, 255, 255)) – the color of the arrow.

force_arrow(center, direction, force, max_force, max_length, color=(255, 255, 255), width=1)[source]

Draw a force arrow, i.e. an arrow representing a force. The length of the arrow is directly proportional to the force value.

Parameters:
  • center (np.ndarray) – the point where the force is applied;

  • direction (np.ndarray) – the direction of the force;

  • force (float) – the applied force value;

  • max_force (float) – the maximum force value;

  • max_length (float) – the length to use for the maximum force;

  • color (tuple, (255, 255, 255)) – the color of the arrow;

  • width (int, 1) – the width of the force arrow.

torque_arrow(center, torque, max_torque, max_radius, color=(255, 255, 255), width=1)[source]

Draw a torque arrow, i.e. a circular arrow representing a torque. The radius of the arrow is directly proportional to the torque value.

Parameters:
  • center (np.ndarray) – the point where the torque is applied;

  • torque (float) – the applied torque value;

  • max_torque (float) – the maximum torque value;

  • max_radius (float) – the radius to use for the maximum torque;

  • color (tuple, (255, 255, 255)) – the color of the arrow;

  • width (int, 1) – the width of the torque arrow.

background_image(img)[source]

Use the given image as background for the window, rescaling it appropriately.

Parameters:

img – the image to be used.

function(x_s, x_e, f, n_points=100, width=1, color=(255, 255, 255))[source]

Draw the graph of a function in the image.

Parameters:
  • x_s (float) – starting x coordinate;

  • x_e (float) – final x coordinate;

  • f (function) – the function that maps x coorinates into y coordinates;

  • n_points (int, 100) – the number of segments used to approximate the function to draw;

  • width (int, 1) – thw width of the line drawn;

  • color (tuple, (255,255,255)) – the color of the line.

static get_frame()[source]

Getter.

Returns:

The current Pygame surface as an RGB array.

display(s)[source]

Display current frame and initialize the next frame to the background color.

Parameters:

s – time to wait in visualization.

close()[source]

Close the viewer, destroy the window.

class CV2Viewer(window_name, dt, width, height)[source]

Bases: object

Simple viewer to display rendered images using cv2.

__init__(window_name, dt, width, height)[source]
display(img)[source]

Displays an image.

Parameters:

img (np.array) – Image to display

_wait()[source]

Wait for the specified amount of time. Time is supposed to be in milliseconds.

_window_was_closed()[source]

Check if a window was closed.

Returns:

True if the window was closed.