Skip to content

Instantly share code, notes, and snippets.

@Shinichi-Nakagawa
Last active July 10, 2022 07:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Shinichi-Nakagawa/8033a6558098a50821b9a62a7264ca90 to your computer and use it in GitHub Desktop.
Save Shinichi-Nakagawa/8033a6558098a50821b9a62a7264ca90 to your computer and use it in GitHub Desktop.
DataprocでBaseball savantデータを集計する(保存済みのCSVがGCSにあったとして)
from datetime import datetime
import os
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import StructType, StructField, DoubleType, DateType, StringType, LongType
spark = SparkSession \
.builder \
.appName('app_334')\
.config('spark.jars', 'gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.25.2.jar') \
.config('spark.sql.debug.maxToStringFields', 2000) \
.getOrCreate()
# TODO 環境変数にする
# GCSのパス
GCS_ROOT = 'gs://hogefuga/statcast'
GCS_BUCKET = os.getenv('GCS_BUCKET', 'baseball-savant-temporary-bucket')
# プロジェクト
GCP_PROJECT = os.getenv('GCP_PROJECT', 'sample-project')
# TODO 環境変数にする
# 日付フォーマット
DATE_FORMAT = '%Y-%m-%d'
# schema設定(ちょっと長い)
STATCAST_SCHEMA = StructType(
[
StructField("pitch_type", StringType(), False),
StructField("game_date", DateType(), False),
StructField("release_speed", DoubleType(), False),
StructField("release_pos_x", DoubleType(), False),
StructField("release_pos_z", DoubleType(), False),
StructField("player_name", StringType(), False),
StructField("batter", StringType(), False),
StructField("pitcher", StringType(), False),
StructField("events", StringType(), False),
StructField("description", StringType(), False),
StructField("spin_dir", DoubleType(), False),
StructField("spin_rate_deprecated", DoubleType(), False),
StructField("break_angle_deprecated", DoubleType(), False),
StructField("break_length_deprecated", DoubleType(), False),
StructField("zone", LongType(), False),
StructField("des", StringType(), False),
StructField("game_type", StringType(), False),
StructField("stand", StringType(), False),
StructField("p_throws", StringType(), False),
StructField("home_team", StringType(), False),
StructField("away_team", StringType(), False),
StructField("type", StringType(), False),
StructField("hit_location", LongType(), False),
StructField("bb_type", StringType(), False),
StructField("balls", LongType(), False),
StructField("strikes", LongType(), False),
StructField("game_year", LongType(), False),
StructField("pfx_x", DoubleType(), False),
StructField("pfx_z", DoubleType(), False),
StructField("plate_x", DoubleType(), False),
StructField("plate_z", DoubleType(), False),
StructField("on_3b", StringType(), False),
StructField("on_2b", StringType(), False),
StructField("on_1b", StringType(), False),
StructField("outs_when_up", DoubleType(), False),
StructField("inning", DoubleType(), False),
StructField("inning_topbot", StringType(), False),
StructField("hc_x", DoubleType(), False),
StructField("hc_y", DoubleType(), False),
StructField("tfs_deprecated", StringType(), False),
StructField("tfs_zulu_deprecated", StringType(), False),
StructField("fielder_2", StringType(), False),
StructField("umpire", StringType(), False),
StructField("sv_id", StringType(), False),
StructField("vx0", DoubleType(), False),
StructField("vy0", DoubleType(), False),
StructField("vz0", DoubleType(), False),
StructField("ax", DoubleType(), False),
StructField("ay", DoubleType(), False),
StructField("az", DoubleType(), False),
StructField("sz_top", DoubleType(), False),
StructField("sz_bot", DoubleType(), False),
StructField("hit_distance_sc", DoubleType(), False),
StructField("launch_speed", DoubleType(), False),
StructField("launch_angle", DoubleType(), False),
StructField("effective_speed", DoubleType(), False),
StructField("release_spin_rate", DoubleType(), False),
StructField("release_extension", DoubleType(), False),
StructField("game_pk", StringType(), False),
StructField("pitcher.1", StringType(), False),
StructField("fielder_2.1", StringType(), False),
StructField("fielder_3", StringType(), False),
StructField("fielder_4", StringType(), False),
StructField("fielder_5", StringType(), False),
StructField("fielder_6", StringType(), False),
StructField("fielder_7", StringType(), False),
StructField("fielder_8", StringType(), False),
StructField("fielder_9", StringType(), False),
StructField("release_pos_y", DoubleType(), False),
StructField("estimated_ba_using_speedangle", DoubleType(), False),
StructField("estimated_woba_using_speedangle", DoubleType(), False),
StructField("woba_value", DoubleType(), False),
StructField("woba_denom", DoubleType(), False),
StructField("babip_value", DoubleType(), False),
StructField("iso_value", DoubleType(), False),
StructField("launch_speed_angle", DoubleType(), False),
StructField("at_bat_number", LongType(), False),
StructField("pitch_number", LongType(), False),
StructField("pitch_name", StringType(), False),
StructField("home_score", LongType(), False),
StructField("away_score", LongType(), False),
StructField("bat_score", LongType(), False),
StructField("fld_score", LongType(), False),
StructField("post_away_score", LongType(), False),
StructField("post_home_score", LongType(), False),
StructField("post_bat_score", LongType(), False),
StructField("post_fld_score", LongType(), False),
StructField("if_fielding_alignment", StringType(), False),
StructField("of_fielding_alignment", StringType(), False),
StructField("spin_axis", DoubleType(), False),
StructField("delta_home_win_exp", DoubleType(), False),
StructField("delta_run_exp", DoubleType(), False)
]
)
# トラッキングデータ
QUERY_TRACKING_DATA = '''
select
pitch_type,
game_date,
release_speed,
release_pos_x,
release_pos_z,
batter,
pitcher,
events,
description,
spin_dir,
spin_rate_deprecated,
break_angle_deprecated,
break_length_deprecated,
zone,
des,
game_type,
stand,
p_throws,
home_team,
away_team,
type,
hit_location,
bb_type,
balls,
strikes,
game_year,
pfx_x,
pfx_z,
plate_x,
plate_z,
on_3b,
on_2b,
on_1b,
outs_when_up,
inning,
inning_topbot,
hc_x,
hc_y,
tfs_deprecated,
tfs_zulu_deprecated,
fielder_2,
umpire,
sv_id,
vx0,
vy0,
vz0,
ax,
ay,
az,
sz_top,
sz_bot,
hit_distance_sc,
launch_speed,
launch_angle,
effective_speed,
release_spin_rate,
release_extension,
game_pk,
`pitcher.1` as fielder_p,
`fielder_2.1` as fielder_c,
fielder_3 as fielder_1b,
fielder_4 as fielder_2b,
fielder_5 as fielder_3b,
fielder_6 as fielder_ss,
fielder_7 as fielder_lf,
fielder_8 as fielder_cf,
fielder_9 as fielder_rf,
release_pos_y,
estimated_ba_using_speedangle,
estimated_woba_using_speedangle,
woba_value,
woba_denom,
babip_value,
iso_value,
launch_speed_angle,
at_bat_number,
pitch_number,
pitch_name,
home_score,
away_score,
bat_score,
fld_score,
post_away_score,
post_home_score,
post_bat_score,
post_fld_score,
if_fielding_alignment,
of_fielding_alignment,
spin_axis,
delta_home_win_exp,
delta_run_exp
from batterCsv
order by game_date, game_pk, inning, inning_topbot desc, at_bat_number, pitch_number
'''
def read_csv(date: str, filename: str, schema: StructType = None) -> SparkDataFrame:
try:
return spark.read.format('csv').options(header="true", inferSchema="true").load(f'{GCS_ROOT}/{date}/{filename}', schema=schema)
except AnalysisException:
return None
def get_dataframe(date: datetime, filename: str, schema: StructType = None) -> SparkDataFrame:
dir_name = date.strftime(DATE_FORMAT)
return read_csv(dir_name, filename, schema)
def save_bigquery(sdf: SparkDataFrame, table_name: str):
sdf.write\
.mode('append') \
.format('bigquery') \
.option('table', f'{GCP_PROJECT}.{table_name}') \
.option('temporaryGcsBucket', GCS_BUCKET) \
.option('createDisposition', 'CREATE_NEVER') \
.save()
def execute(run_date: datetime):
# それぞれのDataFrame
sdf_batter = get_dataframe(run_date, 'batter.csv', schema=STATCAST_SCHEMA)
sdf_pitcher = get_dataframe(run_date, 'pitcher.csv')
# Temporary table
sdf_batter.createOrReplaceTempView('batterCsv')
sdf_pitcher.createOrReplaceTempView('pitcherCSV')
sdf_tracking_data = spark.sql(QUERY_TRACKING_DATA)
sdf_tracking_data.createOrReplaceTempView('trackingDataset')
save_bigquery(sdf_tracking_data, 'tracking')
# 選手リスト
sdf_player_b = spark.sql('select distinct game_year, game_date, batter as player_id, player_name from batterCsv')
sdf_player_p = spark.sql('select distinct game_year, game_date, pitcher as player_id, player_name from pitcherCSV')
sdf_player = sdf_player_b.unionAll(sdf_player_p)
sdf_player.createOrReplaceTempView('playerDataset')
sdf_player = spark.sql('select distinct game_year, game_date, player_id, player_name from playerDataset')
sdf_player = sdf_player.withColumn("game_year", sdf_player["game_year"].cast(LongType()))
sdf_player = sdf_player.withColumn("game_date", sdf_player["game_date"].cast(DateType()))
save_bigquery(sdf_player, 'player')
if __name__ == "__main__":
# TODO イベントトリガーで貰うのが理想
_date: datetime = datetime(2022, 7, 8)
execute(run_date=_date)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment