Skip to content

Instantly share code, notes, and snippets.

@c74p
Created January 30, 2019 19:36
Show Gist options
  • Save c74p/9bd2d706a71fde07a7141dfdf92eadec to your computer and use it in GitHub Desktop.
Save c74p/9bd2d706a71fde07a7141dfdf92eadec to your computer and use it in GitHub Desktop.
@composite
def create_dataframes(draw) -> Dict[str, pd.DataFrame]:
"""Generate dataframes for property-based testing."""
# create strategies to be used in creating dataframes
# define a 'plus_nan' strategy wrapper to explicitly include np.NaN
@composite
def plus_nan(draw, strat: SearchStrategy) -> SearchStrategy:
return draw(one_of(just(np.NaN), strat))
stores = integers(min_value=0, max_value=2000)
stores_plus_nan = plus_nan(stores)
states = sampled_from(state_abbreviations)
states_plus_nan = plus_nan(states)
dates = datetimes(min_value=datetime.datetime(2013, 1, 1),
max_value=datetime.datetime(2015, 12, 12))
dates_plus_nan = plus_nan(dates)
integers_plus_nan = plus_nan(integers())
# Take the 'states' strategy and prepend 'Rossmann_DE' to what it gives you
# Then add in NaN as a possibility for good measure
google_files = states.flatmap(lambda state: just('Rossmann_DE_' + state))
google_files_plus_nan = plus_nan(google_files)
# create the strategy for spelling out a google_week entry (and add nan)
@composite
def create_google_weeks(draw, strat: SearchStrategy) -> SearchStrategy:
day = draw(dates)
idx = (day.weekday() + 1) % 7
last_sun = day - datetime.timedelta(idx)
next_sat = last_sun + datetime.timedelta(6)
return last_sun.strftime('%Y-%m-%d') + ' - ' +\
next_sat.strftime('%Y-%m-%d')
google_weeks_plus_nan = plus_nan(create_google_weeks(dates))
# Create dataframes from the strategies above
# Note that each column has one of three strategies that include possible
# nan values:
# 1) It explicitly includes the 'plus_nan' wrapper
# 2) It's sampled_from a list that explicitly includes nan
# 3) It uses the 'floats' strategy, with allow_nan=True. (The 'floats'
# strategy implicitly allows nans but PEP 20 dude)
google_df = draw(data_frames([
column('file', elements=google_files_plus_nan),
column('week', elements=google_weeks_plus_nan),
column('trend',
elements=plus_nan(integers(min_value=0, max_value=100)))]))
# Since this file is crucial to structuring the merged pdf, it's hard-coded
state_names_df = pd.DataFrame({'StateName': state_names,
'State': state_abbreviations})
# Note index gives min and max sizes for this dataframe (not empty)
stores_df = draw(data_frames(columns=[
column('Store', elements=stores_plus_nan, unique=True),
column('StoreType',
elements=sampled_from(['a', 'b', 'c', 'd', np.NaN])),
column('Assortment', elements=sampled_from(['a', 'b', 'c', np.NaN])),
column('CompetitionDistance',
elements=floats(allow_infinity=False, allow_nan=True)),
column('CompetitionOpenSinceMonth',
elements=floats(allow_infinity=False, allow_nan=True)),
column('CompetitionOpenSinceYear',
elements=floats(allow_infinity=False, allow_nan=True)),
column('Promo2', elements=sampled_from([0, 1, np.NaN])),
column('Promo2SinceWeek',
elements=floats(allow_infinity=False, allow_nan=True)),
column('Promo2SinceYear',
elements=floats(allow_infinity=False, allow_nan=True)),
column('PromoInterval',
elements=sampled_from(['Feb,May,Aug,Nov', 'Jan,Apr,Jul,Oct',
'Mar,Jun,Sept,Dec', np.NaN]))],
index=range_indexes(min_size=10, max_size=1000)))
store_states_df = draw(data_frames([
column('Store', elements=stores_plus_nan, unique=True),
column('State', elements=states_plus_nan)
]))
# Note index gives min and max sizes for this dataframe (not empty)
train_df = draw(data_frames(columns=[
column('Store', elements=stores_plus_nan),
column('DayOfWeek', elements=integers_plus_nan),
column('Date', elements=dates_plus_nan),
column('Sales', elements=integers_plus_nan),
column('Customers', elements=integers_plus_nan),
column('Open', elements=sampled_from([0, 1, np.NaN])),
column('Promo', elements=sampled_from([0, 1, np.NaN])),
column('StateHoliday',
elements=sampled_from(['0', 'a', 'b', 'c', np.NaN])),
column('SchoolHoliday', elements=sampled_from([0, 1, np.NaN]))],
index=range_indexes(min_size=10, max_size=10000)
))
# Note that there are a lot of integer-valued columns in here; that's what
# came out of the original dataframe. May need to revisit whether it's
# better to code these as floats from the beginning.
weather_df = draw(data_frames([
column('file', elements=sampled_from([np.NaN] + state_names)),
column('date', elements=dates_plus_nan),
column('Max_TemperatureC', elements=integers_plus_nan),
column('Mean_TemperatureC', elements=integers_plus_nan),
column('Min_TemperatureC', elements=integers_plus_nan),
column('Dew_PointC', elements=integers_plus_nan),
column('MeanDew_PointC', elements=integers_plus_nan),
column('Min_DewpointC', elements=integers_plus_nan),
column('Max_Humidity', elements=integers_plus_nan),
column('Mean_Humidity', elements=integers_plus_nan),
column('Min_Humidity', elements=integers_plus_nan),
column('Max_Sea_Level_PressurehPa', elements=integers_plus_nan),
column('Mean_Sea_Level_PressurehPa', elements=integers_plus_nan),
column('Min_Sea_Level_PressurehPa', elements=integers_plus_nan),
column('Max_VisibilityKm',
elements=floats(allow_infinity=False, allow_nan=True)),
column('Mean_VisibilityKm',
elements=floats(allow_infinity=False, allow_nan=True)),
column('Min_VisibilitykM',
elements=floats(allow_infinity=False, allow_nan=True)),
column('Max_Wind_SpeedKm_h', elements=integers_plus_nan),
column('Mean_Wind_SpeedKm_h', elements=integers_plus_nan),
column('Max_Gust_SpeedKm_h',
elements=floats(allow_infinity=False, allow_nan=True)),
column('Precipitationmm',
elements=floats(allow_infinity=False, allow_nan=True)),
column('CloudCover', elements=sampled_from(['NA', np.NaN] +
[str(i) for i in range(0, 9)])),
column('Events', elements=sampled_from([np.NaN] +
['Rain', 'Fog-Rain-Snow', 'Snow', 'Rain-Snow', 'Fog-Snow',
'Rain-Thunderstorm', 'Rain-Snow-Hail', 'Fog-Rain', 'Fog',
'Fog-Snow-Hail', 'Thunderstorm', 'Fog-Rain-Thunderstorm',
'Rain-Snow-Hail-Thunderstorm', 'Fog-Rain-Hail', 'Rain-Hail',
'Rain-Hail-Thunderstorm', 'Fog-Rain-Snow-Hail',
'Fog-Thunderstorm', 'Rain-Snow-Thunderstorm',
'Fog-Rain-Hail-Thunderstorm', 'Snow-Hail'])),
column('WindDirDegrees', elements=integers_plus_nan),
]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment