Skip to content

Instantly share code, notes, and snippets.

@bddppq
Last active February 2, 2023 10:58
Show Gist options
  • Save bddppq/090fc17720b7b3517c90b6676e52549a to your computer and use it in GitHub Desktop.
Save bddppq/090fc17720b7b3517c90b6676e52549a to your computer and use it in GitHub Desktop.
import time
from timeit import default_timer as timer
class DelayedValue:
def __init__(self, done, thunk):
self.done = done
self.thunk = thunk
self.val = None
def __str__(self):
return get_(self).__str__()
def __repr__(self):
return get_(self).__repr__()
def get_(val):
if isinstance(val, list):
return [get_(v) for v in val]
if not isinstance(val, DelayedValue):
raise Exception('Only DelayedValue can be get_')
if not val.done:
val.thunk()
return val.val
def batch(max_count):
def decorator(fn):
count = 0
inputs = []
delayed_vals = []
def do():
nonlocal count
nonlocal inputs
nonlocal delayed_vals
outputs = fn(inputs)
for dv, o in zip(delayed_vals, outputs):
dv.done = True
dv.val = o
count = 0
inputs = []
delayed_vals = []
def batched_fn(inp):
nonlocal count
nonlocal inputs
inputs += inp
val = DelayedValue(False, do)
delayed_vals.append(val)
count += 1
if count >= max_count:
do()
return [val]
return batched_fn
return decorator
def preprocess(x):
return x+1
@batch(max_count=4)
def main_model(x_batch):
print('called')
time.sleep(1)
return [x ** 2 for x in x_batch]
t1 = timer()
results = []
for x in range(10):
y = preprocess(x)
z = main_model([y])
results += z
print(results)
t2 = timer()
print(f'{t2 - t1:.2f}')
----------- 分割线 -----------
called
called
called
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
3.00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment