Source code for volumentations.core.composition

import random
from collections import defaultdict

import numpy as np
from volumentations.core.serialization import SERIALIZABLE_REGISTRY, SerializableMeta
from volumentations.core.six import add_metaclass
from volumentations.core.utils import format_args

__all__ = [
    "Compose",
    "OneOf",
    "OneOrOther",
    "ReplayCompose",
]


REPR_INDENT_STEP = 2


class Transforms:
    def __init__(self, transforms):
        self.transforms = transforms
        self.start_end = self._find_dual_start_end(transforms)

    def _find_dual_start_end(self, transforms):
        dual_start_end = None
        last_dual = None
        for idx, transform in enumerate(transforms):
            if isinstance(transform, BaseCompose):
                inside = self._find_dual_start_end(transform)
                if inside is not None:
                    last_dual = idx
                    if dual_start_end is None:
                        dual_start_end = [idx]
        if dual_start_end is not None:
            dual_start_end.append(last_dual)
        return dual_start_end

    def get_always_apply(self, transforms):
        new_transforms = []
        for transform in transforms:
            if isinstance(transform, BaseCompose):
                new_transforms.extend(self.get_always_apply(transform))
            elif transform.always_apply:
                new_transforms.append(transform)
        return Transforms(new_transforms)

    def __getitem__(self, item):
        return self.transforms[item]


def set_always_apply(transforms):
    for t in transforms:
        t.always_apply = True


@add_metaclass(SerializableMeta)
class BaseCompose:
    def __init__(self, transforms, p):
        self.transforms = Transforms(transforms)
        self.p = p

        self.replay_mode = False
        self.applied_in_replay = False

    def __getitem__(self, item):
        return self.transforms[item]

    def __repr__(self):
        return self.indented_repr()

    def indented_repr(self, indent=REPR_INDENT_STEP):
        args = {
            k: v
            for k, v in self._to_dict().items()
            if not (k.startswith("__") or k == "transforms")
        }
        repr_string = self.__class__.__name__ + "(["
        for t in self.transforms:
            repr_string += "\n"
            if hasattr(t, "indented_repr"):
                t_repr = t.indented_repr(indent + REPR_INDENT_STEP)
            else:
                t_repr = repr(t)
            repr_string += " " * indent + t_repr + ","
        repr_string += (
            "\n"
            + " " * (indent - REPR_INDENT_STEP)
            + "], {args})".format(args=format_args(args))
        )
        return repr_string

    @classmethod
    def get_class_fullname(cls):
        return "{cls.__module__}.{cls.__name__}".format(cls=cls)

    def _to_dict(self):
        return {
            "__class_fullname__": self.get_class_fullname(),
            "p": self.p,
            "transforms": [t._to_dict() for t in self.transforms],
        }

    def get_dict_with_id(self):
        return {
            "__class_fullname__": self.get_class_fullname(),
            "id": id(self),
            "params": None,
            "transforms": [t.get_dict_with_id() for t in self.transforms],
        }

    def add_targets(self, additional_targets):
        if additional_targets:
            for t in self.transforms:
                t.add_targets(additional_targets)

    def set_deterministic(self, flag, save_key="replay"):
        for t in self.transforms:
            t.set_deterministic(flag, save_key)


[docs]class Compose(BaseCompose): """Compose transforms and handle all transformations regrading bounding boxes. Args: transforms (list): list of transformations to compose. additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'} p (float): probability of applying all list of transforms. Default: 1.0. """ def __init__( self, transforms, additional_targets=None, p=1.0, ): super(Compose, self).__init__([t for t in transforms if t is not None], p) self.processors = {} if additional_targets is None: additional_targets = {} self.additional_targets = additional_targets for proc in self.processors.values(): proc.ensure_transforms_valid(self.transforms) self.add_targets(additional_targets) def __call__(self, force_apply=False, **data): need_to_run = force_apply or (random.random() < self.p) for p in self.processors.values(): p.ensure_data_valid(data) transforms = ( self.transforms if need_to_run else self.transforms.get_always_apply(self.transforms) ) dual_start_end = transforms.start_end if self.processors else None for idx, t in enumerate(transforms): if dual_start_end is not None and idx == dual_start_end[0]: for p in self.processors.values(): p.preprocess(data) data = t(force_apply=force_apply, **data) if dual_start_end is not None and idx == dual_start_end[1]: for p in self.processors.values(): p.postprocess(data) return data def _to_dict(self): dictionary = super(Compose, self)._to_dict() dictionary.update({"additional_targets": self.additional_targets}) return dictionary
[docs]class OneOf(BaseCompose): """Select one of transforms to apply. Args: transforms (list): list of transformations to compose. p (float): probability of applying selected transform. Default: 0.5. """ def __init__(self, transforms, p=0.5): super(OneOf, self).__init__(transforms, p) transforms_ps = [t.p for t in transforms] s = sum(transforms_ps) self.transforms_ps = [t / s for t in transforms_ps] def __call__(self, force_apply=False, **data): if self.replay_mode: for t in self.transforms: data = t(**data) return data if self.transforms_ps and (force_apply or random.random() < self.p): random_state = np.random.RandomState(random.randint(0, 2 ** 32 - 1)) t = random_state.choice(self.transforms.transforms, p=self.transforms_ps) data = t(force_apply=True, **data) return data
class OneOrOther(BaseCompose): def __init__(self, first=None, second=None, transforms=None, p=0.5): if transforms is None: transforms = [first, second] super(OneOrOther, self).__init__(transforms, p) def __call__(self, force_apply=False, **data): if self.replay_mode: for t in self.transforms: data = t(**data) return data if random.random() < self.p: return self.transforms[0](force_apply=True, **data) return self.transforms[-1](force_apply=True, **data)
[docs]class ReplayCompose(Compose): def __init__( self, transforms, additional_targets=None, p=1.0, save_key="replay", ): super(ReplayCompose, self).__init__(transforms, additional_targets, p) self.set_deterministic(True, save_key=save_key) self.save_key = save_key def __call__(self, force_apply=False, **kwargs): kwargs[self.save_key] = defaultdict(dict) result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs) serialized = self.get_dict_with_id() self.fill_with_params(serialized, result[self.save_key]) self.fill_applied(serialized) result[self.save_key] = serialized return result @staticmethod def replay(saved_augmentations, **kwargs): augs = ReplayCompose._restore_for_replay(saved_augmentations) return augs(force_apply=True, **kwargs) @staticmethod def _restore_for_replay(transform_dict): """Restores dictionary of transformtaions for replay. Args: transform_dict (dict): A dictionary with serialized transform pipeline. Returns: transform (dict): Transformed dictionary. """ transform = transform_dict applied = transform["applied"] params = transform["params"] name = transform["__class_fullname__"] args = { k: v for k, v in transform.items() if k not in ["__class_fullname__", "applied", "params"] } cls = SERIALIZABLE_REGISTRY[name] if "transforms" in args: args["transforms"] = [ ReplayCompose._restore_for_replay(t) for t in args["transforms"] ] transform = cls(**args) transform.params = params transform.replay_mode = True transform.applied_in_replay = applied return transform def fill_with_params(self, serialized, all_params): params = all_params.get(serialized.get("id")) serialized["params"] = params del serialized["id"] for transform in serialized.get("transforms", []): self.fill_with_params(transform, all_params) def fill_applied(self, serialized): if "transforms" in serialized: applied = [self.fill_applied(t) for t in serialized["transforms"]] serialized["applied"] = any(applied) else: serialized["applied"] = serialized.get("params") is not None return serialized["applied"] def _to_dict(self): raise NotImplementedError("You cannot serialize ReplayCompose")