Skip to content

Instantly share code, notes, and snippets.

@AlexanderVR
Created December 30, 2022 21:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexanderVR/67ade7ecfb1d72578675eca2171f92cd to your computer and use it in GitHub Desktop.
Save AlexanderVR/67ade7ecfb1d72578675eca2171f92cd to your computer and use it in GitHub Desktop.
custom PartitionMapping for assets
from typing import Callable, Optional
from dagster import (
AssetIn,
AssetSelection,
Definitions,
PartitionKeyRange,
PartitionMapping,
PartitionsDefinition,
SourceAsset,
StaticPartitionsDefinition,
asset,
build_asset_reconciliation_sensor,
materialize_to_memory,
)
from dagster._core.definitions.partition import PartitionsSubset
class CoalescingPartitionMapping(PartitionMapping):
"""
Partition mapping defined by any function from upstream partitions to downstream partitions
"""
def __init__(self, mapper: Callable[[str], str]):
self._mapper = mapper
def get_downstream_partitions_for_partitions(
self,
upstream_partitions_subset: PartitionsSubset,
downstream_partitions_def: Optional[PartitionsDefinition],
) -> PartitionsSubset:
if downstream_partitions_def is None:
raise NotImplementedError()
downstream_subset = downstream_partitions_def.empty_subset()
downstream_keys = []
for key in upstream_partitions_subset.get_partition_keys():
downstream_keys.append(self._mapper(key))
return downstream_subset.with_partition_keys(downstream_keys)
def get_upstream_partitions_for_partitions(
self,
downstream_partitions_subset: Optional[PartitionsSubset],
upstream_partitions_def: PartitionsDefinition,
) -> PartitionsSubset:
if downstream_partitions_subset is None:
raise NotImplementedError()
upstream_subset = upstream_partitions_def.empty_subset()
upstream_keys = []
target_keys = set(downstream_partitions_subset.get_partition_keys())
for key in upstream_partitions_def.get_partition_keys():
if self._mapper(key) in target_keys:
upstream_keys.append(key)
return upstream_subset.with_partition_keys(upstream_keys)
def get_upstream_partitions_for_partition_range(
self,
downstream_partition_key_range: Optional[PartitionKeyRange],
downstream_partitions_def: Optional[PartitionsDefinition],
upstream_partitions_def: PartitionsDefinition,
) -> PartitionKeyRange:
raise NotImplementedError()
def get_downstream_partitions_for_partition_range(
self,
upstream_partition_key_range: PartitionKeyRange,
downstream_partitions_def: Optional[PartitionsDefinition],
upstream_partitions_def: PartitionsDefinition,
) -> PartitionKeyRange:
raise NotImplementedError()
upstream_parts = StaticPartitionsDefinition(["p_1", "p_2", "p_3", "q_1", "q_2", "r_1"])
downstream_parts = StaticPartitionsDefinition(["p", "q", "r"])
def key_mapper(p: str) -> str:
return p.split("_")[0]
@asset(partitions_def=upstream_parts)
def a(context):
return "key-" + context.asset_partition_key_for_output()
@asset(
partitions_def=downstream_parts,
ins={"up": AssetIn("a", partition_mapping=CoalescingPartitionMapping(key_mapper))},
)
def b(context, up):
return up
defs = Definitions(
assets=[a, b], sensors=[build_asset_reconciliation_sensor(AssetSelection.all())]
)
if __name__ == "__main__":
for upstream_key in ("p_1", "p_2", "p_3"):
materialize_to_memory([a], partition_key=upstream_key)
r = materialize_to_memory(
[SourceAsset("a", partitions_def=upstream_parts), b], partition_key="p"
)
print(r.output_for_node("b"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment