Skip to content

Instantly share code, notes, and snippets.

@sararob
Last active August 21, 2022 11:09
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sararob/1eea7ae2b08e85851855ec2eff8c2d8b to your computer and use it in GitHub Desktop.
Save sararob/1eea7ae2b08e85851855ec2eff8c2d8b to your computer and use it in GitHub Desktop.
This Cloud Function will kick off a Vertex Pipeline run whenever a specified amount of new BigQuery data is available for training. Deploy it as an HTTP function and set up a Cloud Scheduler job to automate running it on a recurring basis. See this blog post for details: https://cloud.google.com/blog/topics/developers-practitioners/lets-get-it-s…
# Copyright 2021 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import kfp
import json
import time
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from kfp.v2.google.client import AIPlatformClient
client = bigquery.Client()
RETRAIN_THRESHOLD = 1000 # Change this based on your use case
def insert_bq_data(table_id, num_rows):
rows_to_insert = [
{u"num_rows_last_retraining": num_rows, u"last_retrain_time": time.time()}
]
errors = client.insert_rows_json(table_id, rows_to_insert)
if errors == []:
print("New rows have been added.")
else:
print(f"Encountered errors while inserting rows: {errors}")
def create_count_table(table_id, num_rows):
schema = [
bigquery.SchemaField("num_rows_last_retraining", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("last_retrain_time", "TIMESTAMP", mode="REQUIRED")
]
table = bigquery.Table(table_id, schema=schema)
table = client.create_table(table)
print(f"Created table {table.project}.{table.dataset_id}.{table.table_id}")
insert_bq_data(table_id, num_rows)
def create_pipeline_run():
print('Kicking off a pipeline run...')
REGION = "us-central1" # Change this to the region you want to run in
api_client = AIPlatformClient(
project_id=client.project,
region=REGION,
)
try:
response = api_client.create_run_from_job_spec(
"compiled_pipeline.json",
pipeline_root="gs://your-gcs-bucket/pipeline_root/",
parameter_values={"project": client.project, "display_name": "pipeline_gcf_trigger"}
)
return response
except:
print("Error trying to run the pipeline")
raise
# This should be the entrypoint for your Cloud Function
def check_table_size(request):
request = request.get_data()
try:
request_json = json.loads(request.decode())
except ValueError as e:
print(f"Error decoding JSON: {e}")
return "JSON Error", 400
if request_json and 'bq_dataset' in request_json:
dataset = request_json['bq_dataset']
table = request_json['bq_table']
data_table = client.get_table(f"{client.project}.{dataset}.{table}")
current_rows = data_table.num_rows
print(f"{table} table has {current_rows} rows")
# See if `count` table exists in dataset
try:
count_table = client.get_table(f"{client.project}.{dataset}.count")
print("Count table exists, querying to see how many rows at last pipeline run")
except NotFound:
print("No count table found, creating one...")
create_count_table(f"{client.project}.{dataset}.count", current_rows)
query_job = client.query(
"""
SELECT num_rows_last_retraining FROM `your-project.your-dataset.count`
ORDER BY last_retrain_time DESC
LIMIT 1"""
)
results = query_job.result()
for i in results:
last_retrain_count = i[0]
rows_added_since_last_pipeline_run = current_rows - last_retrain_count
print(f"{rows_added_since_last_pipeline_run} rows have been added since we last ran the pipeline")
if (rows_added_since_last_pipeline_run >= RETRAIN_THRESHOLD):
pipeline_result = create_pipeline_run()
insert_bq_data(f"{client.project}.{dataset}.count", current_rows)
else:
return f"No BigQuery data given"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment