Skip to content

Instantly share code, notes, and snippets.

@ES-Alexander
Last active December 23, 2022 08:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ES-Alexander/7a4d6719781d13860f9bdd728b611eaf to your computer and use it in GitHub Desktop.
Save ES-Alexander/7a4d6719781d13860f9bdd728b611eaf to your computer and use it in GitHub Desktop.
A Python-based 3D Mandelbulb Ray-Marcher, with phase shifting
''' Mandelbulb.py
Adapted from
https://github.com/AstroKriel/Mandelbulb/blob/main/python/main.py
EXAMPLE OUTPUT:
https://youtube.com/shorts/f8ek83Z9SFo
Original author: AstroKriel
- Mandelbulb raymarcher
- movable camera position
- matplotlib plotting
Improvements: ES-Alexander
- got nerd sniped by AxonMediaSeattle's comment on youtube.com/watch?v=xuLIJ-FNkSI
- added phase shift
- while I was at it, added horizontal rotation
- numba JIT compilation (for faster run-time)
- progress updates with TQDM
- added OpenCV display and image saving (can use as frames for gif/animation)
'''
##############################################
## MODULES
##############################################
import numpy as np
import matplotlib.pyplot as plt
from pylab import cm
from numba import njit, prange
import numba as nb
from tqdm import trange
import math
##############################################
## PARAMETERS
##############################################
## ray marching properties
PI = math.pi
MAX_STEPS = 300
MAX_DIST = 150.0
SURF_DIST = 0.001
EPSILON = 0.001
## figure / screen properties
resolution = 500
WIDTH = resolution
HEIGHT = resolution
## camera properties
#global CAM_FOV, CAM_POS, CAM_ANG_VER, CAM_ANG_HOR, GLOBAL_UP, LIGHT_POS
CAM_FOV = math.radians(45)
CAM_ANG_VER = math.radians(20)
CAM_ANG_HOR = math.radians(10)
GLOBAL_UP = np.array([0.0, 1.0, 0.0], dtype=np.float64)
CAM_POS = np.array([0.0, 0.0, 2.5], dtype=np.float64)
LIGHT_POS = np.array([1.0, 1.0, 1.0], dtype=np.float64)
## mandelbulb properties
TARGET_POS = np.array([0.0, 0.0, 0.0], dtype=np.float64)
POWER = 8.0
##############################################
## FUNCTIONS
##############################################
@njit(nb.float64[::1](nb.float64[::1],nb.float64))
def rotate_ray_ver(vec, angle):
c = math.cos(angle)
s = math.sin(angle)
matrix = np.array([
[1.0, 0.0, 0.0],
[0.0, c, -s],
[0.0, s, c]
], dtype=np.float64)
result = matrix.dot(vec)
return result
@njit(nb.float64[::1](nb.float64[::1],nb.float64))
def rotate_ray_hor(vec, angle):
c = math.cos(angle)
s = math.sin(angle)
matrix = np.array([
[c, 0.0, s],
[0.0, 1.0, 0.0],
[-s, 0.0, c]
], dtype=np.float64)
result = matrix.dot(vec)
return result
@njit(nb.float64[::1](nb.float64[::1]))
def normVec(vec):
norm = np.linalg.norm(vec)
return vec if norm == 0 else vec / norm
@njit(nb.float64[:,::1]())
def view_matrix():
global CAM_POS, GLOBAL_UP, TARGET_POS
## basis vectors : https://www.3dgep.com/understanding-the-view-matrix/
cam2obj = normVec(TARGET_POS - CAM_POS) # forward
vec_cross = np.cross(cam2obj, GLOBAL_UP)
x_axis = normVec(vec_cross) # right
y_axis = np.cross(x_axis, cam2obj) # up
return np.vstack((x_axis, y_axis, cam2obj))
@njit(nb.float64(nb.float64[::1], nb.float64))
def DEMandelbulb(ray_pos, phase_shift):
global POWER
tmp_pos = ray_pos
dr = 2.0
for tmp_iter in range(10):
r = np.linalg.norm(tmp_pos)
if r > 2.0: break
## approximate the distance differential
dr = POWER * pow(r, POWER-1.0) * dr + 1.0
## calculate fractal surface
## convert to polar coordinates
theta = math.acos( tmp_pos[2] / r )
phi = math.atan2(tmp_pos[1], tmp_pos[0])
zr = pow(r, POWER)
## convert back to cartesian coordinates
tp = theta * POWER + phase_shift # mandelbulb phase shift
pp = phi * POWER + phase_shift # horizontal spin
x = zr * math.sin(tp) * math.cos(pp)
y = zr * math.sin(tp) * math.sin(pp)
z = zr * math.cos(tp)
tmp_pos = ray_pos + np.array([x, y, z], dtype=np.float64)
## distance estimator
return 0.5 * np.log(r) * r / dr
@njit(nb.float64[::1](nb.float64[::1], nb.float64))
def getNormal(ray_pos, phase):
global EPSILON
safe_dist = DEMandelbulb(ray_pos, phase)
vec_x = safe_dist - DEMandelbulb(ray_pos - np.array([EPSILON, 0, 0], dtype=np.float64), phase)
vec_y = safe_dist - DEMandelbulb(ray_pos - np.array([0, EPSILON, 0], dtype=np.float64), phase)
vec_z = safe_dist - DEMandelbulb(ray_pos - np.array([0, 0, EPSILON], dtype=np.float64), phase)
output = np.array([vec_x, vec_y, vec_z], dtype=np.float64)
return normVec(output)
@njit(nb.float64(nb.float64[::1], nb.float64))
def getLight(ray_pos, phase):
global LIGHT_POS, CAM_ANG_VER, CAM_ANG_HOR
light_pos = np.array([LIGHT_POS[0],LIGHT_POS[1],LIGHT_POS[2]], dtype=np.float64)
light_pos_ver = rotate_ray_ver(light_pos, CAM_ANG_VER)
light_pos_hor = rotate_ray_hor(light_pos_ver, CAM_ANG_HOR)
light2surface_angle = normVec(light_pos_hor - ray_pos)
surface_normal = getNormal(ray_pos, phase)
return max(0.0, min(1.0, np.dot(surface_normal, light2surface_angle)))
@njit(nb.types.UniTuple(nb.float64,2)(nb.int64,nb.int64,nb.float64))
def rayMarching(x, y, phase):
global WIDTH, HEIGHT
global MAX_STEPS, MAX_DIST, SURF_DIST
global CAM_FOV, CAM_POS, CAM_ANG_VER, CAM_ANG_HOR
dist_cam2obj = HEIGHT / math.tan(CAM_FOV)
vec2obj = np.array([ (WIDTH // 2), (HEIGHT // 2), dist_cam2obj ], dtype=np.float64)
vec2pixel = np.array([x, y, 0], dtype=np.int64)
## compute vector from camera to object
ray_dir_world = normVec( vec2obj - vec2pixel ) # in world coordinates
view = view_matrix()
ray_dir_cam = view.dot(ray_dir_world) # in camera (view matrix) coordinates
ray_dist = 0 # initialise distance travelled by ray
## find how far ray can travel along direction
for tmp_iter in range(MAX_STEPS):
## integrate the ray forwards
ray_pos = CAM_POS + ray_dir_cam * ray_dist
## apply camera rotations
ray_ver = rotate_ray_ver(ray_pos, CAM_ANG_VER)
ray_hor = rotate_ray_hor(ray_ver, CAM_ANG_HOR)
## check how far the ray can step safely
tmp_dist = DEMandelbulb(ray_hor, phase)
## check if sufficiently close to mandelbulb
if tmp_dist < SURF_DIST:
return (MAX_DIST-ray_dist, getLight(ray_hor, phase))
ray_dist += tmp_dist
if ray_dist > MAX_DIST:
break
return (MAX_DIST, 0.0)
@njit(parallel=WIDTH>250)
def renderImage(phase):
global WIDTH, HEIGHT
## initialise screen pixels
dist_pixels = np.zeros((HEIGHT, WIDTH), np.float64)
light_pixels = np.zeros_like(dist_pixels)
## calculate the inensity of each pixel on the screen
for x in prange(WIDTH):
for y in range(HEIGHT):
dist, light = rayMarching(x, y, phase)
#dist_pixels[y,x] = dist
light_pixels[y,x] = light
return dist_pixels, light_pixels
def drawScene(phase):
from time import perf_counter
start = perf_counter()
dist_pixels, light_pixels = renderImage(phase)
print(f'Rendered in {perf_counter() - start:.3f}s')
## plot distances
'''
fig, ax = plt.subplots()
ax.imshow(dist_pixels, cmap=cm.gray, origin='upper') # plot dist data
ax.axis('off') # remove axis labels
plt.tight_layout() # minimise white space
plt.show()
'''
## plot lighting
fig, ax = plt.subplots()
ax.imshow(light_pixels, cmap=cm.gray, origin='upper') # plot light data
ax.axis('off') # remove axis labels
plt.tight_layout() # minimise white space
plt.show()
##############################################
## MAIN PROGRAM
##############################################
if __name__ == "__main__":
import cv2
cv2.namedWindow('fractal', cv2.WINDOW_NORMAL)
FRAMES = 90
num_width = len(str(FRAMES))
for frame in trange(FRAMES):
phase = math.radians(360 * frame/FRAMES) # end at the start
_, pixels = renderImage(phase)
converted = np.uint8((pixels / pixels.max()) * 255)
cv2.imshow('fractal', converted)
cv2.waitKey(1)
image_number = str(frame).zfill(num_width)
cv2.imwrite(f'fractal_{image_number}.png', converted)
#drawScene()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment