Let's say I have an enum
class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"
I wanted to create a ColorDict class that works as a native python dictionary but only takes the Color enum or its corresponding string value as key.
d = ColorDict() # I want to implement a ColorDict class such that ...
d[Color.RED] = 123
d["RED"] = 456 # I want this to override the previous value
d[Color.RED] # ==> 456
d["foo"] = 789 # I want this to produce an KeyError exception
What's the "pythonic way" of implementing this ColorDict class? Shall I use inheritance (overriding python's native dict) or composition (keep a dict as a member)?
CodePudding user response:
A simple solution would be to slightly modify your Color object and then subclass dict to add a test for the key. I would do something like this:
class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"
@classmethod
def is_color(cls, color):
if isinstance(color, cls):
color=color.value
if not color in cls.__members__:
return False
else:
return True
class ColorDict(dict):
def __setitem__(self, k, v):
if Color.is_color(k):
super().__setitem__(Color(k), v)
else:
raise KeyError(f"Color {k} is not valid")
def __getitem__(self, k):
if isinstance(k, str):
k = Color(k.upper())
return super().__getitem__(k)
d = ColorDict()
d[Color.RED] = 123
d["RED"] = 456
d[Color.RED]
d["foo"] = 789
In the Color class, I have added a test function to return True or False if a color is/isn't in the allowed list. The upper() function puts the string in upper case so it can be compared to the pre-defined values.
Then I have subclassed the dict object to override the __setitem__ special method to include a test of the value passed, and an override of __getitem__ to convert any key passed as str into the correct Enum. Depending on the specifics of how you want to use the ColorDict class, you may need to override more functions. There's a good explanation of that here: How to properly subclass dict and override __getitem__ & __setitem__
CodePudding user response:
One way is to use the abstract base class collections.abc.MutableMapping, this way, you only need to override the abstract methods and then you can be sure that access always goes through your logic -- you can do this with dict too, but for example, overriding dict.__setitem__ will not affect dict.update, dict.setdefault etc... So you have to override those by hand too. Usually, it is easier to just use the abstract base class:
from collections.abc import MutableMapping
from enum import Enum
class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"
class ColorDict(MutableMapping):
def __init__(self): # could handle more ways of initializing but for simplicity...
self._data = {}
def __getitem__(self, item):
return self._data[color]
def __setitem__(self, item, value):
color = self._handle_item(item)
self._data[color] = value
def __delitem__(self, item):
del self._data[color]
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
def _handle_item(self, item):
try:
color = Color(item)
except ValueError:
raise KeyError(item) from None
return color
Note, you can also add:
def __repr__(self):
return repr(self._data)
For easier debugging.
An example in the repl:
In [3]: d = ColorDict() # I want to implement a ColorDict class such that ...
...:
...: d[Color.RED] = 123
...: d["RED"] = 456 # I want this to override the previous value
...: d[Color.RED] # ==> 456
Out[3]: 456
In [4]: d["foo"] = 789 # I want this to produce an KeyError exception
...:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-4-9cf80d6dd8b4> in <module>
----> 1 d["foo"] = 789 # I want this to produce an KeyError exception
<ipython-input-2-a0780e16594b> in __setitem__(self, item, value)
17
18 def __setitem__(self, item, value):
---> 19 color = self._handle_item(item)
20 self._data[color] = value
21
<ipython-input-2-a0780e16594b> in _handle_item(self, item)
34 color = Color(item)
35 except ValueError:
---> 36 raise KeyError(item) from None
37 return color
38 def __repr__(self): return repr(self._data)
KeyError: 'foo'
In [5]: d
Out[5]: {<Color.RED: 'RED'>: 456}
