Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save wild-endeavor/79a861d0fa3b69dc7fdbfc4514b06daf to your computer and use it in GitHub Desktop.
Save wild-endeavor/79a861d0fa3b69dc7fdbfc4514b06daf to your computer and use it in GitHub Desktop.
dynamic pod task with image
from typing import List
import flytekit.configuration
from flytekit import dynamic, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
from flytekit.core import context_manager
from flytekit.core.context_manager import ExecutionState
from flytekit.core.type_engine import TypeEngine
from flytekitplugins.pod import Pod
from kubernetes.client.models import (
V1Container,
V1EmptyDirVolumeSource,
V1PodSpec,
V1ResourceRequirements,
V1Volume,
V1VolumeMount,
)
def generate_pod_spec_for_task():
# Primary containers do not require us to specify an image, the default image built for Flyte tasks will get used.
primary_container = V1Container(name="primary")
# NOTE: For non-primary containers, we must specify the image.
secondary_container = V1Container(
name="secondary",
image="alpine",
)
secondary_container.command = ["/bin/sh"]
secondary_container.args = [
"-c",
"echo hi pod world > /tmp/blah",
]
resources = V1ResourceRequirements(
requests={"cpu": "1", "memory": "100Mi"}, limits={"cpu": "1", "memory": "100Mi"}
)
primary_container.resources = resources
secondary_container.resources = resources
shared_volume_mount = V1VolumeMount(
name="shared-data",
mount_path="/data",
)
secondary_container.volume_mounts = [shared_volume_mount]
primary_container.volume_mounts = [shared_volume_mount]
return V1PodSpec(
containers=[primary_container, secondary_container],
volumes=[
V1Volume(
name="shared-data", empty_dir=V1EmptyDirVolumeSource(medium="Memory")
)
],
)
@task(
task_config=Pod(
pod_spec=generate_pod_spec_for_task(),
primary_container_name="primary",
),
container_image="ghcr.io/special:image",
cache=True,
cache_version="1.6.23",
)
def run_splat() -> int:
return 5
@dynamic(
container_image="ghcr.io/special:image",
cache=True,
cache_version="1.6.0",
)
def _splat_dynamic(
splat_templates: List[int]
) -> List[int]:
"""Calls splat for every set of control values provided"""
return [run_splat() for _ in splat_templates]
if __name__ == "__main__":
with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(
flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="default_bad_name", fqn="default_badimage", tag="bad_tag")),
env={},
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir="/User/flyte/workflows",
distribution_location="s3://my-s3-bucket/fast/123",
),
)
)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"splat_templates": [5, 6]},
type_hints={"splat_templates": List[int]})
dynamic_job_spec = _splat_dynamic.dispatch_execute(ctx, input_literal_map)
print(dynamic_job_spec)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment