Skip to content

Instantly share code, notes, and snippets.

@parity3
Created Jan 10, 2021
Embed
What would you like to do?
import functools
import logging
import os
import struct
from typing import Dict, cast, Optional, List
import cytoolz
from trio.hazmat import wait_readable
import inotify_simple
import trio
class InotifyWrapper:
def __init__(self, dir_to_watch: bytes, log_instance: logging.Logger):
self.dir_to_watch = dir_to_watch
self.log_instance = log_instance
self._wds = {} # type: Dict[int, bytes]
self._paths = {} # type: Dict[bytes, int]
self.ino = None # type: Optional[inotify_simple.INotify]
self.num_event_loops = 0
self.num_events = 0
async def __aenter__(self):
assert self.ino is None, "non-reentrant"
ino = await trio.to_thread.run_sync(self._get_inotify) # type: inotify_simple.INotify
self.ino = ino
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
ino = self.ino # type: Optional[inotify_simple.INotify]
self.ino = None
if ino is None:
return
ino.close()
async def __aiter__(self):
ino = self.ino
_EVENT_STRUCT_FORMAT = 'iIII'
_EVENT_STRUCT_SIZE = struct.calcsize(_EVENT_STRUCT_FORMAT)
Event = inotify_simple.Event
buf = bytearray()
fd = ino.fd
events = []
while True:
s = b' '
while s:
if len(buf) >= _EVENT_STRUCT_SIZE:
break
if events:
self.num_events += len(events)
yield events
events = []
await wait_readable(fd)
s = os.read(fd, 0x20000)
if s:
buf += s
else:
if buf:
self.log_instance.error('inotify stream terminated mid-packet')
return
wd, mask, cookie, namesize = struct.unpack_from(_EVENT_STRUCT_FORMAT, buf[:_EVENT_STRUCT_SIZE], 0)
del buf[:_EVENT_STRUCT_SIZE]
while s:
if len(buf) >= namesize:
break
await wait_readable(fd)
s = os.read(fd, 0x20000)
if s:
buf += s
else:
self.log_instance.error('inotify stream terminated mid-packet')
return
name, = struct.unpack('%ds' % namesize, cast(bytes, memoryview(buf)[:namesize]))
name = name.rstrip(b'\0') # type: bytes
del buf[:namesize]
try:
parent_path = self._wds[wd]
except KeyError:
self.log_instance.info(
'ignoring peculiar error where we received an event without its watched parent')
continue
if parent_path:
subpath = os.path.join(parent_path, name) # type: bytes
else:
subpath = name
self.num_event_loops += 1
# self.log_instance.info(f'event: {subpath}')
if mask & inotify_simple.flags.ISDIR:
if mask & inotify_simple.flags.DELETE or mask & inotify_simple.flags.MOVED_FROM:
await trio.to_thread.run_sync(self._rm_watch, ino, wd, subpath)
continue
if not mask & inotify_simple.flags.CREATE and not mask & inotify_simple.flags.MOVED_TO:
continue
await trio.to_thread.run_sync(self._add_dir_and_subdirs, ino, os.path.join(self.dir_to_watch, subpath))
else:
if mask & (inotify_simple.flags.CLOSE_WRITE | inotify_simple.flags.MOVED_TO):
events.append(Event(wd, mask, cookie, subpath))
def _get_inotify(self):
ino = inotify_simple.INotify()
dtw = self.dir_to_watch
self._add_dir_and_subdirs(ino, dtw)
return ino
mask = inotify_simple.flags.MODIFY | inotify_simple.flags.CLOSE_WRITE | inotify_simple.flags.ONLYDIR | inotify_simple.flags.CREATE | inotify_simple.flags.DELETE | inotify_simple.flags.MOVED_FROM | inotify_simple.flags.MOVED_TO
def _add_dir_and_subdirs(self, ino: inotify_simple.INotify, new_path: bytes) -> None:
self._add_dir(ino, new_path)
self._add_subdirs_recursive(ino, new_path)
def _add_dir(self, ino: inotify_simple.INotify, new_path: bytes) -> bool:
if new_path == self.dir_to_watch:
subpath = b''
else:
subpath = os.path.relpath(new_path, self.dir_to_watch)
if subpath in self._paths:
return False
try:
new_wd = ino.add_watch(new_path, self.mask)
except OSError as add_watch_error:
self.log_instance.info(f'ignoring {add_watch_error=}')
return False
self._wds[new_wd] = subpath
self._paths[subpath] = new_wd
return True
def _add_subdirs_recursive(self, ino, new_path):
dtw = self.dir_to_watch
for parent_path, dirs, names in os.walk(new_path): # type: bytes, List[bytes], List[bytes]
if not dirs:
continue
if parent_path == dtw: # make sure to remove any non-incoming_feeds dirs (like archive)
any(dirs.remove(d) for d in dirs if not d.startswith(b'a_'))
paths = map(functools.partial(os.path.join, parent_path), dirs)
num_added = cytoolz.count(filter(None, map(functools.partial(self._add_dir, ino), paths)))
self.log_instance.info(f'added {num_added} / {len(dirs)} dirs for {parent_path}, now at: {len(self._wds)}')
def _rm_watch(self, ino, wd, relpath):
subpath = relpath + b'/'
for path, swd in list(self._paths.items()):
if not path.startswith(subpath):
continue
try:
ino.rm_watch(swd)
except OSError as rm_watch_error:
self.log_instance.info(f'ignoring {rm_watch_error=} removing {path=}')
self._paths.pop(path, None)
self._wds.pop(swd, None)
try:
ino.rm_watch(wd)
except OSError as rm_watch_error:
self.log_instance.info(f'ignoring {rm_watch_error=} removing {relpath=}')
self._paths.pop(relpath, None)
self._wds.pop(wd, None)
self.log_instance.info(f'removed watch for: {relpath}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment