# Copyright (c) 2018 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import (List, Set, Dict, Type, TypeVar, Any, Union, Optional, Tuple, Iterator, Callable,
NewType)
import attr
import copy
from .....api import JSON
from .serializable import SerializerError, Serializable, GenericSerializable
from .obj import Obj, Lst
T = TypeVar("T")
T2 = TypeVar("T2")
Serializer = NewType("Serializer", Callable[[T], JSON])
Deserializer = NewType("Deserializer", Callable[[JSON], T])
serializer_map: Dict[Type[T], Serializer] = {}
deserializer_map: Dict[Type[T], Deserializer] = {}
[docs]def serializer(elem_type: Type[T]) -> Callable[[Serializer], Serializer]:
def decorator(func: Serializer) -> Serializer:
serializer_map[elem_type] = func
return func
return decorator
[docs]def deserializer(elem_type: Type[T]) -> Callable[[Deserializer], Deserializer]:
def decorator(func: Deserializer) -> Deserializer:
deserializer_map[elem_type] = func
return func
return decorator
def _fields(attrs_type: Type[T], only_if_flatten: bool = None) -> Iterator[Tuple[str, Type[T2]]]:
return ((field.metadata.get("json", field.name), field) for field in attr.fields(attrs_type)
if only_if_flatten is None or field.metadata.get("flatten", False) == only_if_flatten)
immutable = (int, str, float, bool, type(None))
def _safe_default(val: T) -> T:
if isinstance(val, immutable):
return val
return copy.copy(val)
def _dict_to_attrs(attrs_type: Type[T], data: JSON, default: Optional[T] = None,
default_if_empty: bool = False) -> T:
data = data or {}
unrecognized = {}
new_items = {field.name.lstrip("_"):
_try_deserialize(field.type, data, field.default,
field.metadata.get("ignore_errors", False))
for _, field in _fields(attrs_type, only_if_flatten=True)}
fields = dict(_fields(attrs_type, only_if_flatten=False))
for key, value in data.items():
try:
field = fields[key]
except KeyError:
unrecognized[key] = value
continue
name = field.name.lstrip("_")
new_items[name] = _try_deserialize(field.type, value, field.default,
field.metadata.get("ignore_errors", False))
if len(new_items) == 0 and default_if_empty:
return _safe_default(default)
try:
obj = attrs_type(**new_items)
except TypeError as e:
for json_key, field in _fields(attrs_type):
if not field.default and json_key not in new_items:
raise SerializerError(
f"Missing value for required key {field.name} in {attrs_type.__name__}") from e
raise SerializerError("Unknown serialization error") from e
if len(unrecognized) > 0:
obj.unrecognized_ = unrecognized
return obj
def _try_deserialize(cls: Type[T], value: JSON, default: Optional[T] = None,
ignore_errors: bool = False) -> T:
try:
return _deserialize(cls, value, default)
except SerializerError:
if not ignore_errors:
raise
except (TypeError, ValueError, KeyError) as e:
if not ignore_errors:
raise SerializerError("Unknown serialization error") from e
def _deserialize(cls: Type[T], value: JSON, default: Optional[T] = None) -> T:
if value is None:
return _safe_default(default)
cls = getattr(cls, "__supertype__", None) or cls
try:
return deserializer_map[cls](value)
except KeyError:
pass
if attr.has(cls):
if issubclass(cls, Serializable) and cls.deserialize.__func__ != SerializableAttrs.deserialize.__func__:
return cls.deserialize(value)
return _dict_to_attrs(cls, value, default, default_if_empty=True)
elif cls == Any or cls == JSON:
return value
elif getattr(cls, "__origin__", None) is Union:
if len(cls.__args__) == 2 and isinstance(None, cls.__args__[1]):
return _deserialize(cls.__args__[0], value, default)
elif isinstance(cls, type):
if issubclass(cls, Serializable):
return cls.deserialize(value)
elif issubclass(cls, List):
item_cls, = getattr(cls, "__args__", (None,))
return [_deserialize(item_cls, item) for item in value]
elif issubclass(cls, Set):
item_cls, = getattr(cls, "__args__", (None,))
return {_deserialize(item_cls, item) for item in value}
elif issubclass(cls, Dict):
key_cls, val_cls = getattr(cls, "__args__", (None, None))
return {key: _deserialize(val_cls, item) for key, item in value.items()}
if isinstance(value, list):
return Lst(value)
elif isinstance(value, dict):
return Obj(**value)
return value
def _attrs_to_dict(data: T) -> JSON:
new_dict = {}
for json_name, field in _fields(data.__class__):
if not json_name:
continue
field_val = getattr(data, field.name)
if field_val is None:
if not field.metadata.get("omitempty", True):
field_val = field.default
else:
continue
if field.metadata.get("omitdefault", False) and field_val == field.default:
continue
try:
serialized = serializer_map[field.type](field_val)
except KeyError:
serialized = _serialize(field_val)
if field.metadata.get("flatten", False) and isinstance(serialized, dict):
new_dict.update(serialized)
else:
new_dict[json_name] = serialized
try:
new_dict.update(data.unrecognized_)
except (AttributeError, TypeError):
pass
return new_dict
def _serialize(val: Any) -> JSON:
if isinstance(val, Serializable):
return val.serialize()
elif isinstance(val, (tuple, list, set)):
return [_serialize(subval) for subval in val]
elif isinstance(val, dict):
return {_serialize(subkey): _serialize(subval) for subkey, subval in val.items()}
elif attr.has(val.__class__):
return _attrs_to_dict(val)
return val
[docs]class SerializableAttrs(GenericSerializable[T]):
"""An abstract :class:`Serializable` that assumes the subclass"""
unrecognized_: Optional[JSON]
def __init__(self):
self.unrecognized_ = {}
[docs] @classmethod
def deserialize(cls, data: JSON) -> T:
return _dict_to_attrs(cls, data)
[docs] def serialize(self) -> JSON:
return _attrs_to_dict(self)
[docs] def get(self, item, default=None):
try:
return self[item]
except KeyError:
return default
def __getitem__(self, item):
try:
return getattr(self, item)
except AttributeError:
try:
return self.unrecognized_[item]
except AttributeError:
self.unrecognized_ = {}
raise KeyError(item)
def __setitem__(self, item, value):
if hasattr(self, item):
setattr(self, item, value)
else:
try:
self.unrecognized_[item] = value
except AttributeError:
self.unrecognized_ = {
item: value,
}