Skip to content

Instantly share code, notes, and snippets.

@jeffmylife
Created January 4, 2024 21:20
Show Gist options
  • Save jeffmylife/92b7b0f7aed77c3686876b79beea316f to your computer and use it in GitHub Desktop.
Save jeffmylife/92b7b0f7aed77c3686876b79beea316f to your computer and use it in GitHub Desktop.
Sampling without replacement to acheive the most balanced representation across skewed distributions
import pandas as pd
def sample_dataframe(df, date_column, N):
"""
Samples a pandas DataFrame to achieve a balanced representation across different date groups.
This function ensures that smaller date groups (those with counts less than an average size)
are fully represented in the sample. Larger date groups are sampled uniformly to contribute
towards a total sample size of N. If the initial sampling process results in a total sample size
less than N, additional samples are randomly selected from the remaining data to meet the target size.
The final sample is shuffled to randomize the order.
Parameters:
df (pd.DataFrame): The DataFrame to be sampled.
date_column (str): The column name in df that contains date values.
N (int): The desired total size of the sample.
Returns:
pd.DataFrame: A DataFrame of size N, with a balanced representation of different dates.
"""
# Group by date and calculate group sizes
grouped = df.groupby(date_column)
group_sizes = grouped.size()
# Identify small groups and calculate remaining sample size after including them
small_groups = group_sizes[group_sizes < N // len(group_sizes)]
remaining_sample_size = N - small_groups.sum()
# Calculate sample size for larger groups
large_group_sample_size = remaining_sample_size // max(
1, len(group_sizes) - len(small_groups)
)
# Sample from each group
sampled_dfs = [
group
if name in small_groups.index
else group.sample(n=min(large_group_sample_size, len(group)))
for name, group in grouped
]
# Combine and shuffle the samples
sampled_df = pd.concat(sampled_dfs).sample(frac=1).reset_index(drop=True)
# Ensure total sample size is N, adding or trimming if necessary
if len(sampled_df) < N:
additional_samples = df.sample(n=(N - len(sampled_df)))
sampled_df = (
pd.concat([sampled_df, additional_samples])
.sample(frac=1)
.reset_index(drop=True)
)
elif len(sampled_df) > N:
sampled_df = sampled_df.head(N)
return sampled_df
# Mock DataFrame creation
def create_mock_dataframe():
data = {
"date": ["2023-01-01"] * 50
+ ["2023-01-02"] * 30
+ ["2023-01-03"] * 20
+ ["2024-01-01"] * 2
}
return pd.DataFrame(data)
# Test function
def test_sample_dataframe():
df = create_mock_dataframe()
print("Value counts before sampling:")
print(df["date"].value_counts())
sampled_df = sample_dataframe(df, "date", 60) # Let's take a sample of size 60
print("\nValue counts after sampling:")
print(sampled_df["date"].value_counts())
# Check if the total sample size is 60
assert len(sampled_df) == 60, "Sample size does not match the desired size"
print("\nTest passed: The sample is correctly sized and balanced.")
# Run the test
test_sample_dataframe()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment