Created
February 2, 2021 15:00
-
-
Save baratrion/1640443e1f80b19962a626d4859bba0d to your computer and use it in GitHub Desktop.
Line-by-line memory profiling for Metaflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps, partial | |
| |
from metaflow import FlowSpec, catch, step | |
| |
| |
def profile_memory(f): | |
@wraps(f) | |
def func(self, *args, **kwargs): | |
from memory_profiler import choose_backend, LineProfiler, show_results | |
backend = choose_backend('psutil') | |
get_prof = partial(LineProfiler, backend=backend) | |
show_results_bound = partial( | |
show_results, stream=None, precision=1 | |
) | |
prof = get_prof() | |
prof(f)(self, *args, **kwargs) | |
show_results_bound(prof) | |
return func | |
| |
| |
class ProfileFlow(FlowSpec): | |
| |
@profile_memory | |
@step | |
def start(self): | |
self.params = range(3) | |
self.next(self.compute, foreach='params') | |
| |
@profile_memory | |
@catch(var='compute_failed') | |
@step | |
def compute(self): | |
self.i = self.input | |
l = [] | |
for i in range(self.i): | |
l.append('a' * int((100 * 1024**2))) | |
self.next(self.join) | |
| |
@profile_memory | |
@step | |
def join(self, inputs): | |
for input in inputs: | |
if input.compute_failed: | |
print('compute failed for parameter: %d' % input.i) | |
self.next(self.end) | |
| |
@step | |
def end(self): | |
pass | |
| |
if __name__ == '__main__': | |
ProfileFlow() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment