Last active
March 10, 2018 13:15
-
-
Save FlorianCassayre/21f428371018cd27c318a96bda4d10ae to your computer and use it in GitHub Desktop.
A basic ray tracer in Scala. "What I cannot create, I do not understand" (R. Feynman)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.awt.Color | |
import java.awt.image.BufferedImage | |
import java.io.File | |
import javax.imageio.ImageIO | |
object RayTracing extends App { | |
private def sq(x: Double): Double = x * x | |
case class Point(x: Double, y: Double, z: Double) { | |
def +(v: Vector): Point = Point(x + v.x, y + v.y, z + v.z) | |
def distanceSq(p: Point): Double = sq(x - p.x) + sq(y - p.y) + sq(z - p.z) | |
def distance(p: Point): Double = Math.sqrt(distanceSq(p)) | |
} | |
case class Vector(x: Double, y: Double, z: Double) { | |
def +(v: Vector): Vector = Vector(x + v.x, y + v.y, z + v.z) | |
def *(d: Double): Vector = Vector(x * d, y * d, z * d) | |
def /(d: Double): Vector = Vector(x / d, y / d, z / d) | |
def dot(v: Vector): Double = x * v.x + y * v.y + z * v.z | |
def opposite: Vector = Vector(-x, -y, -z) | |
def isZero: Boolean = x == 0.0 && y == 0.0 && z == 0.0 | |
def normalize: Vector = if (!isZero) this / length else this | |
def lengthSq: Double = sq(x) + sq(y) + sq(z) | |
def length: Double = Math.sqrt(lengthSq) | |
} | |
object Vector { | |
def apply(to: Point, from: Point): Vector = Vector(to.x - from.x, to.y - from.y, to.z - from.z) | |
} | |
case class Sphere(center: Point, radius: Double) { | |
assert(radius > 0.0) | |
lazy val radiusSq: Double = sq(radius) | |
} | |
def rayTracing(width: Int, height: Int, widthSize: Double, focal: Point, screenToFocal: Vector): Unit = { | |
assert(width > 0 && height > 0) | |
assert(!screenToFocal.isZero) | |
val rayStep = 1.0 | |
val rayIterations = 10000 | |
val halfWidth = (width - 1) / 2.0 | |
val halfWidthSize = (widthSize - 1) / 2.0 | |
val horizontalVectorTemplate = Vector(screenToFocal.y, -screenToFocal.x, 0.0).normalize | |
val halfHeight = (height - 1) / 2.0 | |
val halfHeightSize = (height * widthSize / width - 1) / 2.0 | |
val spheres = List(Sphere(Point(500, 1000, 200), 200), Sphere(Point(1000, 700, 300), 200)) | |
val image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB) | |
val g = image.createGraphics() | |
for { | |
y <- (0 until height).par | |
x <- 0 until width | |
} { | |
val horizontalVector = horizontalVectorTemplate * halfWidthSize * (x - halfWidth) / halfWidth | |
val verticalVector = Vector(0.0, 0.0, halfHeightSize * ((height - y - 1) - halfHeight) / halfHeight) // Assume screenToFocal.z == 0.0 | FIXME | |
val vectorPointOnScreen = screenToFocal + horizontalVector + verticalVector | |
val initialPosition = focal + vectorPointOnScreen | |
val initialVector = vectorPointOnScreen.normalize * rayStep | |
var position = initialPosition | |
var vector = initialVector | |
var stopped = false | |
var i = 0 | |
var darkness = 0.0 | |
var color: Option[Color] = None | |
while (!stopped && i < rayIterations) { | |
position += vector | |
// -- Begin collision detection | |
if (position.z < 0.0) { // Ground | |
stopped = true | |
val squareSpacing = 100 | |
val square = (Math.floor(position.x / squareSpacing) + Math.floor(position.y / squareSpacing)) % 2 == 0 | |
color = Some(if (square) Color.BLACK else Color.WHITE) | |
} else { // Spheres | |
for (sphere <- spheres) { | |
if (position.distanceSq(sphere.center) <= sphere.radiusSq) { | |
val normal = Vector(position, sphere.center).normalize | |
vector = vector + normal.opposite * 2 * vector.dot(normal) | |
darkness += (if (position.z - sphere.center.z >= 0) 0.95 * (1 - normal.dot(Vector(0, 0, 1))) + 0.05 else 1) | |
} | |
} | |
} | |
// -- End collision detection | |
i += 1 | |
} | |
lazy val sky = { // Sky | |
val c = 200 + (100 * Math.atan(vector.normalize.z) / (Math.PI / 2)).toInt | |
new Color(50, 150, Math.max(Math.min(c, 255), 0)) | |
} | |
var finalColor = color.getOrElse(sky) | |
val floats = Array.ofDim[Float](3) | |
Color.RGBtoHSB(finalColor.getRed, finalColor.getGreen, finalColor.getBlue, floats) | |
finalColor = Color.getHSBColor(floats(0), floats(1), Math.max(floats(2) - darkness.toFloat * 0.2f, 0)) | |
image.setRGB(x, y, finalColor.getRGB) | |
} | |
ImageIO.write(image, "png", new File("ray.png")) | |
g.dispose() | |
} | |
val quality: Int = 1 // 1: a few seconds on a multicore processor / 4: max quality, more than a minute | |
rayTracing(500 * quality, 300 * quality, 200, Point(10, -2, 100), Vector(100, 120, 0)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment