Created
January 4, 2024 21:20
-
-
Save jeffmylife/92b7b0f7aed77c3686876b79beea316f to your computer and use it in GitHub Desktop.
Sampling without replacement to acheive the most balanced representation across skewed distributions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
def sample_dataframe(df, date_column, N): | |
""" | |
Samples a pandas DataFrame to achieve a balanced representation across different date groups. | |
This function ensures that smaller date groups (those with counts less than an average size) | |
are fully represented in the sample. Larger date groups are sampled uniformly to contribute | |
towards a total sample size of N. If the initial sampling process results in a total sample size | |
less than N, additional samples are randomly selected from the remaining data to meet the target size. | |
The final sample is shuffled to randomize the order. | |
Parameters: | |
df (pd.DataFrame): The DataFrame to be sampled. | |
date_column (str): The column name in df that contains date values. | |
N (int): The desired total size of the sample. | |
Returns: | |
pd.DataFrame: A DataFrame of size N, with a balanced representation of different dates. | |
""" | |
# Group by date and calculate group sizes | |
grouped = df.groupby(date_column) | |
group_sizes = grouped.size() | |
# Identify small groups and calculate remaining sample size after including them | |
small_groups = group_sizes[group_sizes < N // len(group_sizes)] | |
remaining_sample_size = N - small_groups.sum() | |
# Calculate sample size for larger groups | |
large_group_sample_size = remaining_sample_size // max( | |
1, len(group_sizes) - len(small_groups) | |
) | |
# Sample from each group | |
sampled_dfs = [ | |
group | |
if name in small_groups.index | |
else group.sample(n=min(large_group_sample_size, len(group))) | |
for name, group in grouped | |
] | |
# Combine and shuffle the samples | |
sampled_df = pd.concat(sampled_dfs).sample(frac=1).reset_index(drop=True) | |
# Ensure total sample size is N, adding or trimming if necessary | |
if len(sampled_df) < N: | |
additional_samples = df.sample(n=(N - len(sampled_df))) | |
sampled_df = ( | |
pd.concat([sampled_df, additional_samples]) | |
.sample(frac=1) | |
.reset_index(drop=True) | |
) | |
elif len(sampled_df) > N: | |
sampled_df = sampled_df.head(N) | |
return sampled_df | |
# Mock DataFrame creation | |
def create_mock_dataframe(): | |
data = { | |
"date": ["2023-01-01"] * 50 | |
+ ["2023-01-02"] * 30 | |
+ ["2023-01-03"] * 20 | |
+ ["2024-01-01"] * 2 | |
} | |
return pd.DataFrame(data) | |
# Test function | |
def test_sample_dataframe(): | |
df = create_mock_dataframe() | |
print("Value counts before sampling:") | |
print(df["date"].value_counts()) | |
sampled_df = sample_dataframe(df, "date", 60) # Let's take a sample of size 60 | |
print("\nValue counts after sampling:") | |
print(sampled_df["date"].value_counts()) | |
# Check if the total sample size is 60 | |
assert len(sampled_df) == 60, "Sample size does not match the desired size" | |
print("\nTest passed: The sample is correctly sized and balanced.") | |
# Run the test | |
test_sample_dataframe() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment