Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
A quick demonstration of how to plot multivalued complex functions in Python.
import numpy
import sympy
from mpl_toolkits.mplot3d import Axes3D
import matplotlib
from matplotlib import cm, colors
from matplotlib import pyplot as plt
branching_number = 2
Nr = 16
Ntheta = 32
# compute the theta,R domain
theta = numpy.linspace(0,2*numpy.pi*branching_number, Ntheta)
r = numpy.linspace(0,1,Nr)
Theta, R = numpy.meshgrid(theta,r)
z = R*numpy.exp(1j*Theta)
# compute w^2 = z. THE KEY IDEA is to pass the exponentiation by 1/2 into exp().
w = numpy.sqrt(R)*numpy.exp(1j*Theta/2)
# color by argument
arguments = numpy.angle(w)
norm = colors.Normalize(arguments.min(), arguments.max())
color = cm.jet(norm(arguments))
fig = plt.figure(figsize=(16,8))
# plot the real part
ax_real = fig.add_subplot(1,2,1,projection='3d')
ax_real.plot_surface(z.real, z.imag, w.real,
rstride=1, cstride=1, alpha=0.5, facecolors=color)
# plot the imaginary part
ax_imag = fig.add_subplot(1,2,2,projection='3d')
ax_imag.plot_surface(z.real, z.imag, w.imag,
rstride=1, cstride=1, alpha=0.5, facecolors=color)
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Copy link

rakista commented Apr 24, 2019

Thanks for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment