Source code for mushroom_rl.core.serialization

import json
import torch
import pickle
import numpy as np

from copy import deepcopy
from pathlib import Path

from mushroom_rl.utils.torch import TorchUtils

from zipfile import ZipFile


[docs]class Serializable(object): """ Interface to implement serialization of a MushroomRL object. This provide load and save functionality to save the object in a zip file. It is possible to save the state of the agent with different levels of """
[docs] def save(self, path, full_save=False): """ Serialize and save the object to the given path on disk. Args: path (Path, str): Relative or absolute path to the object save location; full_save (bool): Flag to specify the amount of data to save for MushroomRL data structures. """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with ZipFile(path, 'w') as zip_file: self.save_zip(zip_file, full_save)
[docs] def save_zip(self, zip_file, full_save, folder=''): """ Serialize and save the agent to the given path on disk. Args: zip_file (ZipFile): ZipFile where te object needs to be saved; full_save (bool): flag to specify the amount of data to save for MushroomRL data structures; folder (string, ''): subfolder to be used by the save method. """ primitive_dictionary = dict() for att, method in self._save_attributes.items(): if not method.endswith('!') or full_save: method = method[:-1] if method.endswith('!') else method attribute = getattr(self, att) if hasattr(self, att) else None if attribute is not None: if method == 'primitive': primitive_dictionary[att] = attribute elif method == 'none': pass elif hasattr(self, '_save_{}'.format(method)): save_method = getattr(self, '_save_{}'.format(method)) file_name = "{}.{}".format(att, method) save_method(zip_file, file_name, attribute, full_save=full_save, folder=folder) else: raise NotImplementedError( "Method _save_{} is not implemented for class '{}'". format(method, self.__class__.__name__) ) config_data = dict( type=type(self), save_attributes=self._save_attributes, primitive_dictionary=primitive_dictionary ) self._save_pickle(zip_file, 'config', config_data, folder=folder)
[docs] @classmethod def load(cls, path): """ Load and deserialize the agent from the given location on disk. Args: path (Path, string): Relative or absolute path to the agents save location. Returns: The loaded agent. """ path = Path(path) if not path.exists(): raise ValueError("Path to load agent is not valid") with ZipFile(path, 'r') as zip_file: loaded_object = cls.load_zip(zip_file) return loaded_object
@classmethod def load_zip(cls, zip_file, folder=''): config_path = Serializable._append_folder(folder, 'config') try: object_type, save_attributes, primitive_dictionary = \ cls._load_pickle(zip_file, config_path).values() except KeyError: return None if object_type is list: return cls._load_list(zip_file, folder, primitive_dictionary['len']) else: loaded_object = object_type.__new__(object_type) setattr(loaded_object, '_save_attributes', save_attributes) for att, method in save_attributes.items(): mandatory = not method.endswith('!') method = method[:-1] if not mandatory else method file_name = Serializable._append_folder( folder, '{}.{}'.format(att, method) ) if method == 'primitive' and att in primitive_dictionary: setattr(loaded_object, att, primitive_dictionary[att]) elif file_name in zip_file.namelist() or \ (method == 'mushroom' and mandatory): load_method = getattr(cls, '_load_{}'.format(method)) if load_method is None: raise NotImplementedError('Method _load_{} is not' 'implemented'.format(method)) att_val = load_method(zip_file, file_name) setattr(loaded_object, att, att_val) else: setattr(loaded_object, att, None) loaded_object._post_load() return loaded_object @classmethod def _load_list(self, zip_file, folder, length): loaded_list = list() for i in range(length): element_folder = Serializable._append_folder(folder, str(i)) loaded_element = Serializable.load_zip(zip_file, element_folder) loaded_list.append(loaded_element) return loaded_list
[docs] def copy(self): """ Returns: A deepcopy of the agent. """ return deepcopy(self)
[docs] def _add_save_attr(self, **attr_dict): """ Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a "!" character is added at the end of the method, the field will be saved only if full_save is set to True. Args: **attr_dict: dictionary of attributes mapped to the method that should be used to save and load them. """ if not hasattr(self, '_save_attributes'): self._save_attributes = dict() self._save_attributes.update(attr_dict)
[docs] def _post_load(self): """ This method can be overwritten to implement logic that is executed after the loading of the agent. """ pass
@staticmethod def _append_folder(folder, name): if folder: return folder + '/' + name else: return name @staticmethod def _load_pickle(zip_file, name): with zip_file.open(name, 'r') as f: return pickle.load(f) @staticmethod def _load_numpy(zip_file, name): with zip_file.open(name, 'r') as f: return np.load(f) @staticmethod def _load_torch(zip_file, name): with zip_file.open(name, 'r') as f: return torch.load(f, map_location=TorchUtils.get_device()) @staticmethod def _load_json(zip_file, name): with zip_file.open(name, 'r') as f: return json.load(f) @staticmethod def _load_mushroom(zip_file, name): return Serializable.load_zip(zip_file, name) @staticmethod def _save_pickle(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: pickle.dump(obj, f, protocol=pickle.DEFAULT_PROTOCOL) @staticmethod def _save_numpy(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: np.save(f, obj) @staticmethod def _save_torch(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: torch.save(obj, f) @staticmethod def _save_json(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: string = json.dumps(obj) f.write(string.encode('utf8')) @staticmethod def _save_mushroom(zip_file, name, obj, folder, full_save): new_folder = Serializable._append_folder(folder, name) if isinstance(obj, list): config_data = dict( type=list, save_attributes=dict(), primitive_dictionary=dict(len=len(obj)) ) Serializable._save_pickle(zip_file, 'config', config_data, folder=new_folder) for i, element in enumerate(obj): element_folder = Serializable._append_folder(new_folder, str(i)) element.save_zip(zip_file, full_save=full_save, folder=element_folder) else: obj.save_zip(zip_file, full_save=full_save, folder=new_folder) @staticmethod def _get_serialization_method(class_name): if issubclass(class_name, Serializable): return 'mushroom' else: return 'pickle'