Skip to content

Instantly share code, notes, and snippets.

@FlorianCassayre
Last active March 10, 2018 13:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FlorianCassayre/21f428371018cd27c318a96bda4d10ae to your computer and use it in GitHub Desktop.
Save FlorianCassayre/21f428371018cd27c318a96bda4d10ae to your computer and use it in GitHub Desktop.
A basic ray tracer in Scala. "What I cannot create, I do not understand" (R. Feynman)
import java.awt.Color
import java.awt.image.BufferedImage
import java.io.File
import javax.imageio.ImageIO
object RayTracing extends App {
private def sq(x: Double): Double = x * x
case class Point(x: Double, y: Double, z: Double) {
def +(v: Vector): Point = Point(x + v.x, y + v.y, z + v.z)
def distanceSq(p: Point): Double = sq(x - p.x) + sq(y - p.y) + sq(z - p.z)
def distance(p: Point): Double = Math.sqrt(distanceSq(p))
}
case class Vector(x: Double, y: Double, z: Double) {
def +(v: Vector): Vector = Vector(x + v.x, y + v.y, z + v.z)
def *(d: Double): Vector = Vector(x * d, y * d, z * d)
def /(d: Double): Vector = Vector(x / d, y / d, z / d)
def dot(v: Vector): Double = x * v.x + y * v.y + z * v.z
def opposite: Vector = Vector(-x, -y, -z)
def isZero: Boolean = x == 0.0 && y == 0.0 && z == 0.0
def normalize: Vector = if (!isZero) this / length else this
def lengthSq: Double = sq(x) + sq(y) + sq(z)
def length: Double = Math.sqrt(lengthSq)
}
object Vector {
def apply(to: Point, from: Point): Vector = Vector(to.x - from.x, to.y - from.y, to.z - from.z)
}
case class Sphere(center: Point, radius: Double) {
assert(radius > 0.0)
lazy val radiusSq: Double = sq(radius)
}
def rayTracing(width: Int, height: Int, widthSize: Double, focal: Point, screenToFocal: Vector): Unit = {
assert(width > 0 && height > 0)
assert(!screenToFocal.isZero)
val rayStep = 1.0
val rayIterations = 10000
val halfWidth = (width - 1) / 2.0
val halfWidthSize = (widthSize - 1) / 2.0
val horizontalVectorTemplate = Vector(screenToFocal.y, -screenToFocal.x, 0.0).normalize
val halfHeight = (height - 1) / 2.0
val halfHeightSize = (height * widthSize / width - 1) / 2.0
val spheres = List(Sphere(Point(500, 1000, 200), 200), Sphere(Point(1000, 700, 300), 200))
val image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB)
val g = image.createGraphics()
for {
y <- (0 until height).par
x <- 0 until width
} {
val horizontalVector = horizontalVectorTemplate * halfWidthSize * (x - halfWidth) / halfWidth
val verticalVector = Vector(0.0, 0.0, halfHeightSize * ((height - y - 1) - halfHeight) / halfHeight) // Assume screenToFocal.z == 0.0 | FIXME
val vectorPointOnScreen = screenToFocal + horizontalVector + verticalVector
val initialPosition = focal + vectorPointOnScreen
val initialVector = vectorPointOnScreen.normalize * rayStep
var position = initialPosition
var vector = initialVector
var stopped = false
var i = 0
var darkness = 0.0
var color: Option[Color] = None
while (!stopped && i < rayIterations) {
position += vector
// -- Begin collision detection
if (position.z < 0.0) { // Ground
stopped = true
val squareSpacing = 100
val square = (Math.floor(position.x / squareSpacing) + Math.floor(position.y / squareSpacing)) % 2 == 0
color = Some(if (square) Color.BLACK else Color.WHITE)
} else { // Spheres
for (sphere <- spheres) {
if (position.distanceSq(sphere.center) <= sphere.radiusSq) {
val normal = Vector(position, sphere.center).normalize
vector = vector + normal.opposite * 2 * vector.dot(normal)
darkness += (if (position.z - sphere.center.z >= 0) 0.95 * (1 - normal.dot(Vector(0, 0, 1))) + 0.05 else 1)
}
}
}
// -- End collision detection
i += 1
}
lazy val sky = { // Sky
val c = 200 + (100 * Math.atan(vector.normalize.z) / (Math.PI / 2)).toInt
new Color(50, 150, Math.max(Math.min(c, 255), 0))
}
var finalColor = color.getOrElse(sky)
val floats = Array.ofDim[Float](3)
Color.RGBtoHSB(finalColor.getRed, finalColor.getGreen, finalColor.getBlue, floats)
finalColor = Color.getHSBColor(floats(0), floats(1), Math.max(floats(2) - darkness.toFloat * 0.2f, 0))
image.setRGB(x, y, finalColor.getRGB)
}
ImageIO.write(image, "png", new File("ray.png"))
g.dispose()
}
val quality: Int = 1 // 1: a few seconds on a multicore processor / 4: max quality, more than a minute
rayTracing(500 * quality, 300 * quality, 200, Point(10, -2, 100), Vector(100, 120, 0))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment