-
-
Save dq-hustlecoding/af8dead300d7560e732bb9e18bf647f9 to your computer and use it in GitHub Desktop.
batch inference for AWS Personalize
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def user_input() -> None: | |
df = pd.read_sql_table("테이블 이름") | |
# 여기에서도 위의 etl_user 함수에서 해준 것처럼 동일한 type으로 전처리 해줘야합니다. | |
# 그 결과를 json 형태로 만들어서 S3에 업로드 해주어야 합니다. | |
out = user_df.to_json(orient='records')[1:-1].replace('},{', '}\n{') | |
s3 = boto3.resource('s3') s3.Object('your destination', 'path/to/batch_input.json').put(Body=out) | |
def batch_recommendation() -> str: | |
DSG_ARN = personalize.list_dataset_groups()['datasetGroups'][0]['datasetGroupArn'] | |
SOLUTION_ARN = personalize.list_solutions(datasetGroupArn=DSG_ARN)['solutions'][0]['solutionArn'] | |
SV_ARN = personalize.list_solution_versions( solutionArn=SOLUTION_ARN )['solutionVersions'][0]['solutionVersionArn'] | |
filter_list = personalize.list_filters( datasetGroupArn=DSG_ARN )['Filters'] | |
if len(filter_list) == 0: | |
filter_obj = personalize.create_filter( | |
name=f"filter-{int(time.time())}", | |
datasetGroupArn=DSG_ARN, | |
filterExpression=f'EXCLUDE ItemID WHERE 조건1 AND 조건 2' | |
) | |
filter_arn = filter_obj['filterArn'] | |
else: | |
filter_arn = filter_list[0]['filterArn'] | |
print("1. FILTER :: ", filter_arn) | |
# filter 와 solution version 가 생성될 때까지 대기하는 while loop입니다. | |
FILTER_STATUS = '' | |
SV_STATUS = '' | |
while SV_STATUS != 'ACTIVE' or FILTER_STATUS != 'ACTIVE': | |
print("waiting SV.... ", SV_STATUS, "\nfilter.... ::", FILTER_STATUS) | |
time.sleep(30) | |
SV_STATUS = personalize.describe_solution_version( | |
solutionVersionArn=SV_ARN | |
)['solutionVersion']['status'] | |
FILTER_STATUS = personalize.describe_filter( | |
filterArn=filter_arn | |
)['filter']['status'] | |
response = personalize.create_batch_inference_job( | |
solutionVersionArn=SV_ARN, | |
jobName=f"recommendation-batch-{int(time.time())}", | |
roleArn='your role arn', | |
jobInput= {"s3DataSource": {"path": "s3://path/to/batch_input.json"}}, | |
jobOutput= {"s3DataDestination": {"path": "s3://path/to/batch_result/"}}, | |
numResults=500, #(maximum is 500) | |
filterArn=filter_arn | |
) | |
BATCH_ARN = response['batchInferenceJobArn'] | |
print("Batch job created ", BATCH_ARN) | |
return BATCH_ARN |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment