Skip to content

Instantly share code, notes, and snippets.

@Bibo-Joshi
Last active February 13, 2022 19:48
Show Gist options
  • Select an option

  • Save Bibo-Joshi/e5d74648566fb9b2d138cc6020f45def to your computer and use it in GitHub Desktop.

Select an option

Save Bibo-Joshi/e5d74648566fb9b2d138cc6020f45def to your computer and use it in GitHub Desktop.
A defaultdict that keeps track of the accessed keys.
"""This module contains subclasses of :class:`collections.abc.MutableMapping` that keeps track of the
keys that where accessed.
Warning:
Tested only manually.
"""
from typing import (
TypeVar,
DefaultDict,
Callable,
Set,
ClassVar,
Iterator,
Optional,
Union,
Tuple,
overload,
MutableMapping,
List,
Mapping,
)
from collections import defaultdict
from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue
_VT = TypeVar('_VT')
_KT = TypeVar('_KT')
_T = TypeVar('_T')
class TrackingDefaultDict(MutableMapping[_KT, _VT]):
"""DefaultDict that keeps track of which keys where accessed.
Note:
* ``key in tdd`` is not considered reading
* ``setdefault()`` is considered both reading and writing depending on
whether or not the key is present
* ``pop`` is only considered writing, since the value is deleted instead of being changed
Args:
default_factory (Callable): Default factory for missing entries
track_read (:obj:`bool`): Whether read access should be tracked. Deleting entries is
not considered reading.
track_write (:obj:`bool`): Whether write access should be tracked. Deleting entries is
considered writing.
"""
DELETED: ClassVar = object()
"""Special marker indicating that an entry was deleted."""
__slots__ = ('_data', '_write_access_keys', '_read_access_keys', 'track_read', 'track_write')
def __init__(self, default_factory: Callable[[], _VT], track_read: bool, track_write: bool):
# The default_factory argument for defaultdict is positional only!
self._data: DefaultDict[_KT, _VT] = defaultdict(default_factory)
self.track_read = track_read
self.track_write = track_write
self._write_access_keys: Set[_KT] = set()
self._read_access_keys: Set[_KT] = set()
def __track_read(self, key: Union[_KT, Set[_KT]]) -> None:
if self.track_read:
if isinstance(key, set):
self._read_access_keys |= key
else:
self._read_access_keys.add(key)
def __track_write(self, key: Union[_KT, Set[_KT]]) -> None:
if self.track_write:
if isinstance(key, set):
self._write_access_keys |= key
else:
self._write_access_keys.add(key)
def __repr__(self) -> str:
return repr(self._data)
def __str__(self) -> str:
return str(self._data)
def __eq__(self, other: object) -> bool:
return other == self._data
def pop_accessed_read_keys(self) -> Set[_KT]:
"""Returns all keys that were read-accessed since the last time this method was called."""
if not self.track_read:
raise RuntimeError('Not tracking read access!')
out = self._read_access_keys
self._read_access_keys = set()
return out
def pop_accessed_write_keys(self) -> Set[_KT]:
"""Returns all keys that were write-accessed since the last time this method was called."""
if not self.track_write:
raise RuntimeError('Not tracking write access!')
out = self._write_access_keys
self._write_access_keys = set()
return out
def pop_accessed_read_items(self) -> List[Tuple[_KT, _VT]]:
"""
Returns all keys & corresponding values as set of tuples that were read-accessed since
the last time this method was called.
"""
keys = self.pop_accessed_read_keys()
return [(key, self._data[key]) for key in keys]
def pop_accessed_write_items(self) -> List[Tuple[_KT, _VT]]:
"""
Returns all keys & corresponding values as set of tuples that were write-accessed since
the last time this method was called. If a key was deleted, the value will be
:attr:`DELETED`.
"""
keys = self.pop_accessed_write_keys()
return [(key, self._data[key] if key in self._data else self.DELETED) for key in keys]
# Implement abstract interface
def __getitem__(self, key: _KT) -> _VT:
item = self._data[key]
self.__track_read(key)
return item
def __setitem__(self, key: _KT, value: _VT) -> None:
self._data[key] = value
self.__track_write(key)
def __delitem__(self, key: _KT) -> None:
del self._data[key]
self.__track_write(key)
def __iter__(self) -> Iterator[_KT]:
for key in self._data:
self.__track_read(key)
yield key
def __len__(self) -> int:
return len(self._data)
def update_no_track(self, mapping: Mapping[_KT, _VT]) -> None:
return self._data.update(mapping)
# Override some methods so that they fit better with the read/write access book keeping
def __contains__(self, key: object) -> bool:
return key in self._data
# Mypy seems a bit inconsistent about what it wants as types for `default` and return value
# so we just ignore a bit
def pop( # type: ignore[override]
self, key: _KT, default: _VT = DEFAULT_NONE # type: ignore[assignment]
) -> _VT:
self.__track_write(key)
if isinstance(default, DefaultValue):
return self._data.pop(key)
return self._data.pop(key, default=default)
def clear(self) -> None:
self.__track_write(set(self._data.keys()))
self._data.clear()
# Mypy seems a bit inconsistent about what it wants as types for `default` and return value
# so we just ignore a bit
def setdefault(self: 'TrackingDefaultDict[_KT, _T]', key: _KT, default: _T = None) -> _T:
if key in self._data:
self.__track_read(key)
return self._data[key]
self.__track_write(key)
self._data[key] = default # type: ignore[assignment]
return default # type: ignore[return-value]
# Overriding to comply with the behavior of `defaultdict`
@overload
def get(self, key: _KT) -> Optional[_VT]: # pylint: disable=arguments-differ
...
@overload
def get(self, key: _KT, default: _T) -> _T: # pylint: disable=signature-differs
...
def get(self, key: _KT, default: _T = None) -> Optional[Union[_VT, _T]]:
if key in self:
return self[key]
return default
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment