Source code for mushroom_rl.core.serialization

import sys
import json
import torch
import pickle
import numpy as np

from copy import deepcopy
from pathlib import Path

if sys.version_info >= (3, 7):
    from zipfile import ZipFile
else:
    from zipfile37 import ZipFile


[docs]class Serializable(object): """ Interface to implement serialization of a MushroomRL object. This provide load and save functionality to save the object in a zip file. It is possible to save the state of the agent with different levels of """
[docs] def save(self, path, full_save=False): """ Serialize and save the object to the given path on disk. Args: path (Path, str): Relative or absolute path to the object save location; full_save (bool): Flag to specify the amount of data to save for MushroomRL data structures. """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with ZipFile(path, 'w') as zip_file: self.save_zip(zip_file, full_save)
[docs] def save_zip(self, zip_file, full_save, folder=''): """ Serialize and save the agent to the given path on disk. Args: zip_file (ZipFile): ZipFile where te object needs to be saved; full_save (bool): flag to specify the amount of data to save for MushroomRL data structures; folder (string, ''): subfolder to be used by the save method. """ primitive_dictionary = dict() for att, method in self._save_attributes.items(): if not method.endswith('!') or full_save: method = method[:-1] if method.endswith('!') else method attribute = getattr(self, att) if hasattr(self, att) else None if attribute is not None: if method == 'primitive': primitive_dictionary[att] = attribute elif hasattr(self, '_save_{}'.format(method)): save_method = getattr(self, '_save_{}'.format(method)) file_name = "{}.{}".format(att, method) save_method(zip_file, file_name, attribute, full_save=full_save, folder=folder) else: raise NotImplementedError( "Method _save_{} is not implemented for class '{}'". format(method, self.__class__.__name__) ) config_data = dict( type=type(self), save_attributes=self._save_attributes, primitive_dictionary=primitive_dictionary ) self._save_pickle(zip_file, 'config', config_data, folder=folder)
[docs] @classmethod def load(cls, path): """ Load and deserialize the agent from the given location on disk. Args: path (Path, string): Relative or absolute path to the agents save location. Returns: The loaded agent. """ path = Path(path) if not path.exists(): raise ValueError("Path to load agent is not valid") with ZipFile(path, 'r') as zip_file: loaded_object = cls.load_zip(zip_file) return loaded_object
@classmethod def load_zip(cls, zip_file, folder=''): config_path = Serializable._append_folder(folder, 'config') object_type, save_attributes, primitive_dictionary = \ cls._load_pickle(zip_file, config_path).values() if object_type is list: return cls._load_list(zip_file, folder, primitive_dictionary['len']) else: loaded_object = object_type.__new__(object_type) setattr(loaded_object, '_save_attributes', save_attributes) for att, method in save_attributes.items(): mandatory = not method.endswith('!') method = method[:-1] if not mandatory else method file_name = Serializable._append_folder( folder, '{}.{}'.format(att, method) ) if method == 'primitive' and att in primitive_dictionary: setattr(loaded_object, att, primitive_dictionary[att]) elif file_name in zip_file.namelist() or \ (method == 'mushroom' and mandatory): load_method = getattr(cls, '_load_{}'.format(method)) if load_method is None: raise NotImplementedError('Method _load_{} is not' 'implemented'.format(method)) att_val = load_method(zip_file, file_name) setattr(loaded_object, att, att_val) else: setattr(loaded_object, att, None) loaded_object._post_load() return loaded_object @classmethod def _load_list(self, zip_file, folder, length): loaded_list = list() for i in range(length): element_folder = Serializable._append_folder(folder, str(i)) loaded_element = Serializable.load_zip(zip_file, element_folder) loaded_list.append(loaded_element) return loaded_list
[docs] def copy(self): """ Returns: A deepcopy of the agent. """ return deepcopy(self)
[docs] def _add_save_attr(self, **attr_dict): """ Add attributes that should be saved for an agent. Args: **attr_dict: dictionary of attributes mapped to the method that should be used to save and load them. If a "!" character is added at the end of the method, the field will be saved only if full_save is set to True. """ if not hasattr(self, '_save_attributes'): self._save_attributes = dict() self._save_attributes.update(attr_dict)
[docs] def _post_load(self): """ This method can be overwritten to implement logic that is executed after the loading of the agent. """ pass
@staticmethod def _append_folder(folder, name): if folder: return folder + '/' + name else: return name @staticmethod def _load_pickle(zip_file, name): with zip_file.open(name, 'r') as f: return pickle.load(f) @staticmethod def _load_numpy(zip_file, name): with zip_file.open(name, 'r') as f: return np.load(f) @staticmethod def _load_torch(zip_file, name): with zip_file.open(name, 'r') as f: return torch.load(f) @staticmethod def _load_json(zip_file, name): with zip_file.open(name, 'r') as f: return json.load(f) @staticmethod def _load_mushroom(zip_file, name): return Serializable.load_zip(zip_file, name) @staticmethod def _save_pickle(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: pickle.dump(obj, f, protocol=pickle.DEFAULT_PROTOCOL) @staticmethod def _save_numpy(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: np.save(f, obj) @staticmethod def _save_torch(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: torch.save(obj, f) @staticmethod def _save_json(zip_file, name, obj, folder, **_): path = Serializable._append_folder(folder, name) with zip_file.open(path, 'w') as f: string = json.dumps(obj) f.write(string.encode('utf8')) @staticmethod def _save_mushroom(zip_file, name, obj, folder, full_save): new_folder = Serializable._append_folder(folder, name) if isinstance(obj, list): config_data = dict( type=list, save_attributes=dict(), primitive_dictionary=dict(len=len(obj)) ) Serializable._save_pickle(zip_file, 'config', config_data, folder=new_folder) for i, element in enumerate(obj): element_folder = Serializable._append_folder(new_folder, str(i)) element.save_zip(zip_file, full_save=full_save, folder=element_folder) else: obj.save_zip(zip_file, full_save=full_save, folder=new_folder) @staticmethod def _get_serialization_method(class_name): if issubclass(class_name, Serializable): return 'mushroom' else: return 'pickle'