Skip to content

Instantly share code, notes, and snippets.

@DavidKatz-il
Last active February 19, 2023 07:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DavidKatz-il/e2caf17285f8ef2d4dd6e70beb8186b0 to your computer and use it in GitHub Desktop.
Save DavidKatz-il/e2caf17285f8ef2d4dd6e70beb8186b0 to your computer and use it in GitHub Desktop.
API client for airflow
from dataclasses import dataclass
from typing import Any, List, Optional, Dict, Tuple, Union
from urllib.parse import urlencode
import urllib3
import json
LIMIT = 10_000 # the default on airflow api is 100
@dataclass
class Configuration:
"""
Configuration to use with an airflow api client
"""
host: str
username: str
password: str
def auth_settings(self) -> dict:
"""
Get auth settings dict for airflow api client
:param dag_id: the dag_id
:return: dict of auth settings
"""
auth = {}
if self.username is not None and self.password is not None:
auth["Basic"] = {
"type": "basic",
"in": "header",
"key": "Authorization",
"value": urllib3.util.make_headers(
basic_auth=self.username + ":" + self.password
).get("authorization"),
}
return auth
class AirflowClientAPI:
"""
API client for airflow
"""
def __init__(self, configuration: Configuration):
"""
Initialize an AirflowClientAPI instance
:param configuration: airflow configurations
"""
self.configuration = configuration
self.pool_manager = urllib3.PoolManager()
self.default_headers = {
"Content-Type": "application/json",
"User-Agent": "OpenAPI-Generator/2.4.0/python",
}
def _request(
self,
method,
url,
headers=None,
query_params=None,
fields=None,
body=None,
auth="Basic",
) -> Any:
headers = headers or {}
fields = fields or {}
if query_params:
query_params = self.__parameters_to_tuples(query_params)
url += "?" + urlencode(query_params)
body = json.dumps(body) if body else body
auth_setting = self.configuration.auth_settings().get(auth)
headers[auth_setting["key"]] = auth_setting["value"]
headers.update(self.default_headers)
response = self.pool_manager.request(
method,
f"{self.configuration.host}/{url}",
fields=fields,
headers=headers,
body=body,
)
if response.status != 200:
raise Exception(f"request failed with status: {response.status}")
response_data = json.loads(response.data)
return response_data
@staticmethod
def __parameters_to_tuples(
params: Dict[str, Union[int, str, list, tuple]]
) -> List[Tuple]:
new_params = []
for k, v in params.items():
if isinstance(v, (list, tuple)):
new_params.extend((k, value) for value in v)
elif isinstance(v, (int, str)):
new_params.append((k, v))
else:
raise Exception(
f"key: {k} has a type: {type(v)} that is not supported."
)
return new_params
def _get_dags(
self,
dag_id_pattern: Optional[str] = None,
tags: Optional[List[str]] = None,
only_active: Optional[bool] = True,
) -> List[Dict]:
"""
Get all dags info
:param dag_id_pattern: the dag_id pattern
:param tags: tags to filter on
:param only_active: filter on only_active or not
:return: all dags info
"""
query_params = {"only_active": only_active}
if dag_id_pattern:
query_params["dag_id_pattern"] = dag_id_pattern
if tags:
query_params["tags"] = tags
return self._request("GET", "dags", query_params=query_params)["dags"]
def get_all_dag_ids(
self,
dag_id_pattern: Optional[str] = None,
tags: Optional[List[str]] = None,
only_active: Optional[bool] = True,
) -> List[str]:
"""
Get all dag_id's
:param dag_id_pattern: the dag_id pattern
:param tags: tags to filter on
:return: all matched dag_id's
"""
dags = self._get_dags(dag_id_pattern, tags, only_active)
dag_ids = [dag["dag_id"] for dag in dags]
return dag_ids
def _get_dag_runs(self, dag_id: str) -> List[Dict]:
"""
Get all dag runs of a given dag_id
:param dag_id: the dag_id
:return: all dag runs
"""
return self._request("GET", f"dags/{dag_id}/dagRuns")["dag_runs"]
def get_last_dag_run_id(self, dag_id: str) -> str:
"""
Get the last dag_run_id of a given dag_id
:param dag_id: the dag_id
:return: the dag_run_id
"""
dag_runs = self._get_dag_runs(dag_id)
last_dag_run_id = max(dag_runs, key=lambda dag: dag["start_date"])["dag_run_id"]
return last_dag_run_id
def get_task_instances(
self, dag_id: str, dag_run_id: str, state: Optional[List[str]] = None
) -> List[Dict]:
"""
Get a list of task instances
:param dag_id: the dag_id
:param dag_run_id: the dag_run_id
:param state: list of state to filter
:return: list of task instances
"""
query_params = {"limit": LIMIT}
if state:
query_params["state"] = state
response_data = self._request(
"GET",
f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances",
query_params=query_params,
)
return response_data["task_instances"]
def unpause_dags(
self, dag_id_pattern: str, tags: Optional[List[str]] = None
) -> dict:
"""
Unpause a dag
:param dag_id_pattern: dag_id pattern
:param tags: list of tags to filter
:return: the request response data
"""
body = {"is_paused": False}
query_params = {"update_mask": "is_paused", "dag_id_pattern": dag_id_pattern}
if tags:
query_params["tags"] = tags
return self._request("PATCH", "dags", query_params=query_params, body=body)
def trigger_dag(
self, dag_id: str, conf: Optional[Dict] = None
) -> Dict[str, Union[int, List[Dict]]]:
"""
Trigger a dag
:param dag_id: the dag_id
:param conf: conf to run the dag with
:return: the request response data
"""
body = None
if conf:
body = {"conf": conf}
return self._request("POST", f"dags/{dag_id}/dagRuns", body=body)
def clear_tasks(
self,
dag_id: str,
dag_run_id: str,
task_ids: List[str],
include_downstream: Optional[bool] = True,
) -> Dict[str, List[Dict]]:
"""
Clear the state of all the tasks for a specific dag run
:param dag_id: the dag_id
:param dag_run_id: the dag_run_id
:param task_ids: list of task_id to clear
:param include_downstream: to include downstream default is True
:return: the request response data
"""
body = {
"dry_run": False,
"dag_run_id": dag_run_id,
"task_ids": task_ids,
"only_failed": True,
"include_downstream": include_downstream,
}
return self._request("POST", f"dags/{dag_id}/clearTaskInstances", body=body)
def extract_task_id(self, ld: List[Dict]) -> List[str]:
"""
Extract the task ids from a list of dicts with the key "task_id"
:param ld: list of dicts
:return: list of task_id
"""
return self.__get_values_from_list_of_dicts(ld, "task_id")
@staticmethod
def __get_values_from_list_of_dicts(ld: List[Dict], key: str) -> List[str]:
"""
Get all key values from a list of dicts
:param ld: list of dicts
:param key: key name
:return: list of the values
"""
return list(d[key] for d in ld)
def main():
configuration = Configuration(
host="http://<<URL>>:8080/api/v1",
username="<<USERNAME>>",
password="<<PASSWORD>>",
)
airflow_client_api = AirflowClientAPI(configuration)
dag_id = "<<DAG-ID>>"
dag_run_id = "<<DAG-RUN-ID>>"
# unpause a dag
airflow_client_api.unpause_dags(dag_id)
# trigger a dag with conf
airflow_client_api.trigger_dag(dag_id, conf={"key": "value"})
# clear failed tasks with downstream
failed_tasks = airflow_client_api.get_task_instances(
dag_id, dag_run_id, state=["failed"]
)
airflow_client_api.clear_tasks(
dag_id, dag_run_id, airflow_client_api.extract_task_id(failed_tasks)
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment