Skip to content

Instantly share code, notes, and snippets.

@bmritz
Last active July 20, 2021 00:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bmritz/dd33a47f8919c768e3179d5f5db43a52 to your computer and use it in GitHub Desktop.
Save bmritz/dd33a47f8919c768e3179d5f5db43a52 to your computer and use it in GitHub Desktop.
A way to sample a population to desired characteristics
import pandas as pd
import math, random, copy
CASEID = 'respid'
PATH = '/Users/bridgetlittleton/Box Sync/WWB/Survey Data - CONFIDENTIAL/WWB_Main_Survey_Data_20210629.xls'
profpos_mapping = {'Research scientist or similar [INDIA/UK/USA]': 'j',
'Masters / Ph.D. student [ITALY]': 'g',
'Masters / Ph.D. student [INDIA/UK/USA]': 'g',
'Assistant Professor / Lecturer [INDIA/UK/USA]': 'j',
'Postdoctoral fellow [INDIA/UK/USA]': 'g',
'Research fellow / post-doc [ITALY]': 'g',
'Other (Please specify)': 'other',
'Associate Professor / Senior Lecturer or Reader [INDIA/UK/USA]': 's',
'Full Professor [INDIA/UK/USA]': 's',
'Research Professor [INDIA/UK/USA]': 's',
'Adjunct professor / Professore a contratto [ITALY]': 'j',
'Non-tenured assistant professor / Ricercatore R. T.D. A [ITALY]': 'j',
'Associate professor [ITALY]': 's',
'Professor with named chair [INDIA/UK/USA]': 's',
'Tenure-track assistant professor / Ricercatore R. T.D. B [ITALY]': 'j',
'Full, tenured professor [ITALY]': 's'}
well_being_mapping = {
str(v): 'HIGH' if v >=8 else ("MEDIUM" if v >=4 else "LOW")
for v in range(0, 11)
}
job_satisfaction_mapping = {'Mostly satisfied': "HIGH",
'Neither satisfied nor dissatisfied': "NEITHER",
'Completely satisfied': "HIGH",
'Mostly dissatisfied': "LOW",
'Completely dissatisfied': "LOW",
'I prefer not to answer': "PREFER NOT TO ANSWER"}
roleaes_mapping = {'Mostly agree': "Mostly/Completely Agree",
'Completely agree': "Mostly/Completely Agree",
'Neither agree nor disagree': "OTHER",
'Mostly disagree': "Mostly/Completely Disagree",
'Completely disagree': "Mostly/Completely Disagree",
'I prefer not to answer': 'OTHER',
'Don’t know / Not sure': "OTHER"}
religion_mapping = {'I don’t follow a religion & don’t consider myself to be a spiritual person inter': "Neither Spiritual nor Religious",
'I follow a religion, but don’t consider myself to be a spiritual person interest': "Spiritual or Religious",
'I follow a religion & consider myself to be a spiritual person interested in the': "Spiritual or Religious",
'I don’t follow a religion, but consider myself to be a spiritual person interest': "Spiritual or Religious",
'I prefer not to answer': "Neither Spiritual nor Religious"}
def filter_data_to_usable_rows(df):
msk1 = df['perm'] == 'Yes'
msk2 = df['status'] == 'Complete'
return df[msk1 & msk2]
def create_binned_columns(df):
df['career_level_grp'] = df.profpos.map(profpos_mapping)
df['life_satisfaction'] = df.hflifesat_1.map(well_being_mapping)
df['mental_health'] = df.hfmental_1.map(well_being_mapping)
df['job_satisfaction'] = df.jobsatis.map(job_satisfaction_mapping)
df['roleaes_5_binned'] = df.roleaes_5.map(roleaes_mapping)
df['roleaes_3_binned'] = df.roleaes_3.map(roleaes_mapping)
df['religious_or_spiritual'] = df.sprtlty.map(religion_mapping)
df['mistreatment'] = (df.loc[:,['mistr_{}'.format(i) for i in range(1, 8)]] == 'Yes, I have experienced this').any(axis=1).astype(str)
return df
def import_data(path, caseid_column):
"""Import the data from a CSV.
Inputs:
path(str): Path to the data
caseid_column(str): column name of the caseid column
"""
df = pd.read_excel(path, dtype=str)
df = create_binned_columns(df)
assert df[caseid_column].nunique() == len(df)
return filter_data_to_usable_rows(df).to_dict(orient="records")
data = import_data(PATH, CASEID)
target_distributions = {'s_country': [('India', 0.25), ('UK', 0.25), ('USA', 0.25), ('Italy', 0.25)],
'S_Field': [('Physics', 0.5), ('Biology', 0.5)]}
def get_sample_for_col(col):
vc_sample = df_sample[col].value_counts()
vc_sample = vc_sample/vc_sample.sum()
vc_sample = vc_sample.rename(col + "_sample")
vc = df[col].value_counts()
vc = vc/vc.sum()
return pd.concat([vc, vc_sample], axis=1)
MINIMUMS = {
"career_level_grp": pd.Series({'g': 40, 's': 40, 'j': 40}),
"S_Gender": pd.Series({'F': 50}),
"coveff_11": pd.Series({"I was infected with COVID-19": 50}),
'religious_or_spiritual': pd.Series({"Spiritual or Religious": 45}),
"beautywork_1": pd.Series({"My workplace": 50}),
'roleaes_3_binned': pd.Series({"Mostly/Completely Disagree": 50}),
'roleaes_5_binned': pd.Series({"Mostly/Completely Agree": 50}),
'bcon_1': pd.Series({"motivated me to pursue a scientific career": 15}),
'bcon_2': pd.Series({"has been life-changing for me": 15}),
'bcon_3': pd.Series({"helps me persevere when I experience difficulties or failure in my work": 15}),
'bcon_4': pd.Series({"improves scientific understanding": 15}),
'job_satisfaction': pd.Series({"HIGH": 50, "LOW": 50}),
'life_satisfaction': pd.Series({"HIGH": 50, "LOW": 50}),
'mental_health': pd.Series({"HIGH": 50, "LOW": 50}),
'mistreatment': pd.Series({"True": 50}),
}
class SamplePool(object):
def __init__(self, target_distributions, n_samples, minimums):
"""The sample pool of caseids.
Inputs:
target_distribution(dict): {colname: [("value", "pct") for value in columns]}
"""
self.case_ids = list()
self._id_col = CASEID
self.n_samples = n_samples
self.minimums = minimums
self.target_n = self._process_target_distributions(target_distributions)
def _process_target_distributions(self, target_distributions):
"""Calculate the number of samples we want for each value and validate pcts equal to 100."""
ret = copy.deepcopy(self.minimums)
for k, v in target_distributions.items():
assert math.isclose(sum([tup[1] for tup in v]), 1)
ret.update({
k: pd.Series({tup[0]: tup[1]*self.n_samples for tup in v})
})
return ret
def calculate_distributions(self, population, fields=None):
"""Calculate the distributions of the sample pool vs the population."""
fields = fields if fields is not None else set(list(self.target_n.keys())+list(self.minimums.keys()))
pop_ids = [d[self._id_col] for d in population]
assert len(pop_ids) == len(set(pop_ids)), "Population IDs are not unique."
assert all(id_ in pop_ids for id_ in self.case_ids), "At least 1 sample id not in population."
df = pd.DataFrame(population)
msk = df[self._id_col].isin(self.case_ids)
df = df[msk]
return {field: df[field].value_counts().reindex_like(self.target_n[field]).fillna(0)
for field in fields}
def distance_from_desired(self, population):
"""Return the outstanding number of values still needed to meet our target_n."""
distributions = self.calculate_distributions(population)
return {field: (ser - distributions[field]) for field , ser in self.target_n.items()}
def values_to_be_filled(self, population):
"""Return list of tuples of colname, value that still need to be filled."""
ret = []
for field, ser in self.distance_from_desired(population).items():
vals = ser.where(lambda x: x>0).dropna()
for v in vals.index.values:
ret.append((field, v, vals[v]))
return ret
def pick_next_caseid(self, population):
vals_to_be_filled = self.values_to_be_filled(population)
accum = []
for row in population:
caseid = row[self._id_col]
if caseid in self.case_ids:
continue
n_matches = 0 # number of "gaps this row fills"
for tup in row.items():
for t in vals_to_be_filled:
if tup == t[:2]:
n_matches += t[2]
accum.append((caseid, n_matches))
best = list(sorted(accum, key=lambda tup: (-tup[1], random.random())))[0]
if best[1]>0:
return best[0]
raise ValueError("No more cases will fill gaps.")
def print_distributions(self, population):
for col, ser in self.calculate_distributions(population).items():
print(col + ":")
print()
print(ser.to_string())
print('----------------------')
def calculate_sample(self, population):
"""Calculate the case_ids that should be in the sample."""
n_samples_picked = len(self.case_ids)
while n_samples_picked < self.n_samples:
best_case_to_add = self.pick_next_caseid(population)
self.case_ids.append(best_case_to_add)
n_samples_picked = len(self.case_ids)
if n_samples_picked % 10 ==0:
print("{} samples complete.".format(n_samples_picked))
self.print_distributions(population)
return self.case_ids, self.calculate_distributions(population)
def test_sample_pool():
data = [{'caseid': 1, 'country': 'India', 'sex': 'M'},
{'caseid': 2, 'country': 'UK', 'sex': 'M'},
{'caseid': 3, 'country': 'UK', 'sex': 'M'},
{'caseid': 4, 'country': 'US', 'sex': 'F'},
{'caseid': 5, 'country': 'US', 'sex': 'F'},
{'caseid': 6, 'country': 'US', 'sex': 'M'},
{'caseid': 7, 'country': 'US', 'sex': 'F'},
{'caseid': 8, 'country': 'UK', 'sex': 'F'},
{'caseid': 9, 'country': 'UK', 'sex': 'M'},
{'caseid': 10, 'country': 'Italy', 'sex': 'M'},
{'caseid': 11, 'country': 'Italy', 'sex': 'M'},
{'caseid': 12, 'country': 'Italy', 'sex': 'F'}]
target_distributions = {"country": [(k, .25) for k in ['India', 'UK', 'US', 'Italy']],
"sex": [("M", .5), ("F",.5)]}
n_samples = 6
p = SamplePool(target_distributions, n_samples)
assert (p.target_n['country'] == 1.5).all()
assert (p.target_n['sex'] == 3).all()
assert all((p.calculate_distributions(data)[field]==0).all() for field in target_distributions)
assert p.values_to_be_filled(data) == [('country', 'India'),
('country', 'UK'),
('country', 'US'),
('country', 'Italy'),
('sex', 'M'),
('sex', 'F')]
assert p.pick_next_caseid(data) == 1
def test_sample_pool2():
data = [{'caseid': 1, 'country': 'India', 'sex': 'M'},
{'caseid': 2, 'country': 'UK', 'sex': 'M'},
{'caseid': 3, 'country': 'UK', 'sex': 'M'},
{'caseid': 4, 'country': 'US', 'sex': 'F'},
{'caseid': 5, 'country': 'US', 'sex': 'F'},
{'caseid': 6, 'country': 'US', 'sex': 'M'},
{'caseid': 7, 'country': 'US', 'sex': 'F'},
{'caseid': 8, 'country': 'UK', 'sex': 'F'},
{'caseid': 9, 'country': 'UK', 'sex': 'M'},
{'caseid': 10, 'country': 'Italy', 'sex': 'M'},
{'caseid': 11, 'country': 'Italy', 'sex': 'M'},
{'caseid': 12, 'country': 'India', 'sex': 'F'}]
target_distributions = {"country": [(k, .25) for k in ['India', 'UK', 'US', 'Italy']],
"sex": [("M", .75), ("F",.25)]}
n_samples = 8
p = SamplePool(target_distributions, n_samples)
assert (p.target_n['country'] == 2).all()
assert p.target_n['sex']['M'] == 6
assert p.target_n['sex']['F'] == 2
case_ids, dists = p.calculate_sample(data)
assert (dists['country'] == 2).all()
assert (dists['sex']["M"] == 6)
assert (dists['sex']["F"] == 6)
# potential strategies to mitigate this situation
# try again. give a hint by adding a few case ids at the beginning and let it run (below)
# pick smarter. take into account the population dist difference from target_n to pick "scarce" values first
p.case_ids = [1,12]
p.calculate_distributions()
p.calculate_sample(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment