Skip to content

Instantly share code, notes, and snippets.

@kodekracker
Created January 6, 2024 17:19
Show Gist options
  • Save kodekracker/d0e738933675e2534f047ea463aefed3 to your computer and use it in GitHub Desktop.
Save kodekracker/d0e738933675e2534f047ea463aefed3 to your computer and use it in GitHub Desktop.
Multi dict data stucture utilities
"""
Source: https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/coretypes/multidict.py
"""
from abc import ABCMeta
from abc import abstractmethod
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from typing import TypeVar
from mitmproxy.coretypes import serializable
KT = TypeVar("KT")
VT = TypeVar("VT")
class _MultiDict(MutableMapping[KT, VT], metaclass=ABCMeta):
"""
A MultiDict is a dictionary-like data structure that supports multiple values per key.
"""
fields: tuple[tuple[KT, VT], ...]
"""The underlying raw datastructure."""
def __repr__(self):
fields = (repr(field) for field in self.fields)
return "{cls}[{fields}]".format(
cls=type(self).__name__, fields=", ".join(fields)
)
@staticmethod
@abstractmethod
def _reduce_values(values: Sequence[VT]) -> VT:
"""
If a user accesses multidict["foo"], this method
reduces all values for "foo" to a single value that is returned.
For example, HTTP headers are folded, whereas we will just take
the first cookie we found with that name.
"""
@staticmethod
@abstractmethod
def _kconv(key: KT) -> KT:
"""
This method converts a key to its canonical representation.
For example, HTTP headers are case-insensitive, so this method returns key.lower().
"""
def __getitem__(self, key: KT) -> VT:
values = self.get_all(key)
if not values:
raise KeyError(key)
return self._reduce_values(values)
def __setitem__(self, key: KT, value: VT) -> None:
self.set_all(key, [value])
def __delitem__(self, key: KT) -> None:
if key not in self:
raise KeyError(key)
key = self._kconv(key)
self.fields = tuple(
field for field in self.fields if key != self._kconv(field[0])
)
def __iter__(self) -> Iterator[KT]:
seen = set()
for key, _ in self.fields:
key_kconv = self._kconv(key)
if key_kconv not in seen:
seen.add(key_kconv)
yield key
def __len__(self) -> int:
return len({self._kconv(key) for key, _ in self.fields})
def __eq__(self, other) -> bool:
if isinstance(other, MultiDict):
return self.fields == other.fields
return False
def get_all(self, key: KT) -> list[VT]:
"""
Return the list of all values for a given key.
If that key is not in the MultiDict, the return value will be an empty list.
"""
key = self._kconv(key)
return [value for k, value in self.fields if self._kconv(k) == key]
def set_all(self, key: KT, values: list[VT]) -> None:
"""
Remove the old values for a key and add new ones.
"""
key_kconv = self._kconv(key)
new_fields: list[tuple[KT, VT]] = []
for field in self.fields:
if self._kconv(field[0]) == key_kconv:
if values:
new_fields.append((field[0], values.pop(0)))
else:
new_fields.append(field)
while values:
new_fields.append((key, values.pop(0)))
self.fields = tuple(new_fields)
def add(self, key: KT, value: VT) -> None:
"""
Add an additional value for the given key at the bottom.
"""
self.insert(len(self.fields), key, value)
def insert(self, index: int, key: KT, value: VT) -> None:
"""
Insert an additional value for the given key at the specified position.
"""
item = (key, value)
self.fields = self.fields[:index] + (item,) + self.fields[index:]
def keys(self, multi: bool = False):
"""
Get all keys.
If `multi` is True, one key per value will be returned.
If `multi` is False, duplicate keys will only be returned once.
"""
return (k for k, _ in self.items(multi))
def values(self, multi: bool = False):
"""
Get all values.
If `multi` is True, all values will be returned.
If `multi` is False, only the first value per key will be returned.
"""
return (v for _, v in self.items(multi))
def items(self, multi: bool = False):
"""
Get all (key, value) tuples.
If `multi` is True, all `(key, value)` pairs will be returned.
If False, only one tuple per key is returned.
"""
if multi:
return self.fields
else:
return super().items()
class MultiDict(_MultiDict[KT, VT], serializable.Serializable):
"""A concrete MultiDict, storing its own data."""
def __init__(self, fields=()):
super().__init__()
self.fields = tuple(tuple(i) for i in fields) # type: ignore
@staticmethod
def _reduce_values(values):
return values[0]
@staticmethod
def _kconv(key):
return key
def get_state(self):
return self.fields
def set_state(self, state):
self.fields = tuple(tuple(x) for x in state) # type: ignore
@classmethod
def from_state(cls, state):
return cls(state)
class MultiDictView(_MultiDict[KT, VT]):
"""
The MultiDictView provides the MultiDict interface over calculated data.
The view itself contains no state - data is retrieved from the parent on
request, and stored back to the parent on change.
"""
def __init__(self, getter, setter):
self._getter = getter
self._setter = setter
super().__init__()
@staticmethod
def _kconv(key):
# All request-attributes are case-sensitive.
return key
@staticmethod
def _reduce_values(values):
# We just return the first element if
# multiple elements exist with the same key.
return values[0]
@property # type: ignore
def fields(self):
return self._getter()
@fields.setter
def fields(self, value):
self._setter(value)
def copy(self) -> "MultiDict[KT,VT]":
return MultiDict(self.fields)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment