Skip to content

Instantly share code, notes, and snippets.

@mrandri19
Created February 28, 2021 18:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrandri19/0c58e108bb3021132ad1487210d4b6ce to your computer and use it in GitHub Desktop.
Save mrandri19/0c58e108bb3021132ad1487210d4b6ce to your computer and use it in GitHub Desktop.
# From the editor: Ctrl-e to execute, Ctrl-l to clear console (My VsCode custom keybindings)
using Images
using LinearAlgebra
import LinearAlgebra: dot
using StaticArrays
using Printf
const Vec3 = SVector{3,Float64}
abstract type Material end
struct Lambertian <: Material
albedo::RGB{N0f8}
end
struct Metal <: Material
albedo::RGB{N0f8}
end
struct Sphere
center::Vec3
radius::Float64
material::Material
end
struct Ray
origin::Vec3
direction::Vec3
end
function at(r::Ray, t::Float64)
r.origin + (r.direction - r.origin) * t
end
struct HitRecord
t::Float64
p::Vec3
normal::Vec3
material::Material
end
function hit(r::Ray, t_min::Float64, t_max::Float64, s::Sphere)::Union{Nothing,HitRecord}
oc = r.origin - s.center;
a = norm(r.direction)^2;
half_b = dot(oc, r.direction);
c = norm(oc)^2 - s.radius^2;
discriminant = half_b^2 - a * c;
if discriminant > 0
root = sqrt(discriminant)
temp = (-half_b - root) / a
if temp < t_max && temp > t_min
p = at(r, temp)
return HitRecord(temp, p, (p - s.center) / s.radius, s.material)
end
temp = (-half_b + root) / a
if temp < t_max && temp > t_min
p = at(r, temp)
return HitRecord(temp, p, (p - s.center) / s.radius, s.material)
end
end
nothing
end
function hit(r::Ray, t_min::Float64, t_max::Float64, world::Array{Sphere})::Union{Nothing,HitRecord}
closest_so_far::Float64 = t_max
current_hit_record = nothing
for object in world
hit_record = hit(r, t_min, closest_so_far, object)
if !isnothing(hit_record)
closest_so_far = hit_record.t
current_hit_record = hit_record
end
end
current_hit_record
end
function random_in_unit_sphere()::Vec3
# https://karthikkaranth.me/blog/generating-random-points-in-a-sphere/
u = rand();
x1 = randn();
x2 = randn();
x3 = randn();
mag = sqrt(x1^2 + x2^2 + x3^2);
c = cbrt(u);
SA_F64[x1,x2,x3] ./ mag .* c
end
function scatter(r::Ray, material::Lambertian, hit_record::HitRecord)::Tuple{Bool,Ray,RGB{N0f8}}
scattered_ray = Ray(hit_record.p, hit_record.normal + random_in_unit_sphere())
attenuation = material.albedo
true, scattered_ray, attenuation
end
function reflect_normal(v::Vec3, n::Vec3)::Vec3
v - 2 * dot(v, n) * n
end
function scatter(r::Ray, material::Metal, hit_record::HitRecord)::Tuple{Bool,Ray,RGB{N0f8}}
scattered_ray = Ray(
hit_record.p,
reflect_normal(normalize(r.direction), hit_record.normal)
)
attenuation = material.albedo
dot(scattered_ray.direction, hit_record.normal) > 0, scattered_ray, attenuation
end
function mul(c1::RGB{N0f8}, c2::RGB{N0f8})
RGB{N0f8}(c1.r * c2.r, c1.g * c2.g, c1.b * c2.b)
end
function ray_color(r::Ray, world, depth)::RGB{N0f8}
if depth > 10
return RGB{N0f8}(0, 0, 0)
end
hit_record = hit(r, Float64(0.0001), Float64(Inf), world)
if !isnothing(hit_record)
should_scatter, scattered_ray, attenuation = scatter(r, hit_record.material, hit_record)
if should_scatter
return mul(attenuation, ray_color(scattered_ray, world, depth + 1))
end
return RGB{N0f8}(0, 0, 0)
end
# Sky
direction = normalize(r.direction)
t = 0.5 * (direction[2] + 1)
(1 - t) * RGB{N0f8}(1, 1, 1) + t * RGB{N0f8}(0.5, 0.7, 1.0)
end
function trace()
W = 1000;
H = 500;
img = Array{RGB{N0f8},2}(undef, H, W);
lower_left_corner = SA_F64[-2, -1, -1];
horizontal = SA_F64[4, 0, 0];
vertical = SA_F64[0, 2, 0];
origin = SA_F64[0, 0, 0];
world::Array{Sphere} = [
Sphere(SA_F64[0;0;-1], 0.5, Lambertian(RGB{N0f8}(0.7, 0.3, 0.3))),
Sphere(SA_F64[1;0;-1], 0.5, Metal(RGB{N0f8}(0.8, 0.6, 0.2))),
Sphere(SA_F64[-1;0;-1], 0.5, Metal(RGB{N0f8}(0.8, 0.8, 0.8))),
Sphere(SA_F64[0;-100.5;-1], 100, Lambertian(RGB{N0f8}(0.8, 0.8, 0.0))),
];
for j = 1:W
for i = 1:H
u = j / W;
v = 1 - i / H;
r = Ray(origin, lower_left_corner + u * horizontal + v * vertical);
img[i, j] = ray_color(r, world, 0);
end
end
img
end
val, t, bytes, gctime, memallocs = @timed trace();
@Printf.printf("t: %.2f s, %.2f MRays/s", t, 1000 * 500 / (1e6 * t))
val
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment