Skip to content

Instantly share code, notes, and snippets.

@jelmervdl
Created June 17, 2022 14:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jelmervdl/1474ee859d44f179370860c4241b0408 to your computer and use it in GitHub Desktop.
Save jelmervdl/1474ee859d44f179370860c4241b0408 to your computer and use it in GitHub Desktop.
Head + tail + random sample of file
import random
from math import exp, log, floor
def reservoir_sample(k, it, *, rand: random.Random = random._inst):
sample = []
numbered_it = enumerate(it)
for i, (_, line) in zip(range(k), numbered_it):
sample.append(line)
w = exp(log(rand.random())/k)
try:
while True:
next_i = i + floor(log(rand.random()) / log(1 - w)) + 1
# Skip forward
while i < next_i:
i, line = next(numbered_it)
sample[rand.randrange(k)] = line
w = w * exp(log(rand.random()) / k)
except StopIteration:
pass
return sample
class Tailer:
"""Functions as an iterator that returns all but the last K lines. Those lines
you can read from `tail`."""
def __init__(self, k, it):
self.sample = []
self.k = k
self.i = 0
self.it = iter(it)
def __iter__(self):
while self.i < self.k:
self.sample.append(next(self.it))
self.i += 1
for line in self.it:
yield self.sample[self.i % len(self.sample)]
self.sample[self.i % len(self.sample)] = line
self.i += 1
@property
def tail(self):
return self.sample[(self.i % len(self.sample)):] + self.sample[0:(self.i % len(self.sample))]
def sample(k, items):
it = iter(items)
head = [next(it) for _ in range(k)]
tailer = Tailer(k, it)
middle = reservoir_sample(k, tailer)
return head, middle, tailer.tail
if __name__ == '__main__':
import sys
import gzip
from contextlib import ExitStack
from itertools import count, chain
k = int(sys.argv[1])
with ExitStack() as ctx:
files = [ctx.enter_context(gzip.open(file, 'rb')) for file in sys.argv[2:]]
pairs = zip(
(str(i).encode() + b":\n" for i in count()), # Line numbers
*files
)
head, middle, tail = sample(10, pairs)
for pair in chain(head, middle, tail):
for entry in pair:
sys.stdout.buffer.write(entry)
sys.stdout.buffer.write(b'\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment