Source code for udataclasses.decorator

  1from . import source
  2from .constants import FACTORY_SENTINEL, FIELDS_NAME, MISSING
  3from .field import FrozenInstanceError
  4from .transform_spec import TransformSpec
  5
  6try:
  7    from collections.abc import Callable
  8    from typing import Any, TypeVar
  9
 10    T = TypeVar("T")
 11except ImportError:
 12    pass
 13
 14
[docs] 15def dataclass( 16 cls: type[T] | None = None, **kwargs: Any 17) -> type[T] | Callable[[type[T]], type[T]]: 18 """Decorator to transform a normal class into a dataclass.""" 19 20 def wrapper(cls: type[T]) -> type[T]: 21 return _dataclass(cls, **kwargs) 22 23 if cls is None: 24 # Decorator called with no arguments 25 return wrapper 26 27 # Decorator called with arguments 28 return wrapper(cls)
29 30 31def _dataclass( 32 cls: type[T], 33 *, 34 init: bool = True, 35 repr: bool = True, 36 eq: bool = True, 37 order: bool = False, 38 unsafe_hash: bool = False, 39 frozen: bool = False, 40) -> type[T]: 41 transform = TransformSpec( 42 cls, 43 init=init, 44 repr=repr, 45 eq=eq, 46 order=order, 47 unsafe_hash=unsafe_hash, 48 frozen=frozen, 49 ) 50 51 for name, value in make_methods(transform).items(): 52 setattr(cls, name, value) 53 54 # Store fields metadata 55 setattr(cls, FIELDS_NAME, {f.name: f for f in transform.fields}) 56 return cls 57 58 59def make_global_bindings(transform: TransformSpec) -> dict[str, Any]: 60 bindings: dict[str, Any] = { 61 "FrozenInstanceError": FrozenInstanceError, 62 "FACTORY_SENTINEL": FACTORY_SENTINEL, 63 } 64 for field in transform.fields: 65 if field.default is not MISSING: 66 bindings[field.default_value_name] = field.default 67 if field.default_factory is not MISSING: 68 bindings[field.default_value_name] = field.default_factory 69 return bindings 70 71 72def make_methods(transform: TransformSpec) -> dict[str, Any]: 73 global_bindings = make_global_bindings(transform) 74 methods: dict[str, Any] = {} 75 76 def add_method(code: str) -> None: 77 exec(code, global_bindings, methods) 78 79 for field in transform.fields: 80 add_method(source.getter(field)) 81 add_method(source.setter(field, transform.frozen)) 82 add_method(source.deleter(field, transform.frozen)) 83 84 if transform.init: 85 add_method(source.init(transform.fields, post_init=transform.post_init)) 86 if transform.repr: 87 add_method(source.repr(transform.fields)) 88 if transform.eq: 89 add_method(source.eq(transform.fields)) 90 if transform.order: 91 add_method(source.lt(transform.fields)) 92 add_method(source.le(transform.fields)) 93 add_method(source.gt(transform.fields)) 94 add_method(source.ge(transform.fields)) 95 96 if transform.hash is None: 97 methods["__hash__"] = None 98 if transform.hash: 99 add_method(source.hash(transform.fields)) 100 101 return methods