Skip to content

Instantly share code, notes, and snippets.

@guillefix
Created March 3, 2020 23:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save guillefix/a9b71f2551e1a2240dc206db9a2730b7 to your computer and use it in GitHub Desktop.
Save guillefix/a9b71f2551e1a2240dc206db9a2730b7 to your computer and use it in GitHub Desktop.
some useful plots for generalization error / complexity data using plotly
import plotly.plotly as py
from plotly.graph_objs import *
def nice_2dhist(x,y,nbins,title='title',xlabel='x',ylabel='y',filename='nice-hist.png'):
py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv')
trace1 = {
"x": x,
"y": y,
"marker": {
"color": "rgb(255, 255, 255)",
"line": {"width": 0.5},
"opacity": 0.4,
"size": 4
},
"mode": "markers",
"name": "points",
"opacity": 0.75,
"text": [],
"type": "scatter",
"uid": "eb94b3"
}
trace2 = {
"x": x,
"y": y,
"autocolorscale": False,
"colorscale": [
[0, "rgb(8, 29, 88)"], [0.125, "rgb(37, 52, 148)"], [0.25, "rgb(34, 94, 168)"], [0.375, "rgb(29, 145, 192)"], [0.5, "rgb(65, 182, 196)"], [0.625, "rgb(127, 205, 187)"], [0.75, "rgb(199, 233, 180)"], [0.875, "rgb(237, 248, 217)"], [1, "rgb(255, 255, 217)"]],
"contours": {
"coloring": "fill",
"end": 80.05,
"showlines": True,
"size": 5,
"start": 5
},
"name": "density",
"ncontours": 20,
"reversescale": False,
"showscale": False,
"type": "histogram2dcontour",
"uid": "b20cd7",
"xbins": {
"end": max(x),
"size": (max(x)-min(x))/nbins,
"start": min(x)
},
"ybins": {
"end": max(y),
"size": (max(y)-min(y))/nbins,
"start": min(y)
},
"zmax": 83,
"zmin": 0
}
trace3 = {
"x": x,
"marker": {"color": "rgb(31, 119, 180)"},
"name": "x density",
"type": "histogram",
"uid": "70efa7",
"xbins": {
"end": max(x),
"size": (max(x)-min(x))/nbins,
"start": min(x)
},
"yaxis": "y2"
}
trace4 = {
"y": y,
"marker": {"color": "rgb(33, 113, 181)"},
"name": "y density",
"type": "histogram",
"uid": "73ca31",
"xaxis": "x2",
"ybins": {
"end": max(y),
"size": (max(y)-min(y))/nbins,
"start": min(y)
}
}
data = Data([trace1, trace2, trace3, trace4])
layout = {
"autosize": False,
"bargap": 0,
"height": 700,
"hovermode": "closest",
"margin": {"t": 50},
"paper_bgcolor": "rgb(249, 249, 249)",
"plot_bgcolor": "rgb(249, 249, 249)",
"showlegend": False,
"title": title,
"width": 800,
"xaxis": {
"autorange": True,
"domain": [0, 0.85],
"range": [min(x),max(x)],
"showgrid": False,
"title": xlabel,
"type": "linear",
"zeroline": False
},
"xaxis2": {
"autorange": True,
"domain": [0.85, 1],
"range": [min(x),max(x)],
"showgrid": False,
"title": "",
"type": "linear",
"zeroline": False
},
"yaxis": {
"autorange": True,
"domain": [0, 0.85],
"range": [min(y),max(y)],
"showgrid": False,
"title": ylabel,
"type": "linear",
"zeroline": False
},
"yaxis2": {
"autorange": True,
"domain": [0.85, 1],
"range": [min(y),max(y)],
"showgrid": False,
"title": "",
"type": "linear",
"zeroline": False
}
}
fig = Figure(data=data, layout=layout)
# plot_url = py.plot(fig)
# py.iplot(fig, filename='Comp_gen_erro_hist')
py.image.save_as(fig, filename=filename)
# import matplotlib.pyplot as plt
#
# %matplotlib inline
#
# plt.scatter(final_LZs[0], gen_errors[0])
#
# %matplotlib
# plt.clf()
#
# def forceAspect(ax,aspect=1):
# im = ax.get_images()
# extent = im[0].get_extent()
# ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect)
#
#
# idx=11
# heatmap, xedges, yedges = np.histogram2d(final_entss[idx], gen_errors[idx], bins=20)
# extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
#
# fig = plt.figure()
# ax = fig.add_subplot(111)
#
# # plt.ylim(yedges[0], yedges[-1])
# # plt.figure(figsize=(3,3))
# ax.imshow(heatmap.T, extent=extent, origin='lower')
# forceAspect(ax,aspect=1)
# fig.show()
#
# from matplotlib import cm as CM
#
# idx=3
# plt.hexbin(final_entss[idx], gen_errors[idx],gridsize=20,cmap=CM.jet, bins=None)
# plt.show()
# x=final_entss[idx]
# y=gen_errors[idx]
# nbins=10
def nice_2dhist_double(xa,ya,xb,yb,nbins,title='title',xlabel='x',ylabel='y',filename='nice-hist.png'):
x = xa+xb
y = ya+yb
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv')
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj')
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ')
# py.sign_in(username='guillefix5',api_key='lrZZKCPNIOwDtvgazKAL')
fontsize = 26
trace1a = {
"x": xa,
"y": ya,
"marker": {
"symbol": "circle",
"color": "rgb(31, 119, 180)",
"line": {"width": 1},
"opacity": 0.4,
"size": 8
},
"mode": "markers",
"name": "Neural network",
"opacity": 0.75,
"text": [],
"type": "scatter",
"uid": "eb94b3"
}
trace1b = {
"x": xb,
"y": yb,
"marker": {
"symbol": "diamond",
"color": "red",
"line": {"width": 0.1},
"opacity": 0.4,
"size": 8
},
"mode": "markers",
"name": "Unbiased learner",
"opacity": 0.75,
"text": [],
"type": "scatter",
"uid": "eb94b3"
}
# trace2a = {
# "x": xa,
# "y": ya,
# "autocolorscale": False,
# "colorscale": [[0, "rgb(8, 29, 88)"], [0.125, "rgb(37, 52, 148)"], [0.25, "rgb(34, 94, 168)"], [0.375, "rgb(29, 145, 192)"], [0.5, "rgb(65, 182, 196)"], [0.625, "rgb(127, 205, 187)"], [0.75, "rgb(199, 233, 180)"], [0.875, "rgb(237, 248, 217)"], [1, "rgb(255, 255, 217)"]],
# "contours": {
# "coloring": "fill",
# "end": 30,
# "showlines": True,
# "size": 5,
# "start": 5
# },
# "name": "density",
# "ncontours": 10,
# "reversescale": False,
# "showscale": False,
# "type": "histogram2dcontour",
# "uid": "b20cd7",
# "xbins": {
# "end": max(xa)+0.1,
# "size": (max(xa)-min(xa))/nbins,
# "start": min(xa)-0.1
# },
# "ybins": {
# "end": max(ya)+0.1,
# "size": (max(ya)-min(ya))/nbins,
# "start": min(ya)-0.1
# },
# "zmax": 30,
# "zmin": 0
# }
trace2b = {
"x": x,
"y": y,
"autocolorscale": False,
"colorscale": 'Greys',#
# "colorscale": [[0, "rgb(8, 29, 88)"], [0.125, "rgb(37, 52, 148)"], [0.25, "rgb(34, 94, 168)"], [0.375, "rgb(29, 145, 192)"], [0.5, "rgb(65, 182, 196)"], [0.625, "rgb(127, 205, 187)"], [0.75, "rgb(199, 233, 180)"], [0.875, "rgb(237, 248, 217)"], [1, "rgb(255, 255, 255)"]],
"contours": {
"coloring": "fill",
"end": 100,
"showlines": False,
"size": 5,
"start": 5
},
"name": "density",
"ncontours": 50,
"reversescale": True,
"showscale": False,
"type": "histogram2dcontour",
"uid": "b20cd7",
"xbins": {
"end": max(x)*1.1,
"size": (max(x)-min(x))/nbins,
"start": min(x)-max(x)*0.1
},
"ybins": {
"end": max(y)*1.1,
"size": (max(y)-min(y))/nbins,
"start": min(y)-0.05
},
"zmax": 50,
"zmin": 0
}
trace3a = {
"x": xa,
"marker": {"color": "rgb(31, 119, 180)"},
"name": "Entropy histogram, NN",
"type": "histogram",
"showlegend": False,
"uid": "70efa7",
"xbins": {
"end": max(x),
"size": (max(x)-min(x))/nbins,
"start": min(x)
},
"yaxis": "y2"
}
trace3b = {
"x": xb,
"marker": {"color": "red"},
"name": "Entropy histogram, unbiased",
"type": "histogram",
"showlegend": False,
"uid": "70efa7",
"xbins": {
"end": max(x),
"size": (max(x)-min(x))/nbins,
"start": min(x)
},
"yaxis": "y2"
}
trace4a = {
"y": ya,
"marker": {"color": "rgb(33, 113, 181)"},
"name": "Error histogram, NN",
"type": "histogram",
"showlegend": False,
"uid": "73ca31",
"xaxis": "x2",
"ybins": {
"end": max(y),
"size": (max(y)-min(y))/nbins,
"start": min(y)
}
}
trace4b = {
"y": yb,
"marker": {"color": "red"},
"name": "Error histogram, unbiased",
"type": "histogram",
"showlegend": False,
"uid": "73ca31",
"xaxis": "x2",
"ybins": {
"end": max(y),
"size": (max(y)-min(y))/nbins,
"start": min(y)
}
}
data = Data([trace1a, trace3a, trace4a, trace1b, trace2b, trace3b, trace4b])
layout = {
"autosize": False,
"bargap": 0,
"height": 700,
"hovermode": "closest",
"margin": {"t": 40},
# "paper_bgcolor": "rgb(249, 249, 249)",
# "plot_bgcolor": "rgb(249, 249, 249)",
"showlegend": True,
"font":dict(size=fontsize, color='black'),
"legend": {
"x": 0.07,
"y": 0.8,
"bgcolor":'#E2E2E2',
"bordercolor":'#FFFFFF',
"borderwidth":2
# font:dict(
# family='sans-serif',
# size=12,
# color='#000'
# ),
},
"title": title,
"width": 800,
"xaxis": {
"autorange": True,
"domain": [0, 0.85],
"range": [min(x),max(x)],
"showgrid": False,
"title": xlabel,
"type": "linear",
"zeroline": False,
"showline":True,
"titlefont":dict(size=fontsize, color='black')
},
"xaxis2": {
"autorange": True,
"domain": [0.85, 1],
"range": [min(x),max(x)],
"showgrid": False,
"showticklabels":False,
"title": "",
"type": "linear",
"zeroline": True
},
"yaxis": {
"autorange": True,
"domain": [0, 0.85],
"range": [min(y),max(y)],
"showgrid": False,
"title": ylabel,
"type": "linear",
"zeroline": False,
"showline":True,
"titlefont":dict(size=fontsize, color='black')
},
"yaxis2": {
"autorange": True,
"domain": [0.85, 1],
"range": [min(y),max(y)],
"showgrid": False,
"showticklabels":False,
"title": "",
"type": "linear",
"zeroline": True
}
}
fig = Figure(data=data, layout=layout)
# plot_url = py.plot(fig)
# py.iplot(fig, filename='Comp_gen_erro_hist')
py.image.save_as(fig, filename=filename)
def nice_2dhist_double_tight(xa,ya,xb,yb,nbins,title='title',xlabel='x',ylabel='y',filename='nice-hist.png'):
x = xa+xb
y = ya+yb
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv')
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj')
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ')
# py.sign_in(username='guillefix5',api_key='lrZZKCPNIOwDtvgazKAL')
trace1a = {
"x": xa,
"y": ya,
"marker": {
"symbol": "circle",
"color": "rgb(31, 119, 180)",
"line": {"width": 1},
"opacity": 0.4,
"size": 8
},
"mode": "markers",
"name": "Neural network",
"opacity": 0.75,
"text": [],
"type": "scatter",
"uid": "eb94b3"
}
trace1b = {
"x": xb,
"y": yb,
"marker": {
"symbol": "diamond",
"color": "red",
"line": {"width": 0.1},
"opacity": 0.4,
"size": 8
},
"mode": "markers",
"name": "Unbiased learner",
"opacity": 0.75,
"text": [],
"type": "scatter",
"uid": "eb94b3"
}
trace2b = {
"x": x,
"y": y,
"autocolorscale": False,
"colorscale": 'Greys',#
"contours": {
"coloring": "fill",
"end": 100,
"showlines": False,
"size": 5,
"start": 5
},
"name": "density",
"ncontours": 50,
"reversescale": True,
"showscale": False,
"type": "histogram2dcontour",
"uid": "b20cd7",
"xbins": {
"end": max(x)*1.1,
"size": (max(x)-min(x))/nbins,
"start": min(x)-max(x)*0.1
},
"ybins": {
"end": max(y)*1.1,
"size": (max(y)-min(y))/nbins,
"start": min(y)-0.05
},
"zmax": 50,
"zmin": 0
}
trace3a = {
"x": xa,
"marker": {"color": "rgb(31, 119, 180)"},
"name": "Entropy histogram, NN",
"type": "histogram",
"showlegend": False,
"uid": "70efa7",
"xbins": {
"end": max(x),
"size": (max(x)-min(x))/nbins,
"start": min(x)
},
"yaxis": "y2"
}
trace3b = {
"x": xb,
"marker": {"color": "red"},
"name": "Entropy histogram, unbiased",
"type": "histogram",
"showlegend": False,
"uid": "70efa7",
"xbins": {
"end": max(x),
"size": (max(x)-min(x))/nbins,
"start": min(x)
},
"yaxis": "y2"
}
trace4a = {
"y": ya,
"marker": {"color": "rgb(33, 113, 181)"},
"name": "Error histogram, NN",
"type": "histogram",
"showlegend": False,
"uid": "73ca31",
"xaxis": "x2",
"ybins": {
"end": max(y),
"size": (max(y)-min(y))/nbins,
"start": min(y)
}
}
trace4b = {
"y": yb,
"marker": {"color": "red"},
"name": "Error histogram, unbiased",
"type": "histogram",
"showlegend": False,
"uid": "73ca31",
"xaxis": "x2",
"ybins": {
"end": max(y),
"size": (max(y)-min(y))/nbins,
"start": min(y)
}
}
data = Data([trace1a, trace3a, trace4a, trace1b, trace2b, trace3b, trace4b])
layout = {
"autosize": False,
"bargap": 0,
"height": 700,
"hovermode": "closest",
"margin": {"t": 70},
# "paper_bgcolor": "rgb(249, 249, 249)",
# "plot_bgcolor": "rgb(249, 249, 249)",
"showlegend": False,
"font":dict(size=30, color='black'),
"legend": {
"x": 0.07,
"y": 0.8,
"bgcolor":'#E2E2E2',
"bordercolor":'#FFFFFF',
"borderwidth":2
# font:dict(
# family='sans-serif',
# size=12,
# color='#000'
# ),
},
"title": title,
"width": 800,
"xaxis": {
"autorange": True,
"domain": [0.05, 0.9],
"range": [min(x),max(x)],
"showgrid": False,
"title": xlabel,
"type": "linear",
"zeroline": False,
"showline":True,
"titlefont":dict(size=30, color='black')
},
"xaxis2": {
"autorange": True,
"domain": [0.9, 1],
"range": [min(x),max(x)],
"showgrid": False,
"showticklabels":False,
"title": "",
"type": "linear",
"zeroline": True
},
"yaxis": {
"autorange": True,
"domain": [0, 0.85],
"range": [min(y),max(y)],
"showgrid": False,
"title": ylabel,
"type": "linear",
"zeroline": False,
"showline":True,
"titlefont":dict(size=30, color='black')
},
"yaxis2": {
"autorange": True,
"domain": [0.85, 1],
"range": [min(y),max(y)],
"showgrid": False,
"showticklabels":False,
"title": "",
"type": "linear",
"zeroline": True
}
}
fig = Figure(data=data, layout=layout)
# plot_url = py.plot(fig)
# py.iplot(fig, filename='Comp_gen_erro_hist')
py.image.save_as(fig, filename=filename)
import numpy as np
def shaded_std_plot(x1,y1s,filename,xlabel,ylabel,plotline=False):
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv')
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj')
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ')
foo1 = list(zip(*sorted(list(zip(x1,y1s)),key=lambda x: x[0])))
x1 = list(foo1[0])
y1s = list(foo1[1])
x1_rev = x1[::-1]
y1 = [np.mean(yy) for yy in y1s]
y1_upper = [y1[i]+np.std(yy) for i,yy in enumerate(y1s)]
y1_lower = [y1[i]-np.std(yy) for i,yy in enumerate(y1s)]
y1_lower = y1_lower[::-1]
# print(x)
# print(y)
trace1 = Scatter(
x=x1+x1_rev,
y=y1_upper+y1_lower,
fill='tozerox',
fillcolor='rgba(31, 119, 180,0.2)',
line=scatter.Line(color='rgba(31, 119, 180,0)'),
showlegend=False,
# name='Fair',
)
trace2 = Scatter(
x=x1,
y=y1,
line=Line(color='rgb(31, 119, 180)'),
mode='markers',
# name='Neural network',
)
trace3 = Scatter(
x = [0,1.05*max(x1)],
y = [0,1.05*max(y1)],
mode = 'lines',
name = 'lines'
)
if plotline:
data = Data([trace1, trace2, trace3])
else:
data = Data([trace1, trace2])
layout = Layout(
# legend=dict(x=0.75,y=0.08),
showlegend=False,
# paper_bgcolor='rgb(255,255,255)',
# plot_bgcolor='rgb(229,229,229)',
width=600,
height=500,
font=dict(size=22, color='black'),
xaxis=XAxis(
title=xlabel,
gridcolor='rgb(127,127,127)',
domain=[0.05,1],
dtick = 20,
range=[0.9*min(x1),1.05*max(x1)],
showgrid=False,
showline=True,
showticklabels=True,
tickcolor='rgb(127,127,127)',
ticks='outside',
zeroline=False
),
yaxis=YAxis(
title=ylabel,
gridcolor='rgb(127,127,127)',
showgrid=False,
# dtick = 20,
range=[0.9*min(y1),1.05*max(y1_upper)],
showline=True,
showticklabels=True,
tickcolor='rgb(127,127,127)',
ticks='outside',
zeroline=False
),
)
fig = Figure(data=data, layout=layout)
# py.iplot(fig, filename= 'shaded_lines')
py.image.save_as(fig, filename=filename)
def shaded_std_plot_double(x1,y1s,x2,y2s,filename,xlabel,ylabel):
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv')
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj')
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ')
foo1 = list(zip(*sorted(list(zip(x1,y1s)),key=lambda x: x[0])))
x1 = list(foo1[0])
y1s = list(foo1[1])
x1_rev = x1[::-1]
y1 = [np.mean(yy) for yy in y1s]
y1_upper = [y1[i]+np.std(yy) for i,yy in enumerate(y1s)]
y1_lower = [y1[i]-np.std(yy) for i,yy in enumerate(y1s)]
y1_lower = y1_lower[::-1]
foo2 = list(zip(*sorted(list(zip(x2,y2s)),key=lambda x: x[0])))
x2 = list(foo2[0])
y2s = list(foo2[1])
x2_rev = x2[::-1]
y2 = [np.mean(yy) for yy in y2s]
y2_upper = [y2[i]+np.std(yy) for i,yy in enumerate(y2s)]
y2_lower = [y2[i]-np.std(yy) for i,yy in enumerate(y2s)]
y2_lower = y2_lower[::-1]
# print(x)
# print(y)
trace1 = Scatter(
x=x1+x1_rev,
y=y1_upper+y1_lower,
fill='tozerox',
fillcolor='rgba(31, 119, 180,0.2)',
line=Line(color='transparent'),
showlegend=False,
# name='Fair',
)
trace2 = Scatter(
x=x1,
y=y1,
line=Line(color='rgb(31, 119, 180)'),
mode='markers',
name='Neural network',
)
trace3 = Scatter(
x=x2+x2_rev,
y=y2_upper+y2_lower,
fill='tozerox',
fillcolor='rgba(255,0,0,0.2)',
line=Line(color='transparent'),
showlegend=False,
# name='Fair',
)
trace4 = Scatter(
x=x2,
y=y2,
line=Line(color='rgb(255,0,0)'),
mode='markers',
name='Unbiased learner',
)
data = Data([trace1, trace2, trace3, trace4])
layout = Layout(
legend=dict(x=0.75,y=0.08),
# paper_bgcolor='rgb(255,255,255)',
# plot_bgcolor='rgb(229,229,229)',
font=dict(size=18, color='black'),
xaxis=XAxis(
title=xlabel,
gridcolor='rgb(127,127,127)',
dtick = 20,
range=[0.9*min(x1+x2),max(x1+x2)*1.1],
showgrid=False,
showline=True,
showticklabels=True,
tickcolor='rgb(127,127,127)',
ticks='outside',
zeroline=False
),
yaxis=YAxis(
title=ylabel,
gridcolor='rgb(127,127,127)',
showgrid=False,
showline=True,
showticklabels=True,
tickcolor='rgb(127,127,127)',
ticks='outside',
zeroline=False
),
)
fig = Figure(data=data, layout=layout)
# py.iplot(fig, filename= 'shaded_lines')
py.image.save_as(fig, filename=filename)
def shaded_std_plot_double_scatter(x1,y1s,x2,y2s,x3,y3,filename,xlabel,ylabel):
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv')
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj')
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ')
# py.sign_in(username='guillefix5',api_key='lrZZKCPNIOwDtvgazKAL')
foo1 = list(zip(*sorted(list(zip(x1,y1s)),key=lambda x: x[0])))
x1 = list(foo1[0])
y1s = list(foo1[1])
x1_rev = x1[::-1]
y1 = [np.mean(yy) for yy in y1s]
y1_upper = [y1[i]+np.std(yy) for i,yy in enumerate(y1s)]
y1_lower = [y1[i]-np.std(yy) for i,yy in enumerate(y1s)]
y1_lower = y1_lower[::-1]
foo2 = list(zip(*sorted(list(zip(x2,y2s)),key=lambda x: x[0])))
x2 = list(foo2[0])
y2s = list(foo2[1])
x2_rev = x2[::-1]
y2 = [np.mean(yy) for yy in y2s]
y2_upper = [y2[i]+np.std(yy) for i,yy in enumerate(y2s)]
y2_lower = [y2[i]-np.std(yy) for i,yy in enumerate(y2s)]
y2_lower = y2_lower[::-1]
# print(x)
# print(y)
trace1 = Scatter(
x=x1+x1_rev,
y=y1_upper+y1_lower,
fill='tozerox',
fillcolor='rgba(31, 119, 180,0.2)',
line=Line(color='transparent'),
showlegend=False,
# name='Fair',
)
trace2 = Scatter(
x=x1,
y=y1,
line=Line(color='rgb(31, 119, 180)'),
mode='markers',
marker=dict(size=4,symbol='x'),
name='Neural network',
)
trace3 = Scatter(
x=x2+x2_rev,
y=y2_upper+y2_lower,
fill='tozerox',
fillcolor='rgba(255,0,0,0.2)',
line=Line(color='transparent'),
showlegend=False,
# name='Fair',
)
trace4 = Scatter(
x=x2,
y=y2,
line=Line(color='rgb(255,0,0)'),
mode='markers',
marker=dict(size=4,symbol='x'),
name='Unbiased learner',
)
trace5 = Scatter(
x=x3,
y=y3,
line=Line(color='rgb(0,0,0)'),
mode='markers',
name='Predicted bound',
)
data = Data([trace1, trace2, trace3, trace4, trace5])
layout = Layout(
legend=dict(x=0.01,y=1),
# legend=dict(x=0.75,y=0.01),
# paper_bgcolor='rgb(255,255,255)',
# plot_bgcolor='rgb(229,229,229)',
# min(x1+x2)
font=dict(size=18, color='black'),
xaxis=XAxis(
title=xlabel,
gridcolor='rgb(127,127,127)',
range=[0,max(x1+x2)*1.1],
showgrid=False,
showline=True,
showticklabels=True,
tickcolor='rgb(127,127,127)',
ticks='outside',
zeroline=False
),
yaxis=YAxis(
title=ylabel,
gridcolor='rgb(127,127,127)',
range=[0,max(y1+y2)*1.3],
showgrid=False,
showline=True,
showticklabels=True,
tickcolor='rgb(127,127,127)',
ticks='outside',
zeroline=False
),
)
fig = Figure(data=data, layout=layout)
# py.iplot(fig, filename= 'shaded_lines')
py.image.save_as(fig, filename=filename)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment