Skip to content

Instantly share code, notes, and snippets.

@sam-goodwin
Last active July 2, 2024 05:41
Show Gist options
  • Save sam-goodwin/d8dd76ad58a241cdb14deba9cb53c2bf to your computer and use it in GitHub Desktop.
Save sam-goodwin/d8dd76ad58a241cdb14deba9cb53c2bf to your computer and use it in GitHub Desktop.
Backfill stale partitioned assets
import json
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
import requests
from pydantic import BaseModel
from typing_extensions import TypedDict
dagster_endpoint = "http://localhost:3000"
dagster_graphql_url = f"{dagster_endpoint}/graphql"
class GraphQLError(Exception): ...
class AssetKey(TypedDict):
path: list[str]
@dataclass
class AssetPartitionStatus:
asset_key: AssetKey
partitions: list[str]
class MaterializeBatchRequest(BaseModel):
asset_keys: list[AssetKey]
partitions: list[str]
StaleAssetsByPartition = dict[str, list[AssetKey]]
def materialize_stale_partitions(group_name: str):
print(f"Materializing stale partitions for group '{group_name}'")
# Fetch stale partitions directly
stale_partitions = fetch_assets_data(group_name)
if stale_partitions:
print(f"Materializing assets:\n{json.dumps(stale_partitions, indent=2)}")
return launch_backfill(
stale_partitions,
title=f"Backfill for stale partitions in group {group_name}",
description=f"Materializing stale partitions for assets in group {group_name}",
)
else:
print(f"No stale partitions found in group {group_name}")
return None
def launch_backfill(stale_assets_by_partition: dict[str, list[AssetKey]], title: str, description: str):
batches = group_partitions_into_batches(stale_assets_by_partition)
results = []
for i, batch in enumerate(batches):
print(f"Launching backfill for batch {i+1}:\n{json.dumps(batch.model_dump(), indent=2)}")
backfill = execute_query(
"""
mutation LaunchBackfillForAssets($backfillParams: LaunchBackfillParams!) {
launchPartitionBackfill(backfillParams: $backfillParams) {
__typename
... on LaunchBackfillSuccess {
backfillId
launchedRunIds
}
... on PythonError {
message
stack
}
... on UnauthorizedError {
message
}
... on InvalidSubsetError {
message
}
... on PartitionSetNotFoundError {
message
}
}
}
""",
{
"backfillParams": {
# TODO(sgoodwin): handle asset keys with multiple path components
"assetSelection": batch.asset_keys,
"partitionNames": batch.partitions,
"fromFailure": False,
"title": f"{title} - Batch {i+1}",
"description": f"{description} - Batch {i+1}",
},
},
)
backfill_id = backfill["data"]["launchPartitionBackfill"]["backfillId"]
print(f"{dagster_endpoint}/overview/backfills/{backfill_id}")
results.append(backfill)
return results
def group_partitions_into_batches(stale_assets_by_partitions: StaleAssetsByPartition) -> list[MaterializeBatchRequest]:
# find all partitions that share the same set of stale partitions
# { partition_set: list[asset_key] }
partitions_by_asset_graph = defaultdict(list)
for partition, asset_list in stale_assets_by_partitions.items():
# TODO(sgoodwin): is `:` a safe separator?
set_of_stale_assets = ":".join(
# sort the asset keys to ensure determinism
sorted(
["/".join(asset_key["path"]) for asset_key in asset_list],
),
)
partitions_by_asset_graph[set_of_stale_assets].append(partition)
return [
MaterializeBatchRequest(
asset_keys=[AssetKey(path=asset_key.split("/")) for asset_key in set_of_stale_assets_str.split(":")],
partitions=partitions,
)
for set_of_stale_assets_str, partitions in partitions_by_asset_graph.items()
]
def fetch_assets_data(group_name: str) -> StaleAssetsByPartition:
# TODO(sgoodwin): can we use staleStatusByPartition in this query instead of executing follow up requests?
data = execute_query(
"""
query AssetsByGroup($groupName: String!) {
assetNodes(group: {
groupName: $groupName,
repositoryName:"__repository__",
repositoryLocationName:"noetik_pipeline_methods.defs"
}) {
id
assetKey {
path
}
partitionDefinition {
name
}
partitionKeys
}
}
""",
{
"groupName": group_name,
},
)
stale_partitions = defaultdict(list)
for asset_node in data["data"]["assetNodes"]:
asset_key = asset_node["assetKey"]
partition_keys = asset_node["partitionKeys"] or []
# Fetch stale statuses for each asset
stale_statuses = fetch_stale_statuses(group_name, asset_key, partition_keys)
for partition, status in zip(partition_keys, stale_statuses, strict=False):
if status in ["MISSING", "STALE"]:
stale_partitions[partition].append(asset_key)
return dict(stale_partitions)
def fetch_stale_statuses(group_name: str, asset_key: AssetKey, partition_keys: list[str]) -> list[str]:
"""Fetch stale statuses for specific partitions of an asset."""
if not partition_keys:
return []
data = execute_query(
"""
query AssetStaleStatus($groupName: String!, $assetKey: AssetKeyInput!, $partitionKeys: [String!]!) {
assetNodes(group: {
groupName: $groupName,
repositoryName:"__repository__",
repositoryLocationName:"noetik_pipeline_methods.defs"
}, assetKeys: [$assetKey]) {
id
staleStatusByPartition(partitions: $partitionKeys)
}
}
""",
{
"groupName": group_name,
"assetKey": asset_key,
"partitionKeys": partition_keys,
},
)
asset_nodes = data["data"]["assetNodes"]
if asset_nodes and asset_nodes[0]["staleStatusByPartition"]:
return asset_nodes[0]["staleStatusByPartition"]
else:
return ["UNKNOWN"] * len(partition_keys) # Default to UNKNOWN if no status is returned
def execute_query(query: str, variables: dict[str, Any]) -> Any:
response = requests.post(
dagster_graphql_url,
json={
"query": query,
"variables": variables,
},
timeout=30,
)
if not response.ok:
raise GraphQLError(f"GraphQL request failed with status code {response.status_code}: {response.text}")
return response.json()
# Example usage
if __name__ == "__main__":
group_name = "test"
materialize_stale_partitions(group_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment