Skip to content

Instantly share code, notes, and snippets.

@brandosha
Last active May 29, 2024 05:21
Show Gist options
  • Save brandosha/133ff52829adde4b97f74bacff5ccc7d to your computer and use it in GitHub Desktop.
Save brandosha/133ff52829adde4b97f74bacff5ccc7d to your computer and use it in GitHub Desktop.
Real time reflections in SceneKit using Metal

Using the ReflectionSession Class for Real Time Reflections in SceneKit

First, create a new ReflectionSession implementation with your SCNView object. Note: You only need one ReflectionSession object at a time.

let reflectionSession = ReflectionSession(inView: scnView)


Then, add the nodes that you want to be reflective to the ReflectionSession using the addNode(_:roughness:) method. Because .physicallyBased materials won't work, the roughness value will determine how much to blur the cubemap images, simulating a roughness effect.

reflectionSession.addNode(node1)
reflectionSession.addNode(node2, roughness: 0.2)


Now, all you need to do is call the update() method within the render loop of your scene.

override func viewDidLoad() {
  
  super.viewDidLoad()
  
  scnView.delegate = self
  
  // Other setup
  
}

func renderer(_ renderer: SCNSceneRenderer, willRenderScene scene: SCNScene, atTime time: TimeInterval) {
  
  reflectionSession.update()
  
}


Optional: You can use reflectionUpdateSpeed to set the frame rate of the reflections.

reflectionSession.reflectionUpdateSpeed = .slow      // Update reflection every six frames
reflectionSession.reflectionUpdateSpeed = .normal    // Default, update reflection every three frames
reflectionSession.reflectionUpdateSpeed = .fast      // Update reflection every two frames
reflectionSession.reflectionUpdateSpeed = .superFast // Update reflection every frame

Note: The framerate of the reflections will be lower if there are a lot of reflective objects on the screen at one time. Higher values for reflectionUpdateSpeed can slow down the actual frame rate of your game.

//
// ReflectionSession.swift
//
import Foundation
import SceneKit
import MetalPerformanceShaders
public class ReflectionSession {
let reflectionMapSize = 200
private let renderer: SCNRenderer
private let scene: SCNScene?
private let view: SCNView
private let device: MTLDevice
private var nodes: [SCNNode] = []
private var commandQueue: MTLCommandQueue
private var blurHandlers: [MPSImageBox?] = []
// private var renderTimer: Timer
public enum UpdateSpeed {
case slow
case normal
case fast
case reallyFast
}
public var reflectionUpdateSpeed: UpdateSpeed = .normal
public init(inView view: SCNView) {
device = MTLCreateSystemDefaultDevice()!
renderer = SCNRenderer(device: device, options: nil)
scene = view.scene
self.view = view
// renderTimer = Timer()
commandQueue = device.makeCommandQueue()!
let cubemapDescriptor = MTLTextureDescriptor.textureCubeDescriptor(pixelFormat: .rgba8Unorm, size: reflectionMapSize, mipmapped: false)
cubemapDescriptor.usage = MTLTextureUsage(rawValue: MTLTextureUsage.renderTarget.rawValue | MTLTextureUsage.shaderRead.rawValue)
pendingTexture = device.makeTexture(descriptor: cubemapDescriptor)!
renderer.scene = scene
scene?.rootNode.addChildNode(camera360Node)
// renderTimer = Timer.scheduledTimer(timeInterval: 1.0 / 60.0, target: self, selector: #selector(update), userInfo: nil, repeats: true)
// renderTimer.tolerance = (1.0 / 80.0)
}
public func addNode(_ node: SCNNode, roughness: Float = 0) {
nodes.append(node)
var kernelSize = Int(roughness * 95)
if kernelSize > 0 {
if Double(kernelSize) / 2 == Double(kernelSize / 2) { kernelSize -= 1 }
if kernelSize > 95 { kernelSize = 95 }
let blurHandler = MPSImageBox(device: device, kernelWidth: kernelSize, kernelHeight: kernelSize)
blurHandler.edgeMode = .clamp
blurHandlers.append(blurHandler)
} else {
blurHandlers.append(nil)
}
node.categoryBitMask = 2 << nodes.count
}
private var angles = ["x+", "x-", "y+", "y-", "z+", "z-"]
private var currentAngle = 0
private var currentNode = 0
private let camera360Node = { () -> SCNNode in
let rootNode = SCNNode()
let xPlusCam = SCNNode()
xPlusCam.name = "camera_x+"
xPlusCam.camera = SCNCamera()
xPlusCam.camera?.fieldOfView = 90
xPlusCam.eulerAngles.y = GLKMathDegreesToRadians(90)
rootNode.addChildNode(xPlusCam)
let xMinusCam = SCNNode()
xMinusCam.name = "camera_x-"
xMinusCam.camera = SCNCamera()
xMinusCam.camera?.fieldOfView = 90
xMinusCam.eulerAngles.y = GLKMathDegreesToRadians(-90)
rootNode.addChildNode(xMinusCam)
let yPlusCam = SCNNode()
yPlusCam.name = "camera_y+"
yPlusCam.camera = SCNCamera()
yPlusCam.camera?.fieldOfView = 90
yPlusCam.eulerAngles.x = GLKMathDegreesToRadians(90)
rootNode.addChildNode(yPlusCam)
let yMinusCam = SCNNode()
yMinusCam.name = "camera_y-"
yMinusCam.camera = SCNCamera()
yMinusCam.camera?.fieldOfView = 90
yMinusCam.eulerAngles.x = GLKMathDegreesToRadians(-90)
rootNode.addChildNode(yMinusCam)
let zPlusCam = SCNNode()
zPlusCam.name = "camera_z+"
zPlusCam.camera = SCNCamera()
zPlusCam.camera?.fieldOfView = 90
zPlusCam.eulerAngles.y = GLKMathDegreesToRadians(180)
rootNode.addChildNode(zPlusCam)
let zMinusCam = SCNNode()
zMinusCam.name = "camera_z-"
zMinusCam.camera = SCNCamera()
zMinusCam.camera?.fieldOfView = 90
rootNode.addChildNode(zMinusCam)
return rootNode
}()
//SCNScene(named: "art.scnassets/360Camera.scn")!.rootNode.childNode(withName: "360_cam", recursively: false)!
private var pendingTexture: MTLTexture
private func updateReflections() {
let node = nodes[currentNode]
let camera = camera360Node.childNode(withName: "camera_" + angles[currentAngle], recursively: false)!
camera.camera?.categoryBitMask = ~node.categoryBitMask
renderer.pointOfView = camera
texturePending = true
let renderPassDescriptor = MTLRenderPassDescriptor()
renderPassDescriptor.colorAttachments[0].texture = pendingTexture
renderPassDescriptor.colorAttachments[0].slice = currentAngle
renderPassDescriptor.colorAttachments[0].loadAction = .clear
renderPassDescriptor.colorAttachments[0].clearColor = MTLClearColorMake(0, 0, 0, 1.0)
renderPassDescriptor.colorAttachments[0].storeAction = .store
let commandBuffer = commandQueue.makeCommandBuffer()!
renderer.render(withViewport: CGRect(x: 0, y: 0, width: 200, height: 200), commandBuffer: commandBuffer, passDescriptor: renderPassDescriptor)
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
currentAngle = (currentAngle + 1) % angles.count
if currentAngle == 0 {
if let blurHandler = blurHandlers[currentNode] {
for slice in 0..<6 {
let blurBuffer = commandQueue.makeCommandBuffer()!
let copyToEditable = blurBuffer.makeBlitCommandEncoder()!
let copyDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba8Unorm, width: reflectionMapSize, height: reflectionMapSize, mipmapped: false)
let copyTexture = device.makeTexture(descriptor: copyDescriptor)!
let blurDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba8Unorm, width: reflectionMapSize, height: reflectionMapSize, mipmapped: false)
blurDescriptor.usage = .shaderWrite
let blurTexture = device.makeTexture(descriptor: blurDescriptor)!
copyToEditable.copy(from: pendingTexture, sourceSlice: slice, sourceLevel: 0, sourceOrigin: MTLOrigin(x: 0, y: 0, z: 0), sourceSize: MTLSize(width: reflectionMapSize, height: reflectionMapSize, depth: 1), to: copyTexture, destinationSlice: 0, destinationLevel: 0, destinationOrigin: MTLOrigin(x: 0, y: 0, z: 0))
copyToEditable.endEncoding()
blurHandler.encode(commandBuffer: blurBuffer, sourceTexture: copyTexture, destinationTexture: blurTexture)
let copyBackToCube = blurBuffer.makeBlitCommandEncoder()!
copyBackToCube.copy(from: blurTexture, sourceSlice: 0, sourceLevel: 0, sourceOrigin: MTLOrigin(x: 0, y: 0, z: 0), sourceSize: MTLSize(width: reflectionMapSize, height: reflectionMapSize, depth: 1), to: pendingTexture, destinationSlice: slice, destinationLevel: 0, destinationOrigin: MTLOrigin(x: 0, y: 0, z: 0))
copyBackToCube.endEncoding()
blurBuffer.commit()
blurBuffer.waitUntilCompleted()
}
}
node.geometry?.firstMaterial?.reflective.contents = pendingTexture
currentNode = (currentNode + 1) % nodes.count
camera360Node.worldPosition = nodes[currentNode].worldPosition
let cubemapDescriptor = MTLTextureDescriptor.textureCubeDescriptor(pixelFormat: .rgba8Unorm, size: reflectionMapSize, mipmapped: false)
cubemapDescriptor.usage = MTLTextureUsage(rawValue: MTLTextureUsage.renderTarget.rawValue | MTLTextureUsage.shaderRead.rawValue)
pendingTexture = device.makeTexture(descriptor: cubemapDescriptor)!
texturePending = false
}
}
private var texturePending = false
@objc public func update() {
if nodes.isEmpty { return }
if !texturePending {
for _ in 0..<nodes.count {
let node = nodes[currentNode]
if let pointOfView = view.pointOfView {
if view.isNode(node, insideFrustumOf: pointOfView) {
break
}
currentNode = (currentNode + 1) % nodes.count
}
}
}
switch reflectionUpdateSpeed {
case .slow:
updateReflections()
case .normal:
for _ in 1...2 { updateReflections() }
case .fast:
for _ in 1...3 { updateReflections() }
case .reallyFast:
for _ in 1...6 { updateReflections() }
}
}
}
@brandosha
Copy link
Author

It seems to be working for me even with that error.

@ivan-ushakov
Copy link

Yes, I already found that with disabled SSAO problem disappear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment