Last active
July 10, 2022 07:34
-
-
Save Shinichi-Nakagawa/8033a6558098a50821b9a62a7264ca90 to your computer and use it in GitHub Desktop.
DataprocでBaseball savantデータを集計する(保存済みのCSVがGCSにあったとして)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import os | |
from pyspark.sql import SparkSession | |
from pyspark.sql import DataFrame as SparkDataFrame | |
from pyspark.sql.utils import AnalysisException | |
from pyspark.sql.types import StructType, StructField, DoubleType, DateType, StringType, LongType | |
spark = SparkSession \ | |
.builder \ | |
.appName('app_334')\ | |
.config('spark.jars', 'gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.25.2.jar') \ | |
.config('spark.sql.debug.maxToStringFields', 2000) \ | |
.getOrCreate() | |
# TODO 環境変数にする | |
# GCSのパス | |
GCS_ROOT = 'gs://hogefuga/statcast' | |
GCS_BUCKET = os.getenv('GCS_BUCKET', 'baseball-savant-temporary-bucket') | |
# プロジェクト | |
GCP_PROJECT = os.getenv('GCP_PROJECT', 'sample-project') | |
# TODO 環境変数にする | |
# 日付フォーマット | |
DATE_FORMAT = '%Y-%m-%d' | |
# schema設定(ちょっと長い) | |
STATCAST_SCHEMA = StructType( | |
[ | |
StructField("pitch_type", StringType(), False), | |
StructField("game_date", DateType(), False), | |
StructField("release_speed", DoubleType(), False), | |
StructField("release_pos_x", DoubleType(), False), | |
StructField("release_pos_z", DoubleType(), False), | |
StructField("player_name", StringType(), False), | |
StructField("batter", StringType(), False), | |
StructField("pitcher", StringType(), False), | |
StructField("events", StringType(), False), | |
StructField("description", StringType(), False), | |
StructField("spin_dir", DoubleType(), False), | |
StructField("spin_rate_deprecated", DoubleType(), False), | |
StructField("break_angle_deprecated", DoubleType(), False), | |
StructField("break_length_deprecated", DoubleType(), False), | |
StructField("zone", LongType(), False), | |
StructField("des", StringType(), False), | |
StructField("game_type", StringType(), False), | |
StructField("stand", StringType(), False), | |
StructField("p_throws", StringType(), False), | |
StructField("home_team", StringType(), False), | |
StructField("away_team", StringType(), False), | |
StructField("type", StringType(), False), | |
StructField("hit_location", LongType(), False), | |
StructField("bb_type", StringType(), False), | |
StructField("balls", LongType(), False), | |
StructField("strikes", LongType(), False), | |
StructField("game_year", LongType(), False), | |
StructField("pfx_x", DoubleType(), False), | |
StructField("pfx_z", DoubleType(), False), | |
StructField("plate_x", DoubleType(), False), | |
StructField("plate_z", DoubleType(), False), | |
StructField("on_3b", StringType(), False), | |
StructField("on_2b", StringType(), False), | |
StructField("on_1b", StringType(), False), | |
StructField("outs_when_up", DoubleType(), False), | |
StructField("inning", DoubleType(), False), | |
StructField("inning_topbot", StringType(), False), | |
StructField("hc_x", DoubleType(), False), | |
StructField("hc_y", DoubleType(), False), | |
StructField("tfs_deprecated", StringType(), False), | |
StructField("tfs_zulu_deprecated", StringType(), False), | |
StructField("fielder_2", StringType(), False), | |
StructField("umpire", StringType(), False), | |
StructField("sv_id", StringType(), False), | |
StructField("vx0", DoubleType(), False), | |
StructField("vy0", DoubleType(), False), | |
StructField("vz0", DoubleType(), False), | |
StructField("ax", DoubleType(), False), | |
StructField("ay", DoubleType(), False), | |
StructField("az", DoubleType(), False), | |
StructField("sz_top", DoubleType(), False), | |
StructField("sz_bot", DoubleType(), False), | |
StructField("hit_distance_sc", DoubleType(), False), | |
StructField("launch_speed", DoubleType(), False), | |
StructField("launch_angle", DoubleType(), False), | |
StructField("effective_speed", DoubleType(), False), | |
StructField("release_spin_rate", DoubleType(), False), | |
StructField("release_extension", DoubleType(), False), | |
StructField("game_pk", StringType(), False), | |
StructField("pitcher.1", StringType(), False), | |
StructField("fielder_2.1", StringType(), False), | |
StructField("fielder_3", StringType(), False), | |
StructField("fielder_4", StringType(), False), | |
StructField("fielder_5", StringType(), False), | |
StructField("fielder_6", StringType(), False), | |
StructField("fielder_7", StringType(), False), | |
StructField("fielder_8", StringType(), False), | |
StructField("fielder_9", StringType(), False), | |
StructField("release_pos_y", DoubleType(), False), | |
StructField("estimated_ba_using_speedangle", DoubleType(), False), | |
StructField("estimated_woba_using_speedangle", DoubleType(), False), | |
StructField("woba_value", DoubleType(), False), | |
StructField("woba_denom", DoubleType(), False), | |
StructField("babip_value", DoubleType(), False), | |
StructField("iso_value", DoubleType(), False), | |
StructField("launch_speed_angle", DoubleType(), False), | |
StructField("at_bat_number", LongType(), False), | |
StructField("pitch_number", LongType(), False), | |
StructField("pitch_name", StringType(), False), | |
StructField("home_score", LongType(), False), | |
StructField("away_score", LongType(), False), | |
StructField("bat_score", LongType(), False), | |
StructField("fld_score", LongType(), False), | |
StructField("post_away_score", LongType(), False), | |
StructField("post_home_score", LongType(), False), | |
StructField("post_bat_score", LongType(), False), | |
StructField("post_fld_score", LongType(), False), | |
StructField("if_fielding_alignment", StringType(), False), | |
StructField("of_fielding_alignment", StringType(), False), | |
StructField("spin_axis", DoubleType(), False), | |
StructField("delta_home_win_exp", DoubleType(), False), | |
StructField("delta_run_exp", DoubleType(), False) | |
] | |
) | |
# トラッキングデータ | |
QUERY_TRACKING_DATA = ''' | |
select | |
pitch_type, | |
game_date, | |
release_speed, | |
release_pos_x, | |
release_pos_z, | |
batter, | |
pitcher, | |
events, | |
description, | |
spin_dir, | |
spin_rate_deprecated, | |
break_angle_deprecated, | |
break_length_deprecated, | |
zone, | |
des, | |
game_type, | |
stand, | |
p_throws, | |
home_team, | |
away_team, | |
type, | |
hit_location, | |
bb_type, | |
balls, | |
strikes, | |
game_year, | |
pfx_x, | |
pfx_z, | |
plate_x, | |
plate_z, | |
on_3b, | |
on_2b, | |
on_1b, | |
outs_when_up, | |
inning, | |
inning_topbot, | |
hc_x, | |
hc_y, | |
tfs_deprecated, | |
tfs_zulu_deprecated, | |
fielder_2, | |
umpire, | |
sv_id, | |
vx0, | |
vy0, | |
vz0, | |
ax, | |
ay, | |
az, | |
sz_top, | |
sz_bot, | |
hit_distance_sc, | |
launch_speed, | |
launch_angle, | |
effective_speed, | |
release_spin_rate, | |
release_extension, | |
game_pk, | |
`pitcher.1` as fielder_p, | |
`fielder_2.1` as fielder_c, | |
fielder_3 as fielder_1b, | |
fielder_4 as fielder_2b, | |
fielder_5 as fielder_3b, | |
fielder_6 as fielder_ss, | |
fielder_7 as fielder_lf, | |
fielder_8 as fielder_cf, | |
fielder_9 as fielder_rf, | |
release_pos_y, | |
estimated_ba_using_speedangle, | |
estimated_woba_using_speedangle, | |
woba_value, | |
woba_denom, | |
babip_value, | |
iso_value, | |
launch_speed_angle, | |
at_bat_number, | |
pitch_number, | |
pitch_name, | |
home_score, | |
away_score, | |
bat_score, | |
fld_score, | |
post_away_score, | |
post_home_score, | |
post_bat_score, | |
post_fld_score, | |
if_fielding_alignment, | |
of_fielding_alignment, | |
spin_axis, | |
delta_home_win_exp, | |
delta_run_exp | |
from batterCsv | |
order by game_date, game_pk, inning, inning_topbot desc, at_bat_number, pitch_number | |
''' | |
def read_csv(date: str, filename: str, schema: StructType = None) -> SparkDataFrame: | |
try: | |
return spark.read.format('csv').options(header="true", inferSchema="true").load(f'{GCS_ROOT}/{date}/{filename}', schema=schema) | |
except AnalysisException: | |
return None | |
def get_dataframe(date: datetime, filename: str, schema: StructType = None) -> SparkDataFrame: | |
dir_name = date.strftime(DATE_FORMAT) | |
return read_csv(dir_name, filename, schema) | |
def save_bigquery(sdf: SparkDataFrame, table_name: str): | |
sdf.write\ | |
.mode('append') \ | |
.format('bigquery') \ | |
.option('table', f'{GCP_PROJECT}.{table_name}') \ | |
.option('temporaryGcsBucket', GCS_BUCKET) \ | |
.option('createDisposition', 'CREATE_NEVER') \ | |
.save() | |
def execute(run_date: datetime): | |
# それぞれのDataFrame | |
sdf_batter = get_dataframe(run_date, 'batter.csv', schema=STATCAST_SCHEMA) | |
sdf_pitcher = get_dataframe(run_date, 'pitcher.csv') | |
# Temporary table | |
sdf_batter.createOrReplaceTempView('batterCsv') | |
sdf_pitcher.createOrReplaceTempView('pitcherCSV') | |
sdf_tracking_data = spark.sql(QUERY_TRACKING_DATA) | |
sdf_tracking_data.createOrReplaceTempView('trackingDataset') | |
save_bigquery(sdf_tracking_data, 'tracking') | |
# 選手リスト | |
sdf_player_b = spark.sql('select distinct game_year, game_date, batter as player_id, player_name from batterCsv') | |
sdf_player_p = spark.sql('select distinct game_year, game_date, pitcher as player_id, player_name from pitcherCSV') | |
sdf_player = sdf_player_b.unionAll(sdf_player_p) | |
sdf_player.createOrReplaceTempView('playerDataset') | |
sdf_player = spark.sql('select distinct game_year, game_date, player_id, player_name from playerDataset') | |
sdf_player = sdf_player.withColumn("game_year", sdf_player["game_year"].cast(LongType())) | |
sdf_player = sdf_player.withColumn("game_date", sdf_player["game_date"].cast(DateType())) | |
save_bigquery(sdf_player, 'player') | |
if __name__ == "__main__": | |
# TODO イベントトリガーで貰うのが理想 | |
_date: datetime = datetime(2022, 7, 8) | |
execute(run_date=_date) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment