Skip to content

Instantly share code, notes, and snippets.

@gallir
Last active December 25, 2021 21:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gallir/6cabc4580865da71ef4c0725e8fb1b13 to your computer and use it in GitHub Desktop.
Save gallir/6cabc4580865da71ef4c0725e8fb1b13 to your computer and use it in GitHub Desktop.
A basic pipeline for AWS Forecast
from botocore.hooks import _PrefixTrie
import util
import boto3
from .s3utils import upload_csv
import time
import pprint
class Pipeline:
def __init__(self, name, target, freq, horizon, s3_bucket, related=None, domain="RETAIL",
s3_prefix="forecast", role_name="ForecastExecution", aws_region="eu-west-1",
algorithm="arn:aws:forecast:::algorithm/CNN-QR", autoML=False):
self.name = name
self.target = target
self.freq = freq
self.horizon = horizon
self.s3_bucket = s3_bucket
self.related = related
self.domain = domain
self.s3_prefix = s3_prefix
self.role_name = role_name
self.aws_region = aws_region
self.algorithm = algorithm
self.autoML = autoML
self.key_target = f"{s3_prefix}/{name}/target/{target.name}.csv"
if related:
self.key_related = f"{s3_prefix}/{name}/related/{related.name}.csv"
else:
self.key_related = None
self.session = boto3.Session(region_name=self.aws_region)
self.forecast = self.session.client(service_name='forecast')
self.s3 = self.session.client('s3')
self.role_arn = None
self.ds_arns = []
self.dg_arn = None
self.ds_import_arns = []
self.predictor_arn = None
self.forecast_arn = None
self.export_arn = None
def run(self, create=True, dont_upload=False, delete_after=False, delete_all=False):
if delete_all:
create = False
self.set_role(create=create)
if not dont_upload and not delete_all:
self.upload_data()
self.set_datasets(create=create)
self.set_dataset_group(create=create)
self.do_import(create=create)
self.set_predictor(create=create)
self.do_forecast(create=create)
if delete_after or delete_all:
self.delete_all()
def set_role(self, create=True):
if not create:
return
self.role_arn = util.get_or_create_iam_role(role_name=self.role_name)
def upload_data(self):
print("Uploading data to", self.s3_bucket)
upload_csv(self.target.get_df(), self.s3_bucket, self.key_target, s3_client=self.s3)
if self.related:
upload_csv(self.related.get_df(), self.s3_bucket, self.key_related, s3_client=self.s3)
def set_datasets(self, create=True):
# Create Datasets
ds_target = f"{self.name}_target_ds"
print("Processing target dataset", ds_target)
ds_dict = dict()
resp = self.forecast.list_datasets()
for ds in resp['Datasets']:
ds_dict[ds['DatasetName']] = ds['DatasetArn']
if ds_target in ds_dict:
self.ds_arns.append(ds_dict[ds_target])
elif create:
response = self.forecast.create_dataset(
Domain=self.domain,
DatasetType='TARGET_TIME_SERIES',
DatasetName=ds_target,
DataFrequency=self.freq,
Schema=self.target.get_schema()
)
self.ds_arns.append(response['DatasetArn'])
print("Created target dataset", self.ds_arns[0])
if not self.related:
return
ds_related = f"{self.name}_related_ds"
print("Processing related dataset", ds_related)
if ds_related in ds_dict:
self.ds_arns.append(ds_dict[ds_related])
elif create:
response = self.forecast.create_dataset(
Domain=self.domain,
DatasetType='RELATED_TIME_SERIES',
DatasetName=ds_related,
DataFrequency=self.freq,
Schema=self.related.get_schema()
)
self.ds_arns.append(response['DatasetArn'])
print("Created related dataset", self.ds_arns[1])
def set_dataset_group(self, create=True):
# Create Dataset group
dg_name = f"{self.name}_dg"
print("Processing Dataset Group", dg_name)
dg_dict = dict()
resp = self.forecast.list_dataset_groups()
for ds in resp['DatasetGroups']:
dg_dict[ds['DatasetGroupName']] = ds['DatasetGroupArn']
if dg_name in dg_dict:
self.dg_arn = dg_dict[dg_name]
elif create:
response = self.forecast.create_dataset_group(DatasetGroupName=dg_name,
Domain=self.domain,
DatasetArns=self.ds_arns
)
self.dg_arn = response['DatasetGroupArn']
print("Created Dataset Group", self.dg_arn) # forecast.describe_dataset_group(DatasetGroupArn=dg_arn))
def do_import(self, create=True):
# Import JOBS
ds_target_import_name = f"target_{self.name}"
print("Importing target data", ds_target_import_name)
job_arn = ''
jobs = self.forecast.list_dataset_import_jobs()['DatasetImportJobs']
# Target import
for j in jobs:
if j['DatasetImportJobName'] == ds_target_import_name:
job_arn = j['DatasetImportJobArn']
print(f"Skipping existing target import: {job_arn}")
break
if not job_arn and create:
resp = self.forecast.create_dataset_import_job(
DatasetImportJobName=ds_target_import_name,
DatasetArn=self.ds_arns[0],
DataSource={
"S3Config": {
"Path": f"s3://{self.s3_bucket}/{self.key_target}",
"RoleArn": self.role_arn
}
},
TimestampFormat=self.target.timestamp_format
)
job_arn = resp['DatasetImportJobArn']
print(f"Created target import: {job_arn}")
if job_arn:
self.ds_import_arns.append(job_arn)
# Related import
job_arn = ''
ds_related_import_name = f"related_{self.name}"
print("Importing related data", ds_related_import_name)
for j in jobs:
if j['DatasetImportJobName'] == ds_related_import_name:
job_arn = j['DatasetImportJobArn']
print(f"Skipping existing related import: {job_arn}")
break
if self.related and not job_arn and create:
resp = self.forecast.create_dataset_import_job(
DatasetImportJobName=ds_related_import_name,
DatasetArn=self.ds_arns[1],
DataSource={
"S3Config": {
"Path": f"s3://{self.s3_bucket}/{self.key_related}",
"RoleArn": self.role_arn
}
},
TimestampFormat=self.related.timestamp_format
)
job_arn = resp['DatasetImportJobArn']
print(f"Created related import: {job_arn}")
if job_arn:
self.ds_import_arns.append(job_arn)
statuses = ['', '']
wait_start = time.time()
while create:
for i, arn in enumerate(self.ds_import_arns):
statuses[i] = self.forecast.describe_dataset_import_job(DatasetImportJobArn=arn)['Status']
if statuses[0] in ('ACTIVE', 'CREATE_FAILED') and statuses[1] in ('ACTIVE', 'CREATE_FAILED'):
break
print(f"Waiting for import jobs, {int(time.time()-wait_start)} secs: {statuses}")
time.sleep(10)
def set_predictor(self, create=True):
# Start predictor
predictor_name = self.name + '_predictor'
print("Creating predictor", predictor_name)
predictors = self.forecast.list_predictors()['Predictors']
for e in predictors:
if e['PredictorName'] == predictor_name:
self.predictor_arn = e['PredictorArn']
print(f"Skipping existing predictor creation: {self.predictor_arn}")
break
if not self.predictor_arn and create:
resp = self.forecast.create_predictor(
PredictorName=predictor_name,
AlgorithmArn=self.algorithm,
ForecastHorizon=self.horizon,
PerformAutoML=self.autoML,
PerformHPO=False,
EvaluationParameters={
"NumberOfBacktestWindows": 1,
"BackTestWindowOffset": self.horizon
},
InputDataConfig={"DatasetGroupArn": self.dg_arn},
FeaturizationConfig={
"ForecastFrequency": self.freq,
# "Featurizations": [target.FEATURES, covid.FEATURES]
},
TrainingParameters={'use_related_data': 'ALL'}
)
self.predictor_arn = resp['PredictorArn']
wait_start = time.time()
pred_info = {}
while create:
pred_info = self.forecast.describe_predictor(PredictorArn=self.predictor_arn)
if pred_info['Status'] in ('ACTIVE', 'CREATE_FAILED'):
if 'AlgorithmArn' in pred_info:
alg = pred_info['AlgorithmArn']
elif 'AutoMLAlgorithmArns' in pred_info:
alg = pred_info['AutoMLAlgorithmArns']
else:
alg = "unknown"
print(f"Ready, algorithm: {alg}, forecast types: {pred_info['ForecastTypes']}")
break
print(f"Waiting for predictor, {int(time.time()-wait_start)} secs: {pred_info['Status']}")
time.sleep(10)
if pred_info:
pprint.pprint(pred_info)
def do_forecast(self, create=True):
# Forecast
forecast_name = self.name + '_forecast'
print("Executing forecast", forecast_name)
forecasts = self.forecast.list_forecasts()['Forecasts']
for e in forecasts:
if e['ForecastName'] == forecast_name:
self.forecast_arn = e['ForecastArn']
print(f"Skipping existing forecast creation: {self.forecast_arn}")
break
if not self.forecast_arn and create:
resp = self.forecast.create_forecast(ForecastName=forecast_name,
PredictorArn=self.predictor_arn)
self.forecast_arn = resp['ForecastArn']
wait_start = time.time()
while create:
status = self.forecast.describe_forecast(ForecastArn=self.forecast_arn)['Status']
if status in ('ACTIVE', 'CREATE_FAILED'):
break
print(f"Waiting for forecast, {int(time.time()-wait_start)} secs: {status}")
time.sleep(10)
def export(self, create=True):
# Export Job
export_name = self.name + '_export'
print("Processing export", export_name)
export_prefix = f"{self.s3_bucket}/{self.name}/output"
export_path = f"s3://{self.s3_bucket}/{export_prefix}/"
exports = self.forecast.list_forecast_export_jobs()['ForecastExportJobs']
for e in exports:
if e['ForecastExportJobName'] == export_name:
self.export_arn = e['ForecastExportJobArn']
print(f"Skipping existing export creation: {self.export_arn}")
break
if not self.export_arn and create:
resp = self.forecast.create_forecast_export_job(
ForecastExportJobName=export_name,
ForecastArn=self.forecast_arn,
Destination={
"S3Config": {
"Path": export_path,
"RoleArn": self.role_arn
}
}
)
export_arn = resp['ForecastExportJobArn']
wait_start = time.time()
while create:
status = self.forecast.describe_forecast_export_job(ForecastExportJobArn=export_arn)['Status']
if status in ('ACTIVE', 'CREATE_FAILED'):
break
print(f"Waiting for export, {int(time.time()-wait_start)} secs: {status}")
time.sleep(10)
def delete_all(self):
# Delete forecast export for both algorithms
if self.export_arn:
print("Deleting", self.export_arn)
util.wait_till_delete(lambda: self.forecast.delete_forecast_export_job(
ForecastExportJobArn=self.export_arn))
# Delete forecast
if self.forecast_arn:
print("Deleting", self.forecast_arn)
util.wait_till_delete(lambda: self.forecast.delete_forecast(ForecastArn=self.forecast_arn))
# Delete predictor
if self.predictor_arn:
print("Deleting", self.predictor_arn)
util.wait_till_delete(lambda: self.forecast.delete_predictor(PredictorArn=self.predictor_arn))
# Delete Import
if self.ds_import_arns:
for arn in self.ds_import_arns:
print("Deleting", arn)
util.wait_till_delete(lambda: self.forecast.delete_dataset_import_job(DatasetImportJobArn=arn))
# Delete the datasets
if self.ds_arns:
for arn in self.ds_arns:
print("Deleting", arn)
util.wait_till_delete(lambda: self.forecast.delete_dataset(DatasetArn=arn))
# Delete Dataset Group
if self.dg_arn:
print("Deleting", self.dg_arn)
util.wait_till_delete(lambda: self.forecast.delete_dataset_group(DatasetGroupArn=self.dg_arn))
# Delete IAM role
# util.delete_iam_role(role_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment