I want my dataclass to have a field that can either be provided manually, or if it isn't, it is inferred at initialization from the other fields. MWE:
from collections.abc import Sized
from dataclasses import dataclass
from typing import Optional
@dataclass
class Foo:
data: Sized
index: Optional[list[int]] = None
def __post_init__(self):
if self.index is None:
self.index = list(range(len(self.data)))
reveal_type(Foo.index) # Union[None, list[int]]
reveal_type(Foo([1,2,3]).index) # Union[None, list[int]]
How can this be implemented in a way such that:
- It complies with
mypytype checking indexis guaranteed to be of typelist[int]
I considered using default_factory(list), however, then how does one distinguish the User passing index=[] from the sentinel value? Is there a proper solution besides doing
index: list[int] = None # type: ignore[assignment]
CodePudding user response:
Use NotImplemented
from collections.abc import Sized
from dataclasses import dataclass
@dataclass
class Foo:
data: Sized
index: list[int] = NotImplemented
def __post_init__(self):
if self.index is NotImplemented:
self.index = list(range(len(self.data)))
CodePudding user response:
You can have the default_factory return a list with a sentinel object as its only element. You just need to make sure that the sentinel is an instance of int, otherwise mypy will complain. Luckily we have identity comparisons to ensure that the check in __post_init__ is always correct.
from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
@staticmethod
def _idx_sentinel_factory() -> list[int]:
return [Foo._idx_sentinel]
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))
I put the entire factory and sentinel logic inside of Foo, but if you don't like that, you can also factor it out:
from collections.abc import Sized
from dataclasses import dataclass, field
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
def _idx_sentinel_factory() -> list[int]:
return [_idx_sentinel]
@dataclass
class Foo:
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is _idx_sentinel:
self.index = list(range(len(self.data)))
