Skip to content

Instantly share code, notes, and snippets.

@pryce-turner
Created November 29, 2023 02:06
Show Gist options
  • Save pryce-turner/a84f4cbeea1cf3923f625e4407cf46f0 to your computer and use it in GitHub Desktop.
Save pryce-turner/a84f4cbeea1cf3923f625e4407cf46f0 to your computer and use it in GitHub Desktop.
Node caching client side implementation
import os
import shutil
import hashlib
from time import sleep
from typing import List
from random import randint
from pathlib import Path
from flytekit import task, workflow, dynamic
from flytekit.types.file import FlyteFile
from flytekitplugins.pod import Pod
from kubernetes.client.models import (
V1PodSpec,
V1Volume,
V1Container,
V1VolumeMount,
V1PersistentVolumeClaimVolumeSource,
)
pod_mount_path = "/nodecache"
vol_name = "task-cache-vol"
persist_local_ps = V1PodSpec(
volumes=[
V1Volume(
name=vol_name,
persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(
claim_name="task-cache-pvc"
),
)
],
containers=[
V1Container(
name="primary",
image="docker.io/rwgrim/docker-noop",
image_pull_policy="IfNotPresent",
volume_mounts=[
V1VolumeMount(
name=vol_name,
sub_path="task_cache",
mount_path=pod_mount_path,
),
],
)
],
)
class CacheFile:
def __init__(self, ff: FlyteFile):
self.ff = ff
self.fname = Path(ff.path).name
self.path = None
try:
os.listdir(pod_mount_path)
except FileNotFoundError as e:
raise FileNotFoundError(
f"The default mount path ({pod_mount_path}) does not exist. "
f"Did you use the appropriate pod spec in your task config?"
) from e
# More informative error when a FlyteFile is initialized without a downloader
assert 'noop' not in self.ff._downloader.__str__(), (
"FlyteFile initialized with no downloader. "
"Was it not created at the task boundary?"
)
def check_cache(self) -> str:
self.path = Path(pod_mount_path).joinpath(self.fname)
if self.path.exists():
return 'HIT'
else:
# Get lockfile path and sleep until another task caches
lock = self.path.with_name(self.path.name + ".caching")
# sleep(randint(0, 5))
if not lock.exists():
# Cache file and cleanup
lock.touch()
self.ff.download()
shutil.move(self.ff.path, self.path)
self.ff.path = self.path
lock.unlink()
return 'CACHED'
else:
while lock.exists():
sleep(5)
return 'CACHED_OTHER'
@task(task_config=Pod(pod_spec=persist_local_ps))
def scratch(ff: FlyteFile) -> str:
cf = CacheFile(ff)
stat = cf.check_cache()
with open(cf.path, "r") as file:
line_count = 0
for line in file:
print(line.strip()) # Print each line, removing newline characters
line_count += 1
if line_count >= 50:
break
return stat
@task(task_config=Pod(pod_spec=persist_local_ps))
def drain_cache(ff: FlyteFile, s: List[str]):
cf = CacheFile(ff)
cf.check_cache()
os.remove(cf.path)
@dynamic
def wf():
stats = []
for i in ["s3://my-s3-bucket/my-data/refs/GRCh38.fasta" for _ in range(10)]:
stat = scratch(ff=i)
stats.append(stat)
drain_cache(ff="s3://my-s3-bucket/my-data/refs/GRCh38.fasta", s=stats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment