Created
July 9, 2018 18:18
-
-
Save davidmcclure/c336185b5cb5b69ca9fc2a67e85180d0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
import pandas as pd | |
from pyspark.sql.functions import udf | |
from pyspark.sql import types as T | |
from geovec_data.utils import get_spark | |
COLS = ('name', 'pct', 'cum_pct', 'rank') | |
def read_name_pct(path): | |
"""Read Census name table. | |
""" | |
df = pd.read_csv(path, sep='\s+', header=None, names=COLS) | |
df.name = df.name.str.lower() | |
return {row.name: row.pct for row in df.itertuples()} | |
class GenderRatios: | |
def __init__(self, m_path, f_path): | |
self.m_name_pct = read_name_pct(m_path) | |
self.f_name_pct = read_name_pct(f_path) | |
self.m_min_pct = min(self.m_name_pct.values()) | |
self.f_min_pct = min(self.f_name_pct.values()) | |
def m_ratio(self, name): | |
m_pct = self.m_name_pct.get(name) | |
f_pct = self.f_name_pct.get(name, self.f_min_pct) | |
return m_pct / f_pct if m_pct else None | |
def f_ratio(self, name): | |
f_pct = self.f_name_pct.get(name) | |
m_pct = self.m_name_pct.get(name, self.m_min_pct) | |
return f_pct / m_pct if f_pct else None | |
def gender_ratio_udf(ratios, m=False, f=False): | |
"""Get M or F ratio. | |
""" | |
@udf(T.FloatType()) | |
def work(name): | |
if not name or not name.strip(): | |
return | |
first_name = name.strip().lower().split()[0] | |
if m: return ratios.m_ratio(first_name) | |
elif f: return ratios.f_ratio(first_name) | |
return work | |
@click.command() | |
@click.argument('m_name_src', type=click.Path()) | |
@click.argument('f_name_src', type=click.Path()) | |
@click.argument('geo_tweet_src', type=click.Path()) | |
@click.argument('dest', type=click.Path()) | |
def main(m_name_src, f_name_src, geo_tweet_src, dest): | |
"""Infer gender from first names + US census data. | |
""" | |
sc, spark = get_spark() | |
tweets = spark.read.parquet(geo_tweet_src) | |
ratios = GenderRatios(m_name_src, f_name_src) | |
m_ratio = gender_ratio_udf(ratios, m=True) | |
f_ratio = gender_ratio_udf(ratios, f=True) | |
tweets = tweets.withColumn('m_ratio', m_ratio(tweets.actor_display_name)) | |
tweets = tweets.withColumn('f_ratio', f_ratio(tweets.actor_display_name)) | |
tweets.write.mode('overwrite').parquet(dest) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment