1from . import source
2from .constants import FACTORY_SENTINEL, FIELDS_NAME, MISSING
3from .field import FrozenInstanceError
4from .transform_spec import TransformSpec
5
6try:
7 from collections.abc import Callable
8 from typing import Any, TypeVar
9
10 T = TypeVar("T")
11except ImportError:
12 pass
13
14
[docs]
15def dataclass(
16 cls: type[T] | None = None, **kwargs: Any
17) -> type[T] | Callable[[type[T]], type[T]]:
18 """Decorator to transform a normal class into a dataclass."""
19
20 def wrapper(cls: type[T]) -> type[T]:
21 return _dataclass(cls, **kwargs)
22
23 if cls is None:
24 # Decorator called with no arguments
25 return wrapper
26
27 # Decorator called with arguments
28 return wrapper(cls)
29
30
31def _dataclass(
32 cls: type[T],
33 *,
34 init: bool = True,
35 repr: bool = True,
36 eq: bool = True,
37 order: bool = False,
38 unsafe_hash: bool = False,
39 frozen: bool = False,
40) -> type[T]:
41 transform = TransformSpec(
42 cls,
43 init=init,
44 repr=repr,
45 eq=eq,
46 order=order,
47 unsafe_hash=unsafe_hash,
48 frozen=frozen,
49 )
50
51 for name, value in make_methods(transform).items():
52 setattr(cls, name, value)
53
54 # Store fields metadata
55 setattr(cls, FIELDS_NAME, {f.name: f for f in transform.fields})
56 return cls
57
58
59def make_global_bindings(transform: TransformSpec) -> dict[str, Any]:
60 bindings: dict[str, Any] = {
61 "FrozenInstanceError": FrozenInstanceError,
62 "FACTORY_SENTINEL": FACTORY_SENTINEL,
63 }
64 for field in transform.fields:
65 if field.default is not MISSING:
66 bindings[field.default_value_name] = field.default
67 if field.default_factory is not MISSING:
68 bindings[field.default_value_name] = field.default_factory
69 return bindings
70
71
72def make_methods(transform: TransformSpec) -> dict[str, Any]:
73 global_bindings = make_global_bindings(transform)
74 methods: dict[str, Any] = {}
75
76 def add_method(code: str) -> None:
77 exec(code, global_bindings, methods)
78
79 for field in transform.fields:
80 add_method(source.getter(field))
81 add_method(source.setter(field, transform.frozen))
82 add_method(source.deleter(field, transform.frozen))
83
84 if transform.init:
85 add_method(source.init(transform.fields, post_init=transform.post_init))
86 if transform.repr:
87 add_method(source.repr(transform.fields))
88 if transform.eq:
89 add_method(source.eq(transform.fields))
90 if transform.order:
91 add_method(source.lt(transform.fields))
92 add_method(source.le(transform.fields))
93 add_method(source.gt(transform.fields))
94 add_method(source.ge(transform.fields))
95
96 if transform.hash is None:
97 methods["__hash__"] = None
98 if transform.hash:
99 add_method(source.hash(transform.fields))
100
101 return methods