Skip to content

Instantly share code, notes, and snippets.

@jllanfranchi
Forked from michaelaye/raytracing.py
Last active May 31, 2018 13:25
Show Gist options
  • Save jllanfranchi/8c00c34d0084280e835214e7ff0ab7b7 to your computer and use it in GitHub Desktop.
Save jllanfranchi/8c00c34d0084280e835214e7ff0ab7b7 to your computer and use it in GitHub Desktop.
Numba optimized raytracing. From 23 sec to 9 sec to 0.45 s.
#!/usr/bin/env python
"""
For 400 x 300
Disabling Numba (i.e. Python looping, NOT numpy optimized) speed is 28.6 sec
Enabling Numba is 0.45 sec (further optimizations including removing some function calls gets this down to 0.23 sec)
Note that a Numpy version rt3.py at http://www.excamera.com/sphinx/article-ray.html is still fastest at ~0.11 sec.
Note that I have removed the checker pattern from the original code
"""
from __future__ import absolute_import, division, print_function
from timeit import default_timer as timer
from math import sqrt
import numpy as np
import numba
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
WIDTH = 400 # pixels
HEIGHT = 300 # pixels
DEPTH_MAX = 5 # Maximum number of light reflections.
FLOAT_T = np.float32
EPS = FLOAT_T(0.0001)
EPS_SQ = FLOAT_T(EPS**2)
INF = FLOAT_T(np.inf)
ZERO = FLOAT_T(0)
ONE = FLOAT_T(1)
TWO = FLOAT_T(2)
TYPE_PLANE = 1
TYPE_SPHERE = 2
TYPE_CYLINDER = 3
OBJECT_T = np.dtype([
('type', np.int8),
('position', FLOAT_T, 3),
('normal', FLOAT_T, 3),
('radius', FLOAT_T),
('radius_sq', FLOAT_T),
('color', FLOAT_T, 3),
('reflection', FLOAT_T),
('diffuse_c', FLOAT_T),
('specular_c', FLOAT_T),
('specular_k', FLOAT_T),
], align=True)
# List of objects.
COLOR_PLANE0 = np.full(shape=3, fill_value=1, dtype=FLOAT_T)
COLOR_PLANE1 = np.full(shape=3, fill_value=0, dtype=FLOAT_T)
# Light position and color.
LIGHT_POS = np.array([5, 5, -10], dtype=FLOAT_T)
LIGHT_COLOR = np.ones(3, dtype=FLOAT_T)
# Default light and material parameters.
AMBIENT = FLOAT_T(0.05)
DFLT_DIFFUSE_C = FLOAT_T(1)
DFLT_SPECULAR_C = FLOAT_T(1)
DFLT_SPECULAR_K = FLOAT_T(50)
NJIT_KW = dict(
nopython=True,
nogil=True,
fastmath=True,
cache=True
)
def add_sphere(
position,
radius,
color,
reflection,
diffuse_c=DFLT_DIFFUSE_C,
specular_c=DFLT_SPECULAR_C,
specular_k=DFLT_SPECULAR_K
):
return dict(
type=TYPE_SPHERE,
position=np.array(position, dtype=FLOAT_T),
radius=np.array(radius, dtype=FLOAT_T),
color=np.array(color, dtype=FLOAT_T),
reflection=FLOAT_T(reflection),
diffuse_c=FLOAT_T(diffuse_c),
specular_c=FLOAT_T(specular_c),
specular_k=FLOAT_T(specular_k),
)
def add_plane(
position,
normal,
color,
reflection,
diffuse_c=DFLT_DIFFUSE_C,
specular_c=DFLT_SPECULAR_C,
specular_k=DFLT_SPECULAR_K
):
return dict(
type=TYPE_PLANE,
position=np.array(position, dtype=FLOAT_T),
normal=np.array(normal, dtype=FLOAT_T),
color=np.array(color, dtype=FLOAT_T),
reflection=FLOAT_T(reflection),
diffuse_c=FLOAT_T(diffuse_c),
specular_c=FLOAT_T(specular_c),
specular_k=FLOAT_T(specular_k),
)
SCENE = [
add_sphere(
position=[0.75, 0.1, 1],
radius=0.6,
color=[1, 1, 1],
reflection=0.999
),
add_sphere(
position=[-0.75, 0.1, 2.25],
radius=0.6,
color=[0.5, 0.223, 0.5],
reflection=0.999
),
add_sphere(
position=[-2.75, 0.1, 3.5],
radius=0.6,
color=[1, 0.572, 0.184],
reflection=0.999
),
add_plane(
position=[0, -0.5, 0],
normal=[0, 1, 0],
color=[1, 0, 0],
reflection=0.999
),
]
# Fill array
SCENE_ARRAY = np.empty(shape=len(SCENE), dtype=OBJECT_T)
for idx, obj in enumerate(SCENE):
item = SCENE_ARRAY[idx]
item['position'][:] = obj['position']
item['color'][:] = obj['color']
item['reflection'] = obj['reflection']
item['diffuse_c'] = obj['diffuse_c']
item['specular_c'] = obj['specular_c']
item['specular_k'] = obj['specular_k']
type_ = item['type'] = obj['type']
if type_ == TYPE_PLANE:
item['normal'][:] = obj['normal']
elif type_ == TYPE_SPHERE:
item['radius'] = obj['radius']
item['radius_sq'] = np.square(obj['radius'])
else:
raise ValueError(str(type_))
@numba.jit(**NJIT_KW)
def normalize(x):
"""Normalize vector `x` to have length 1 (in-place operation)."""
n = FLOAT_T(0)
for i in range(3):
xi = x[i]
n += xi * xi
n = sqrt(n)
for i in range(3):
x[i] /= n
return x
@numba.jit(**NJIT_KW)
def intersect_plane(ray_vertex, ray_direction, plane_point, plane_normal):
"""Return the distance from ray_vertex to the intersection of the ray
(ray_vertex, ray_direction) with the plane (plane_point, plane_normal), or +inf if there is no
intersection. ray_vertex and plane_point are 3D points, ray_direction and plane_normal (normal)
are normalized vectors.
Parameters
----------
ray_vertex : shape (3,) array
Light ray starting point
ray_direction : shape (3,) array
Direction cosines of light ray
plane_point : shape (3,) array
Point in the plane
plane_normal : shape (3,) array
Normal vector to the plane (must have length 1)
Returns
-------
dist : float
Distance to intersection point from `ray_vertex`. If no intersection is
found, infinity is returned.
"""
denom = FLOAT_T(np.dot(ray_direction, plane_normal))
if abs(denom) < EPS_SQ:
return INF
d = FLOAT_T(np.dot(plane_point - ray_vertex, plane_normal)) / denom
if d < 0:
return INF
return d
@numba.jit(**NJIT_KW)
def intersect_sphere(ray_vertex, ray_direction, sphere_vertex, sphere_radius_sq):
"""Return the distance from ray_vertex to the intersection of the ray
(ray_vertex, ray_direction) with the sphere (sphere_vertex,
sphere_radius_sq), or +inf if there is no intersection. ray_vertex and
sphere_vertex are 3D points, ray_direction is a normalized vector,
sphere_radius is a scalar.
Parameters
----------
ray_vertex : shape (3,) array
ray_direction : shape (3,) array
sphere_vertex : shape (3,) array
sphere_radius_sq : float
Returns
-------
dist : float
Distance to intersection point from `ray_vertex`. If no intersection is
found, infinity is returned.
"""
a = FLOAT_T(np.dot(ray_direction, ray_direction))
vert_vec = ray_vertex - sphere_vertex
b = TWO * FLOAT_T(np.dot(ray_direction, vert_vec))
c = FLOAT_T(np.dot(vert_vec, vert_vec)) - sphere_radius_sq
discriminant = b * b - 4 * a * c
if discriminant > 0:
sqrt_dis = sqrt(discriminant)
q = (-b - sqrt_dis) / TWO if b < 0 else (-b + sqrt_dis) / TWO
t0 = q / a
t1 = c / q
if t0 > t1:
if t0 >= 0:
return t0 if t1 < 0 else t1
elif t1 >= 0:
return t1 if t0 < 0 else t0
return INF
@numba.jit(**NJIT_KW)
def intersect(ray_vertex, ray_direction, obj):
if obj['type'] == TYPE_PLANE:
return intersect_plane(ray_vertex, ray_direction, obj['position'], obj['normal'])
return intersect_sphere(ray_vertex, ray_direction, obj['position'], obj['radius_sq'])
@numba.jit(**NJIT_KW)
def get_normal(obj, point_of_intersection):
"""Find normal."""
normal = np.empty(shape=3, dtype=FLOAT_T)
if obj['type'] == TYPE_SPHERE:
normal[:] = point_of_intersection - obj['position']
normalize(normal)
else:
normal[:] = obj['normal']
return normal
@numba.jit(**NJIT_KW)
def find_intersecting_obj(ray_vertex, ray_direction):
t = INF
obj_idx = -1
for i, obj in enumerate(SCENE_ARRAY):
if obj['type'] == TYPE_PLANE:
t_obj = intersect_plane(ray_vertex, ray_direction, obj['position'], obj['normal'])
else:
t_obj = intersect_sphere(ray_vertex, ray_direction, obj['position'], obj['radius_sq'])
if t_obj < t:
t = t_obj
obj_idx = i
return t, obj_idx
@numba.jit(**NJIT_KW)
def find_shadow(point_of_intersection, normal, toL, obj_idx):
ls = np.zeros(SCENE_ARRAY.size, dtype=FLOAT_T)
ct = 0
for k, obj_sh in enumerate(SCENE_ARRAY):
if k != obj_idx:
ls[k] = intersect(point_of_intersection + normal * EPS, toL, obj_sh)
ct += 1
else:
ls[k] = INF
return ls, ct
@numba.jit(**NJIT_KW)
def loop(height, width, depth_max):
# Screen coordinates: x0, y0, x1, y1.
aspect_ratio = float(WIDTH) / HEIGHT
screen_coords = np.array(
[-1, -1 / aspect_ratio + 0.25, 1, 1 / aspect_ratio + 0.25],
dtype=FLOAT_T
)
num_pix = height * width
camera = np.array([0, 0.35, -1], dtype=FLOAT_T) # Camera.
img = np.empty(shape=(height, width, 3), dtype=FLOAT_T)
initial_reflection = ONE
x_coords = np.linspace(screen_coords[0], screen_coords[2], width).astype(FLOAT_T)
y_coords = np.linspace(screen_coords[1], screen_coords[3], height).astype(FLOAT_T)
color = np.empty(shape=3, dtype=FLOAT_T) # Current color.
ray_vertex = np.empty(shape=3, dtype=FLOAT_T)
ray_direction = np.empty(shape=3, dtype=FLOAT_T)
point_of_intersection = np.empty(shape=3, dtype=FLOAT_T)
normal = np.empty(shape=3, dtype=FLOAT_T)
col_ray = np.empty(shape=3, dtype=FLOAT_T)
for pix_num in range(num_pix):
i, j = divmod(pix_num, height)
x = x_coords[i]
y = y_coords[j]
camera_points_at = np.array([x, y, 0], dtype=FLOAT_T) # Camera pointing to.
color[:] = 0
depth = 0
ray_vertex[:] = camera
ray_direction[:] = normalize(camera_points_at - camera)
reflection = initial_reflection
# Loop through initial and reflected rays
while True:
# Find first point of intersection with an object in the scene
t, obj_idx = find_intersecting_obj(ray_vertex, ray_direction)
if obj_idx < 0:
break
# Find the point of intersection on the object.
point_of_intersection[:] = ray_vertex + ray_direction * t
# Find properties of the object
normal[:] = get_normal(SCENE_ARRAY[obj_idx], point_of_intersection)
toL = normalize(LIGHT_POS - point_of_intersection)
to_vertex = normalize(ray_vertex - point_of_intersection)
# Shadow: find if the point is shadowed or not.
l, ct = find_shadow(point_of_intersection, normal, toL, obj_idx)
if ct and np.min(l) < INF:
break
obj = SCENE_ARRAY[obj_idx]
# Lambert shading (diffuse).
col_ray[:] = AMBIENT
col_ray += obj['diffuse_c'] * max(FLOAT_T(np.dot(normal, toL)), 0) * obj['color']
# Blinn-Phong shading (specular).
col_ray += obj['specular_c'] * LIGHT_COLOR * max(
FLOAT_T(np.dot(normal, normalize(toL + to_vertex))),
FLOAT_T(0)
) ** obj['specular_k']
depth += 1
if depth > depth_max:
break
# Reflection: create a new ray
ray_vertex[:] = point_of_intersection + normal * EPS
ray_direction[:] = normalize(
ray_direction - TWO * FLOAT_T(np.dot(ray_direction, normal)) * normal
)
color += reflection * col_ray
reflection *= SCENE_ARRAY[obj_idx]['reflection']
for k in range(3):
img[height - j - 1, i, k] = min(ONE, max(ZERO, color[k]))
return img
def main():
# Force compilation of jitted code
loop(height=1, width=1, depth_max=1)
ts = timer()
img = loop(height=HEIGHT, width=WIDTH, depth_max=DEPTH_MAX)
te = timer()
tot_pix = HEIGHT * WIDTH
print("Total: {:.6f} s to render {} pixels".format(te - ts, tot_pix))
print("Time per pixel: {:.3f} us".format(1e6*(te - ts) / tot_pix))
print("Image written to fig.png")
plt.imsave('fig.png', img)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment