Last active
August 31, 2016 11:40
-
-
Save ZelphirKaltstahl/270df543207174f8ed82c0970d67b62e to your computer and use it in GitHub Desktop.
customized matplotlib / pyplot plotting code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# use the scratch environment | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
import scipy as sp | |
# create figure and axes | |
fig = plt.figure(figsize=(8,6)) | |
ax = fig.add_subplot(1, 1, 1) | |
######### | |
# TITLE # | |
######### | |
ax.set_title('Linear Neuron (bias = 1)', y=1.05) | |
############### | |
# AXIS LIMITS # | |
############### | |
ax.set_xlim(-2, 3) | |
ax.set_ylim(-1, 3) | |
################# | |
# AXIS POSITION # | |
################# | |
ax.spines['bottom'].set_position('zero') | |
# ax.spines['bottom'].set_position(('axes', 0.30)) | |
ax.spines['left'].set_position('zero') | |
# ax.spines['left'].set_position(('axes', 0.30)) | |
# ax.spines['left'].set_smart_bounds(True) | |
# ax.spines['bottom'].set_smart_bounds(True) | |
ax.spines['right'].set_color('none') | |
ax.spines['top'].set_color('none') | |
############## | |
# AXIS TICKS # | |
############## | |
ax.yaxis.set_ticks_position('left') | |
ax.xaxis.set_ticks_position('bottom') | |
def get_formatter_function(allowed_values, replacement_values=None, datatype='float', show_pos_sign=False): | |
"""returns a function, which only allows allowed_values as axis tick labels""" | |
def hide_others(value, pos): | |
return_string = '' | |
if value not in allowed_values: | |
return return_string # which is '' at this point | |
elif replacement_values: | |
index = allowed_values.index(value) | |
try: | |
return replacement_values[index] | |
except IndexError as ind_err: | |
return '' | |
else: | |
if value > 0 and show_pos_sign: | |
return_string += '+' | |
if datatype == 'float': | |
typed_value = value | |
elif datatype == 'int': | |
typed_value = int(value) | |
elif datatype == 'str': | |
typed_value = value | |
else: | |
typed_value = value | |
return_string += str(typed_value) | |
return return_string | |
return hide_others | |
ax.xaxis.set_minor_formatter(plt.FuncFormatter(get_formatter_function([0], datatype='int', show_pos_sign=True))) | |
ax.yaxis.set_minor_formatter(plt.FuncFormatter(get_formatter_function([0], datatype='int', show_pos_sign=True))) | |
ax.xaxis.set_major_formatter(plt.FuncFormatter(get_formatter_function([-1,0], datatype='int', show_pos_sign=True))) | |
ax.yaxis.set_major_formatter(plt.FuncFormatter(get_formatter_function([0,1], datatype='int', show_pos_sign=True))) | |
for tick in ax.xaxis.get_majorticklabels(): | |
tick.set_horizontalalignment('left') | |
for tick in ax.yaxis.get_majorticklabels(): | |
tick.set_verticalalignment('bottom') | |
ax.tick_params(axis=u'both', which=u'both',length=0) | |
tick_spacing = 1.0 | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) | |
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) | |
############### | |
# AXIS LABELS # | |
############### | |
y_label = ax.set_ylabel('output') | |
y_label.set_rotation(0) | |
x_label = ax.set_xlabel('input') | |
x_label.set_rotation(0) | |
right_border_xpos = 1 | |
top_border_xpos = 1 | |
# x_label.set_position( | |
# (right_border_xpos - 0.05, x_label.get_position()[1])) | |
# y_label.set_position( | |
# (y_label.get_position()[0], top_border_xpos - 0.05)) | |
# ax.yaxis.labelpad = 8 | |
# ax.xaxis.labelpad = 0 | |
ax.yaxis.set_label_coords(0.4 - 0.05, 1 - 0.05) | |
ax.xaxis.set_label_coords(1 - 0.05, 0.25 - 0.05) | |
################### | |
# AXIS ARROWHEADS # | |
################### | |
# ax.arrow(3, -0.003, 0.1, 0, width=0.002, color="k", clip_on=False, head_width=0.10, head_length=0.10) | |
# ax.arrow(0.003, 3, 0, 0.1, width=0.002, color="k", clip_on=False, head_width=0.10, head_length=0.10) | |
######## | |
# GRID # | |
######## | |
ax.grid(True) | |
######## | |
# PLOT # | |
######## | |
X = np.linspace(-2, 2, num=50, endpoint=True) | |
b = 1.0 | |
Y = X + b | |
ax.plot(X, Y, '-', linewidth=1.5) | |
fig.tight_layout() | |
fig.savefig('plot.pdf', bbox_inches='tight') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment