Skip to content

Instantly share code, notes, and snippets.

@classAndrew
Last active July 3, 2024 16:21
Show Gist options
  • Save classAndrew/c526070f9b9518a1a2814385b86d1076 to your computer and use it in GitHub Desktop.
Save classAndrew/c526070f9b9518a1a2814385b86d1076 to your computer and use it in GitHub Desktop.
Ray tracing using Torch tensors
import time
import time
import numpy as np
import torch
torch.set_default_device('cuda')
device="cuda"
n = 320
width, height = n*2, n
aspect_ratio = width/height
samples_per_pixel = 10
class Vec: pass
class Ray: pass
class Sphere: pass
class HitRecord: pass
Vec = lambda x, y, z: torch.Tensor([x, y, z]).to(device)
Ray = lambda start, dir: torch.Tensor([start, dir]).to(device)
Sphere = lambda center, radius: torch.Tensor([*center, radius]).to(device)
class HitRecord:
def __init__(self, num_rays):
self.hit_points = torch.zeros((num_rays, 3))
self.hit_normals = torch.zeros((num_rays, 3))
self.hit_t = torch.zeros(num_rays)
self.is_front_face = torch.zeros(num_rays, dtype=bool)
self.material_id = torch.zeros(num_rays, dtype=int)
def prod(v):
return int(torch.prod(v))
def ray_at(ray, t):
return ray[..., 0, :] + ray[..., 1, :]*t[..., None]
def vec_norm(v):
return (v[..., 0]**2 + v[..., 1]**2 + v[..., 2]**2) ** .5
def vec_unit(v):
return v/vec_norm(v)[..., None]
def set_face_normal(hit_record: HitRecord, ray, hit_idx, hit_outward_normals):
hit_ray_dirs = ray[hit_idx][..., 1, :]
if len(hit_idx.shape) == 0:
return
front_faces = torch.einsum("ij,ij->i", hit_ray_dirs, hit_outward_normals) < 0
hit_record.is_front_face[hit_idx] = front_faces
hit_outward_normals[~front_faces] *= -1
hit_record.hit_normals[hit_idx] = hit_outward_normals
def gamma_correct(colors):
# ok there might be some negative values somehow
colors[colors > 0] = torch.sqrt(colors[colors > 0])
colors[colors < 0] = 0.
def get_many_rng_unit_vecs(size):
N_RNG_BATCH = 18000
generated_count = 0
results = []
while generated_count < size:
rng_seq = torch.FloatTensor(N_RNG_BATCH).uniform_(-1, 1.).reshape(N_RNG_BATCH//3, 3)
in_sphere = vec_norm(rng_seq) <= 1.
sphere_vecs = vec_unit(rng_seq[in_sphere])
results.append(sphere_vecs)
generated_count += sphere_vecs.shape[0]
return torch.vstack([*results])[:size].to(device)
def get_reflected(ray_dirs, normals):
v_dot_n = torch.einsum("ij,ij->i", ray_dirs, normals)
return ray_dirs - 2 * v_dot_n[..., None] * normals
def get_refracted(uv, normals, etai_over_etat):
cos_theta = torch.fmin(torch.einsum("ij,ij->i", -uv, normals), torch.ones(normals.shape[0]))
out_perp = etai_over_etat[:, None] * (uv + cos_theta[:, None] * normals)
out_parallel = -torch.sqrt(torch.abs(1.0 - vec_norm(out_perp) **2))[:, None] * normals
return out_perp + out_parallel
def get_reflectance(cosine, refractive_index):
r0 = (1-refractive_index) / (1+refractive_index)
r0 = r0*r0
return r0 * (1-r0)*(1. - cosine)**5
def sphere_hit(sphere, ray, ray_tmin, ray_tmax, hit_record: HitRecord):
sphere_center, sphere_radius = sphere[:3].to(device), sphere[3].to(device)
oc = sphere_center - ray[..., 0, :]
a = torch.einsum("ij,ij->i", ray[..., 1, :], ray[..., 1, :]) # torch.dot(ray[1], ray[1])
h = torch.einsum("ij,ij->i", ray[..., 1, :], oc) # torch.dot(ray[1], oc)
c = torch.einsum("ij,ij->i", oc, oc) - sphere_radius*sphere_radius # torch.dot(oc, oc) - sphere_radius*sphere_radius
discriminant = h*h - a*c
has_hit = discriminant > 0
pos_disc_idx = has_hit.nonzero().squeeze()
sqrtd = torch.sqrt(discriminant[pos_disc_idx])
h = h[pos_disc_idx]
a = a[pos_disc_idx]
c = c[pos_disc_idx]
roots = (h - sqrtd) / a
valid_roots = torch.ones_like(roots, dtype=bool)
r1_out_bounds_idx = ((roots <= ray_tmin[pos_disc_idx]) | (roots >= ray_tmax[pos_disc_idx])).nonzero().squeeze()
if not sqrtd.shape:
return has_hit
sqrtd = sqrtd[r1_out_bounds_idx]
h = h[r1_out_bounds_idx]
a = a[r1_out_bounds_idx]
roots2 = (h + sqrtd) / a
roots[r1_out_bounds_idx] = roots2
r2_out_bounds_mask = ((roots <= ray_tmin[pos_disc_idx]) | (roots >= ray_tmax[pos_disc_idx]))
valid_roots[r2_out_bounds_mask] = False
has_hit[pos_disc_idx] = valid_roots
# hit_record[pos_disc_idx, 2, 0] = roots
hit_record.hit_t[pos_disc_idx] = roots
hit_idx = has_hit.nonzero().squeeze()
hit_record.hit_points[hit_idx] = p = ray_at(ray[hit_idx], hit_record.hit_t[hit_idx])
# hit_record[hit_idx, 0] = p = ray_at(ray[hit_idx], hit_record[hit_idx][..., 2, 0])
outward_normal = (p - sphere_center) / sphere_radius
set_face_normal(hit_record, ray, hit_idx, outward_normal)
return has_hit
def hit_all(hittable_list, rays, ray_tmin, ray_tmax, recs: HitRecord):
has_hit_anything = torch.zeros(rays.shape[0], dtype=bool)
closest_t = ray_tmax
temp_recs = HitRecord(rays.shape[0]) # torch.zeros_like(recs)
for sphere, material_id in hittable_list:
has_hit = sphere_hit(sphere, rays, ray_tmin, closest_t, temp_recs)
hit_idx = has_hit.nonzero().squeeze()
has_hit_anything[hit_idx] = True
closest_t[hit_idx] = temp_recs.hit_t[hit_idx] # temp_recs[hit_idx][..., 2, 0]
recs.hit_normals[hit_idx] = temp_recs.hit_normals[hit_idx]
recs.hit_points[hit_idx] = temp_recs.hit_points[hit_idx]
recs.hit_t[hit_idx] = temp_recs.hit_t[hit_idx]
recs.is_front_face[hit_idx] = temp_recs.is_front_face[hit_idx]
recs.material_id[hit_idx] = material_id # temp_recs.material_id[hit_idx]
return has_hit_anything
def get_indices_by_loop(material_ids, has_hit):
# has_hit is needed because a ray that has not hit will have material id 0 assigned to it
indices = []
for k in range(MAX_MATERIALS):
indices.append(((material_ids == k) & has_hit).nonzero().squeeze())
return indices
def scatter_rays_diffuse(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation):
if not material_idx.shape or material_idx.shape[0] == 0:
return
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3))
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx]
bounced_rays[..., 1, :] = bounce_dirs = hit_recs.hit_normals[material_idx] + get_many_rng_unit_vecs(material_idx.shape[0])
# fix degenerate scattering
zero_dir_mask = (torch.abs(bounce_dirs) < 1e-8).sum(axis=1) == 3 # there are 3 components
bounced_rays[zero_dir_mask, 1, :] = hit_recs.hit_normals[material_idx][zero_dir_mask]
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1)
def scatter_rays_metal(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation):
if not material_idx.shape or material_idx.shape[0] == 0:
return
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3))
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx]
bounced_rays[..., 1, :] = get_reflected(rays[material_idx, 1, :], hit_recs.hit_normals[material_idx])
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1)
def scatter_rays_fuzz(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation):
if not material_idx.shape or material_idx.shape[0] == 0:
return
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3))
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx]
bounced_rays[..., 1, :] = bounce_dirs = get_reflected(rays[material_idx, 1, :], hit_recs.hit_normals[material_idx]) + get_many_rng_unit_vecs(material_idx.shape[0])
zero_dir_mask = (torch.abs(bounce_dirs) < 1e-8).sum(axis=1) == 3 # there are 3 components
bounced_rays[zero_dir_mask, 1, :] = hit_recs.hit_normals[material_idx][zero_dir_mask]
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1)
def scatter_rays_dialectric(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation, ri):
if not material_idx.shape or material_idx.shape[0] == 0:
return
refractive_index = torch.ones(material_idx.shape[0])
refractive_index[hit_recs.is_front_face[material_idx]] = 1/ri
refractive_index[~hit_recs.is_front_face[material_idx]] = ri
unit_direction = vec_unit(rays[material_idx][..., 1, :])
cos_theta = torch.fmin(torch.einsum("ij,ij->i", -unit_direction, hit_recs.hit_normals[material_idx]), torch.ones_like(refractive_index))
sin_theta = torch.sqrt(1.0 - cos_theta*cos_theta)
high_reflectance_mask = get_reflectance(cos_theta, refractive_index) > torch.FloatTensor(material_idx.shape[0]).uniform_(0., 1.).to(device)
cannot_refract_mask = (refractive_index * sin_theta > 1.) | high_reflectance_mask
can_refract_mask = ~cannot_refract_mask
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3))
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx]
bounced_rays[cannot_refract_mask, 1, :] = get_reflected(unit_direction[cannot_refract_mask], hit_recs.hit_normals[material_idx][cannot_refract_mask])
bounced_rays[can_refract_mask, 1, :] = get_refracted(unit_direction[can_refract_mask], hit_recs.hit_normals[material_idx][can_refract_mask], refractive_index[can_refract_mask])
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1)
hittable_list = [(Sphere(Vec(0.5, 1, -1), 0.5), 3), (Sphere(Vec(-0.5, 1, -1), 0.5), 2), (Sphere(Vec(-0.5, 0, -1), 0.5), 1),
(Sphere(Vec(0, -100.5, -1), 100), 0), (Sphere(Vec(0.5, 0, -1), 0.5), 4)]
MAX_MATERIALS = max(x for _, x in hittable_list) + 1
def ray_color(hittable_list, rays, depth=10):
if depth <= 0:
return torch.zeros((rays.shape[0], 3)) # Vec(0., 0., 0.)
hit_recs = HitRecord(rays.shape[0]) # torch.zeros((rays.shape[0], 3, 3))
rays_max = torch.ones((rays.shape[0]))*torch.inf
rays_min = torch.ones((rays.shape[0]))*0.001
has_hits = hit_all(hittable_list, rays, rays_min, rays_max, hit_recs)
colors = torch.zeros((rays.shape[0], 3))
unhit_rays_unit = vec_unit(rays[~has_hits])
ray_alphas = 0.5 * (unhit_rays_unit[..., 1, 1] + 1.)
colors[~has_hits] = torch.outer(1-ray_alphas, Vec(1., 1., 1.)) + torch.outer(ray_alphas, Vec(0.5, 0.7, 1.))
ground_idx, yellow_metal_idx, blue_diffuse_idx, yellow_fuzz_idx, glass_15_idx = get_indices_by_loop(hit_recs.material_id, has_hits)
scatter_rays_diffuse(hittable_list, rays, depth, colors, hit_recs, ground_idx, Vec(0.8, 0.8, 0))
scatter_rays_diffuse(hittable_list, rays, depth, colors, hit_recs, blue_diffuse_idx, Vec(0.1, 0.2, 0.5))
scatter_rays_metal(hittable_list, rays, depth, colors, hit_recs, yellow_metal_idx, Vec(0.8, 0.6, 0.2))
scatter_rays_fuzz(hittable_list, rays, depth, colors, hit_recs, yellow_fuzz_idx, Vec(0.8, 0.6, 0.2))
scatter_rays_dialectric(hittable_list, rays, depth, colors, hit_recs, glass_15_idx, Vec(1., .8, 1.), 0.8)
return colors
camera_origin = Vec(2, 0, 0).to(device)
camera_lookat = Vec(0, 0, -1).to(device)
camera_up = Vec(0, 1., 0.).to(device)
def paint():
# # calculate camera basis
# focal_length = vec_norm(camera_lookat - camera_origin)
w = -vec_unit(camera_lookat - camera_origin) # new z
u = vec_unit(torch.cross(camera_up, w)) # new x
v = torch.cross(w, u) # this should give camera up I think?
render = torch.zeros(size=(height, width, 3))
delta_plane_y = 1/height
delta_plane_x = 1/width * aspect_ratio
x_linspace, y_linspace = np.linspace(-.5, .5, width), np.linspace(-.5, .5, height)
plane_lower_lefts = np.array([*np.meshgrid(x_linspace, y_linspace), np.zeros(shape=(height,width))]).transpose([2,1,0])
plane_lower_lefts[..., 0] *= aspect_ratio
plane_lower_lefts = np.repeat(plane_lower_lefts[None, ...], samples_per_pixel, axis=0)
plane_lower_lefts = torch.Tensor(plane_lower_lefts)
anti_alias_randomness = torch.FloatTensor(samples_per_pixel, width, height, 3).uniform_(0, 1)
anti_alias_randomness[..., 2] = 0. # z axis
anti_alias_randomness[..., 0] *= delta_plane_x # x axis
anti_alias_randomness[..., 1] *= delta_plane_y # y axis
plane_positions = (plane_lower_lefts + anti_alias_randomness).to(device)
rays = torch.zeros((samples_per_pixel, width, height, 2, 3))
rays[..., 0, :] = camera_origin
rays[..., 1, :] = plane_positions - Vec(0, 0, 1)
# flatten rays to a 'list of rays'
rays = rays.reshape((samples_per_pixel*width*height, 2, 3))
rays[..., 1, :] = (rays[..., 1, :] @ torch.vstack([u, v, w]).to(device))
render = ray_color(hittable_list, rays)
render = render.reshape((samples_per_pixel, width, height, 3)).mean(axis=0)
gamma_correct(render)
render = render.numpy()
return render
start = time.time()
paint()
end = time.time()
print(f"Render Taken: {end-start:.4}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment