Source code for volumentations.core.serialization

from __future__ import absolute_import

import json
import warnings

from volumentations import __version__

try:
    import yaml

    yaml_available = True
except ImportError:
    yaml_available = False


__all__ = ["to_dict", "from_dict", "save", "load"]


SERIALIZABLE_REGISTRY = {}


class SerializableMeta(type):
    """
    A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY`
    so they can be found later while deserializing transformation pipeline
    using classes full names.
    """

    def __new__(cls, name, bases, class_dict):
        cls_obj = type.__new__(cls, name, bases, class_dict)
        SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
        return cls_obj


[docs]def to_dict(transform, on_not_implemented_error="raise"): """ Take a transform pipeline and convert it to a serializable representation that uses only standard python data types: dictionaries, lists, strings, integers, and floats. Args: transform (object): A transform that should be serialized. If the transform doesn't implement the `to_dict` method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised. If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored but no transform parameters will be serialized. """ if on_not_implemented_error not in {"raise", "warn"}: raise ValueError( "Unknown on_not_implemented_error value: {}. Supported values are:" "'raise' and 'warn'".format(on_not_implemented_error) ) try: transform_dict = transform._to_dict() # skipcq: PYL-W0212 except NotImplementedError as e: if on_not_implemented_error == "raise": raise e transform_dict = {} warnings.warn( "Got NotImplementedError while trying to serialize {obj}. Object arguments " "are not preserved. Implement either '{cls_name}.get_transform_init_args_names' " "or '{cls_name}.get_transform_init_args' " "method to make the transform serializable".format( obj=transform, cls_name=transform.__class__.__name__ ) ) return {"__version__": __version__, "transform": transform_dict}
def instantiate_lambda(transform, lambda_transforms=None): if transform.get("__type__") == "Lambda": name = transform["__name__"] if lambda_transforms is None: raise ValueError( "To deserialize a Lambda transform with name {name} you need to pass " "a dict with this transform " "as the `lambda_transforms` argument".format(name=name), ) transform = lambda_transforms.get(name) if transform is None: raise ValueError( "Lambda transform with {name} was not found in `lambda_transforms`".format( name=name ) ) return transform return None
[docs]def from_dict(transform_dict, lambda_transforms=None): """ Args: transform (dict): A dictionary with serialized transform pipeline. lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class. This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys in that dictionary should be named same as `name` arguments in respective lambda transforms from a serialized pipeline. """ transform = transform_dict["transform"] lmbd = instantiate_lambda(transform, lambda_transforms) if lmbd: return lmbd name = transform["__class_fullname__"] args = {k: v for k, v in transform.items() if k != "__class_fullname__"} cls = SERIALIZABLE_REGISTRY[name] if "transforms" in args: args["transforms"] = [ from_dict({"transform": t}, lambda_transforms=lambda_transforms) for t in args["transforms"] ] return cls(**args)
def check_data_format(data_format): if data_format not in {"json", "yaml"}: raise ValueError( "Unknown data_format {}. Supported formats are: 'json' and 'yaml'".format( data_format ) )
[docs]def save(transform, filepath, data_format="json", on_not_implemented_error="raise"): """ Take a transform pipeline, serialize it and save a serialized version to a file using either json or yaml format. Args: transform (obj): Transform to serialize. filepath (str): Filepath to write to. data_format (str): Serialization format. Should be either `json` or 'yaml'. on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be ignored and no transform arguments will be saved. """ check_data_format(data_format) transform_dict = to_dict( transform, on_not_implemented_error=on_not_implemented_error ) dump_fn = json.dump if data_format == "json" else yaml.safe_dump with open(filepath, "w") as f: dump_fn(transform_dict, f)
[docs]def load(filepath, data_format="json", lambda_transforms=None): """ Load a serialized pipeline from a json or yaml file and construct a transform pipeline. Args: transform (obj): Transform to serialize. filepath (str): Filepath to read from. data_format (str): Serialization format. Should be either `json` or 'yaml'. lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class. This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys in that dictionary should be named same as `name` arguments in respective lambda transforms from a serialized pipeline. """ check_data_format(data_format) load_fn = json.load if data_format == "json" else yaml.safe_load with open(filepath) as f: transform_dict = load_fn(f) return from_dict(transform_dict, lambda_transforms=lambda_transforms)