Skip to content

Instantly share code, notes, and snippets.

@MiroK
Created April 8, 2014 09:21
Show Gist options
  • Save MiroK/10103879 to your computer and use it in GitHub Desktop.
Save MiroK/10103879 to your computer and use it in GitHub Desktop.
'Investigating higher-order CR like elements.'
from numpy import zeros, linspace, meshgrid, sqrt, where
from numpy.ma import masked_where
from numpy.linalg import inv
import matplotlib.pyplot as plt
def get_value(q, midpoint=(1./3, 1./3)):
assert 0 < q < 3
# Points where polynomials are evaluated. The degrees of freedom are point
# evaluations in points
if q == 1:
points = [(0.5, 0.5), (0.5, 0), (0, 0.5)]
else:
points = [(1./3, 0), (2./3, 0), (0, 1./3), (2./3, 1./3), (0, 2./3),
(1/3., 2/3.)]
# Monomial basis functions of space of polynomials of order q
m_basis = [lambda x, i=i, j=j: x[0]**i*x[1]**j
for i in range(q+1) for j in range(q+1-i)]
# Matrix = point evaluation at point for all basis and points
B = zeros((len(points), len(m_basis)))
for i, point in enumerate(points):
for j, base_f in enumerate(m_basis):
B[i, j] = base_f(point)
# print 'B'
# print B
# print
# Compute the coefficient matrix to get nodal basis
A = inv(B).T
# print 'A'
# print A
# print
# Evaluate the basis functions
for i in range(len(points)):
coefs = A[i, :]
# coef_ij * basis_j defines j-th nodal basis function
value = sum(coef*basis(midpoint) for coef, basis in zip(coefs, m_basis))
print '%d-th nodal basis value at [%g, %g], is %g' % \
(i, midpoint[0], midpoint[1], value)
# Plot
n_plots = A.shape[0]
n_rows = int(sqrt(n_plots)) + 1
n_cols = (n_plots/n_rows) + 1
fig = plt.figure()
cmap = plt.cm.get_cmap('coolwarm')
x = linspace(0, 1, 60)
X, Y = meshgrid(x, x)
mask = where(X+Y < 1, 1, 1E8)
for i in range(n_plots):
plt.subplot(n_rows, n_cols, i+1)
coefs = A[i, :]
Z = zeros(X.shape)
for coef, basis in zip(coefs, m_basis):
Z += coef*basis([X, Y])
Z += mask
Z = masked_where(Z > 1E6, Z)
plt.pcolor(X, Y, Z, cmap=cmap)
plt.plot(points[i][0], points[i][1], 'ko', markersize=10)
plt.xticks([])
plt.yticks([])
plt.axis([0, 1, 0, 1])
fig.suptitle('Nodal basis functions of $CR_{%d}$' % q)
plt.show()
if __name__ == '__main__':
import sys
q = int(sys.argv[1])
get_value(q)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment