How to Save and Load (Serializable interface)
In this tutorial, we explain in detail the Serializable
interface, i.e. the interface to save and load classes from
disk. We first explain how to use classes implementing the Serializable
interface, and then we provide a small
example of how to implement the Serializable
interface on a custom class to serialize the object properly on disk.
The Mushroom RL save format (extension .msh
) is nothing else than a zip file, containing some information (stored into
the config
file) to load the object. This information can be accessed easily and you can try to recover the information
by hand from corrupted files.
Note that it is always possible to serialize Python objects with the pickle library. However, the MushroomRL serialization interface use a native format, is easy to use, and is more robust to code changes, as it doesn’t serialize the entire class, but only the data. Furthermore, it is possible to avoid the serialization of some class variables, such as shared objects or big arrays, e.g. replay memories.
Save and load from disk
Many MushroomRL objects implement the serialization interface. All the algorithms, policies, approximators, and parameters
implemented in MushroomRL use the Serializable
interface.
As an example, we save a MushroomRL Parameter
on disk. We create the parameter and then we serialize it to disk
using the save
method of the serializable class:
from mushroom_rl.rl_utils.parameters import Parameter
parameter = Parameter(1.0)
print('Initial parameter value: ', parameter())
parameter.save('parameter.msh')
This code creates a parameters.msh
file in the working directory.
You can also specify a directory:
from pathlib import Path
base_dir = Path('tmp')
file_name = base_dir / 'parameter.msh'
parameter.save(file_name)
This create a tmp
folder (if it doesn’t exist) in the working directory and save the parameters.msh
file
inside it.
Now, we can set another value for our parameter variable:
parameter = Parameter(0.5)
print('Modified parameter value: ', parameter())
Finally, we load the previously stored parameter to go back to the previous state using the load
method:
parameter = Parameter.load('parameter.msh')
print('Loaded parameter value: ', parameter())
You can also call the load directly from the Serializable class:
from mushroom_rl.core import Serializable
parameter = Serializable.load('parameter.msh')
print('Loaded parameter value (Serializable): ', parameter())
The same approach can be used to save an agent, a policy, or an approximator.
Full Save
The save
method has an optional full_save
flag, which by default is set to False. In the previous parameter
example, this flag does not affect. However, when saving a Reinforcement Learning algorithm or other complex
objects, setting this flag to true forces the agent to save data structures that are normally excluded from a save
file, such as the replay memory in DQN.
This implementation choice avoids large save files for agents with huge data structures, and allows to avoid storing
duplicated information (such as the Q function of in epsilon greedy policy, when saving the algorithm).
The full_save
instead, enforces a complete serialization of the agent, retaining all the information.
Implementing the Serializable interface
We give a simple example of how to implement the Serializable
interface in MushroomRL for a custom class.
We use almost all possible data persistence types implemented.
We start the example by importing the serializable interface, the torch library, the NumPy library, and the MushroomRL
Parameter class
.
from mushroom_rl.core import Serializable
import torch
import numpy as np
from mushroom_rl.rl_utils.parameters import Parameter
While it is required to import the Serializable
interface, the other three imports are only required by this example, as
they are used to create variables of such type.
Now we define a class implementing the Serializable
interface. In this case, we call it TestClass
.
The constructor can be divided into two parts: first, we build a set of variables of different types.
Then, we call the superclass constructor, i.e. the constructor of Serializable
. Finally, we specify which variables
we want to be saved in the MushroomRL file passing keywords to the self._add_save_attr
method.
class TestClass(Serializable):
def __init__(self, value):
# Create some different types of variables
self._primitive_variable = value # Primitive python variable
self._numpy_vector = np.array([1, 2, 3]*value) # Numpy array
self._dictionary = dict(some='random', keywords=2, fill='the dictionary') # A dictionary
# Building a torch object
data_array = np.ones(3)*value
data_tensor = torch.from_numpy(data_array)
self._torch_object = torch.nn.Parameter(data_tensor)
# Some variables that implement the Serializable interface
self._mushroom_parameter = Parameter(2.0*value)
self._list_of_objects = [Parameter(i) for i in range(value)] # This is a list!
# A variable that is not important e.g. a buffer
self.not_important = np.zeros(10000)
# A variable that contains a reference to another variable
self._list_reference = [self._dictionary]
# Superclass constructor
super().__init__()
# Here we specify how to save each component
self._add_save_attr(
_primitive_variable='primitive',
_numpy_vector='numpy',
_dictionary='pickle',
_torch_object='torch',
_mushroom_parameter='mushroom',
# List of mushroom objects can also be saved with the 'mushroom' mode
_list_of_objects='mushroom',
# The '!' is to specify that we save the variable only if full_save is True
not_important='numpy!',
)
Some remarks about the self._add_save_attr
method: the keyword name must be the name of the variable we want to
store in the file, while the associated value is the method we want to use to store such variables.
The available methods are:
primitive, to store any primitive type. This includes lists and dictionaries of primitive values.
mushroom, to store any type implementing the Serializable interface. Also, lists of serializable objects are supported.
numpy, to store NumPy arrays.
torch, to store any torch object.
pickle, to store any Python object that cannot be stored with the above methods.
json, can be used if you need a textual output version, that is easy to read.
Another important aspect to remember is that any method can be ended by a ‘!’, to specify that the field must be
serialized if and only if the full_save
flag is set to true.
To conclude the implementation of our Serializable
interface, we might want to implement also the
self._post_load
method. This method is executed after all the data specified in self._add_save_attr
has been
loaded into the class. It can be useful to set the variables not saved in the file to a default variable.
def _post_load(self):
if self.not_important is None:
self.not_important = np.zeros(10000)
self._list_reference = [self._dictionary]
In this scenario, we have to set the self.not_important
variable to his default value, but only if it’s None, i.e.
has not been loaded from the file, because the file didn’t contain it.
Also, we set the `` self._list_primitive`` variable to maintain its original semantic, i.e. to contain a
reference to the content of the self._dictionary
variable.
To test the implementation, we write a function to write in easy to read way the content of the class:
def print_variables(obj):
for label, var in vars(obj).items():
if label != '_save_attributes':
if isinstance(var, Parameter):
print(f'{label}: Parameter({var()})')
elif isinstance(var, list) and isinstance(var[0], Parameter):
new_list = [f'Parameter({item()})' for item in var]
print(f'{label}: {new_list}')
else:
print(label, ': ', var)
Finally, we test the save functionality with the following code:
if __name__ == '__main__':
# Create test object and print its variables
test_object = TestClass(1)
print('###########################################################################################################')
print('The test object contains the following:')
print('-----------------------------------------------------------------------------------------------------------')
print_variables(test_object)
# Changing the buffer
test_object.not_important[0] = 1
# Save the object on disk
test_object.save('test.msh')
# Create another test object
test_object = TestClass(2)
print('###########################################################################################################')
print('After overwriting the test object:')
print('-----------------------------------------------------------------------------------------------------------')
print_variables(test_object)
# Changing the buffer again
test_object.not_important[0] = 1
# Save the other test object, this time remember buffer
test_object.save('test_full.msh', full_save=True)
# Load first test object and print its variables
print('###########################################################################################################')
test_object = TestClass.load('test.msh')
print('Loading previous test object:')
print('-----------------------------------------------------------------------------------------------------------')
print_variables(test_object)
# Load second test object and print its variables
print('###########################################################################################################')
test_object = TestClass.load('test_full.msh')
print('Loading previous test object:')
print('-----------------------------------------------------------------------------------------------------------')
print_variables(test_object)
We can see that the content of self.not_important
is stored only if the full_save
flag is set to true.
The last remark is that the Serializable
interface works also in presence of inheritance. If you extend a
serializable class, you only need to add the new attributes defined by the child class.