Skip to content

Instantly share code, notes, and snippets.

@ckhung
Created May 15, 2022
Embed
What would you like to do?
juxtaposed "binary" Marimekko charts, good for visualizing Simplson's paradox invloving some "hit" ratios
year Derek Derek_hit David David_hit
1995 48 12 411 104
1996 582 183 140 45
# juxtaposed "binary" Marimekko charts
# https://towardsdatascience.com/marimekko-charts-with-pythons-matplotlib-6b9784ae73a1
# good for visualizing Simplson's paradox invloving some "hit" ratios:
# https://towardsdatascience.com/simpsons-paradox-d2f4d8f08d42
# https://towardsdatascience.com/what-is-simpsons-paradox-4a53cd4e9ee2
# python3 jxbmm.py -N Derek,David -T 1,3 -H 2,4 batting.csv
# python3 jxbmm.py -N Male,Female -T 1,4 -H 3,6 ucb.csv
# python3 jxbmm.py -N 'Treatment,No Treatment' -T 1,3 -H 2,4 treatment.csv
# https://matplotlib.org/stable/gallery/lines_bars_and_markers/horizontal_barchart_distribution.html
import numpy as np
import matplotlib.pyplot as plt
import argparse, sys
def Marimekko(ax, seg_name, seg_total, seg_hit):
seg_pref_sum, csum = [0], 0
for s in seg_total[:-1]:
csum += s
seg_pref_sum.append(csum)
Nseg = len(seg_total)
hit_percent = np.array([ seg_hit[i]/seg_total[i]*100 for i in range(Nseg) ])
ax.barh(seg_pref_sum, hit_percent, height=seg_total, align='edge', edgecolor='black')
ax.barh(seg_pref_sum, 100-hit_percent, left=hit_percent, height=seg_total, align='edge', edgecolor='black', color=(1, 1, 0.8))
ax.set_yticks( [ seg_pref_sum[i]+seg_total[i]/2 for i in range(Nseg)] )
ax.set_yticklabels(seg_name)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.xaxis.grid(True)
parser = argparse.ArgumentParser(
description='multiple Marimekko chart',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-s', '--segment', type=int, default=0,
help='column for bar names')
parser.add_argument('-N', '--entity_name', type=str, default='',
help='entity names separated by comma')
parser.add_argument('-T', '--seg_total', type=str, default='',
help='columns for segment totals, entities separated by comma')
parser.add_argument('-H', '--seg_hit', type=str, default='',
help='columns for segment hits, entities separated by comma')
parser.add_argument('csvfile', help='xxx.csv')
args = parser.parse_args()
fn = args.csvfile
data = np.genfromtxt(fn, delimiter=',', names=True, dtype=None, encoding=None)
# print(type(data))
# print(data.ndim)
# print(data)
if len(args.seg_total) < 1: sys.exit('need --seg_total')
i_seg_total = [ int(k) for k in args.seg_total.split(',') ]
if len(args.seg_hit) < 1: sys.exit('need --seg_hit')
i_seg_hit = [ int(k) for k in args.seg_hit.split(',') ]
entity_name = args.entity_name.split(',')
Nentity = len(entity_name)
if not (len(i_seg_hit) == Nentity and len(i_seg_total) == Nentity):
sys.exit('# of entities in --segment and --seg_total and --seg_hit must match')
Nseg = len(data)
seg_name = [ data[r][args.segment] for r in range(Nseg) ]
SegTotal = {
entity_name[g]: [ data[r][i_seg_total[g]] for r in range(Nseg) ] \
for g in range(Nentity)
}
SegHit = {
entity_name[g]: [ data[r][i_seg_hit[g]] for r in range(Nseg) ] \
for g in range(Nentity)
}
print('Segment Total: ', SegTotal)
print('Segment Hit: ', SegHit)
entity_total = [sum(x) for x in SegTotal.values()]
entity_hit = [sum(x) for x in SegHit.values()]
entity_limit = max(entity_total)
fig, ax = plt.subplots(Nentity, sharex='all')
plt.subplots_adjust(hspace=0.5)
for g in range(Nentity):
ax[g].title.set_text('{}: overall {}/{} = {:.0f}%'.format(entity_name[g], entity_hit[g], entity_total[g], entity_hit[g]/entity_total[g]*100))
ax[g].set_ylim(0, entity_limit)
Marimekko(ax[g], seg_name, SegTotal[entity_name[g]], SegHit[entity_name[g]])
plt.show()
gender T Ti N Ni
M 50 45 350 280
F 150 90 50 20
Dept M Mr Ma F Fr Fa
A 825 62 512 108 82 89
B 560 63 353 25 68 17
C 325 37 120 593 34 202
D 417 33 138 375 35 131
E 191 28 53 393 24 94
F 373 6 22 341 7 24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment