Skip to content

Instantly share code, notes, and snippets.

@fjetter
Last active April 30, 2021 09:56
Show Gist options
  • Save fjetter/25fae963c70c9b756b591213244af96a to your computer and use it in GitHub Desktop.
Save fjetter/25fae963c70c9b756b591213244af96a to your computer and use it in GitHub Desktop.
A draft for a possible new multi-layered spilling interface for dask.distributed
import io
from collections import defaultdict
from typing import Callable, Dict, List, MutableMapping, Set
from dask.sizeof import sizeof
class Data:
async def open(self) -> io.IOBase:
return io.StringIO()
async def close(self):
pass
class DataProcessor:
async def store(self, data: Data) -> Data:
# Compress, spill, remote store, etc.
return data
async def retrieve(self, data: Data) -> Data:
# Restore original
return data
class CompressionProcessor(DataProcessor):
pass
class FileProcessor(DataProcessor):
pass
class RemoteStorageProcessor(DataProcessor):
pass
class DataProxy:
def __init__(self, data: Data, processors: List[DataProcessor]):
# Indicate how many levels this data can be pushed down
self.max_level = len(processors)
# This is the size *in-memory*. For in-memory data this would be the raw
# amount, for remote storage this may be just the size of the key, etc.
# We probably want to store the raw size as well for tracking but that's
# not the point of this draft
self.size = sizeof(data)
self.processors = processors
# We keep a refcount to avoid duplicates and avoid spilling currently
# used data
self.refcount = 0
self.current_level = 0
self.raw = data
async def archive(self):
# Only allow archivation (spill, compression, etc.) if data is not
# currently in use
if self.refcount == 0:
processor = self.processors[self.current_level]
self.raw = await processor.store(self.raw)
self.size = sizeof(self.raw)
self.current_level += 1
return self.current_level
async def unarchive(self):
if self.current_level:
processor = self.processors[self.current_level]
self.raw = await processor.retrieve(self.raw)
self.size = sizeof(self.raw)
self.current_level -= 1
return self.current_level
async def open(self) -> io.IOBase:
self.refcount += 1
while not self.current_level is 0:
await self.unarchive()
return await self.raw.open()
async def close(self):
self.refcount -= 1
await self.raw.close()
def proxy_factory(data: Data) -> DataProxy:
# Maybe we want to have different processors per object type
if isinstance(data, object):
return DataProxy(
data,
processors=[
CompressionProcessor(),
FileProcessor(),
],
)
else:
raise RuntimeError()
Level = int
class SmartMutableBuffer(MutableMapping):
def __init__(
self,
data: Dict[str, Data],
target: int | Dict[Level, int],
factory: Callable[[Data], DataProxy],
):
self._data = {}
self.factory = factory
# We may want to limit targets by level, e.g. if we can move data
# through different layers of memory. That's a use case for GPUs, I
# believe. Standard users would probably simply define an integer
if isinstance(target, int):
# This is not accurate. If we have multiple processors which are
# using actual memory, we might need to refine this target a bit.
# But this should give the idea
self.target = {0: target}
else:
self.target = target
self.size_by_level: Dict[Level, int] = defaultdict(lambda: 0)
self.keys_by_level: Dict[Level, Set[str]] = defaultdict(set)
for k, value in data.items():
self[k] = value
def __getitem__(self, k: str) -> DataProxy:
return self._data[k]
def __setitem__(self, k: str, v: Data) -> None:
proxy = self.factory(v)
self._data[k] = proxy
self.size_by_level[proxy.current_level] += proxy.size
self.keys_by_level[proxy.current_level].add(k)
async def balance(self):
"""This balanaces all layers such that they fall below their given targets, if possible."""
# It may also be necessary to refactor this to allow for explicit
# evict calls. c.f. memory.target vs memory.spill
# This is just an example and it does not have a valid exit condition
# and may therefore never stop but should give an idea. This could be
# further optimized with various LRU / size / whatever policies
for level, target in self.target.items():
# If we are breaching the desired target for a given level, start
# archiving and move data to lower levels
while self.size_by_level[level] > target:
if not self.keys_by_level[level]:
break
k = self.keys_by_level[level].pop()
proxy = self._data[k]
self.size_by_level[level] -= proxy.size
new_level = await proxy.archive()
self.size_by_level[new_level] += proxy.size
self.keys_by_level[new_level].add(k)
# If we are otherwise way below the target, we may unarchive stuff
# to keep data in hot storage since eventually we'll need it again
while self.size_by_level[level] < target * 0.7:
pass
# do the same thing with archive to keep the data in mostly hot
# storage
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment