I try to implement list which returns wrapped elements when accessing directly (by index) or via FOR IN loop statement. At present I have following code:
class ItemWrappedList(list):
def __getitem__(self, y):
result = super().__getitem__(y)
return f'<item>{result}</item>'
lst = ItemWrappedList(['1', '2', '3'])
for pos in range(len(lst)):
print(lst[pos])
#>>> <item>1</item>
#>>> <item>2</item>
#>>> <item>3</item>
for value in lst:
print(value)
#>>> 1
#>>> 2
#>>> 3
As you can see first loop works correctly, as intended. But second loop calls list methods that I was unable to override. Are there any possibilities to force ItemWrappedList class to have same behavior in all situations?
P.S. I tried different ways with implementing/overriding __next__ and __iter__ methods, but seems that __next__ ever not called.
CodePudding user response:
Overriding the __iter__ method should do the trick:
class ItemWrappedList(list):
def __getitem__(self, y):
result = super().__getitem__(y)
return f'<item>{result}</item>'
def __iter__(self):
return (f'<item>{x}</item>' for x in list.__iter__(self))
lst = ItemWrappedList(['1', '2', '3'])
for pos in range(len(lst)):
print(lst[pos])
for value in lst:
print(value)
<item>1</item>
<item>2</item>
<item>3</item>
<item>1</item>
<item>2</item>
<item>3</item>
CodePudding user response:
That's because when you say for i in lst Python first calls iter() on lst. If lst's class has implemented __iter__ method, it is the first thing that is considered. Otherwise it proceeds to see if it has implemented __getitem__. (Otherwise it raises TypeError: object is not iterable)
Since you inherit from list and the list class already defined __iter__, iter() always finds its __iter__ first not your __getitem__.
So just implement __iter__ for your class:
class ItemWrappedList(list):
def __getitem__(self, y):
result = super().__getitem__(y)
return f"<item>{result}</item>"
def __iter__(self):
return (f"<item>{item}</item>" for item in super().__iter__())
lst = ItemWrappedList(["1", "2", "3"])
for pos in range(len(lst)):
print(lst[pos])
for value in lst:
print(value)
