Skip to content

Instantly share code, notes, and snippets.

@warrenm
Created July 22, 2018 09:25
Show Gist options
  • Save warrenm/7e87f44087132312c70ff4395a8bae9a to your computer and use it in GitHub Desktop.
Save warrenm/7e87f44087132312c70ff4395a8bae9a to your computer and use it in GitHub Desktop.
Drawing 2D shapes with Metal
#include <metal_stdlib>
using namespace metal;
struct VertexIn {
float2 position [[attribute(0)]]; // varies per-vertex
float4 color [[attribute(1)]]; // varies per-instance
float3 center [[attribute(2)]]; // varies per-instance
float radius [[attribute(3)]]; // varies per-instance
};
struct FrameUniforms {
float4x4 projectionMatrix;
};
struct VertexOut {
float4 position [[position]];
float4 color;
};
vertex VertexOut vertex_main(VertexIn in [[stage_in]],
constant FrameUniforms &uniforms [[buffer(2)]])
{
VertexOut out;
out.position = uniforms.projectionMatrix * float4(float3(in.position * in.radius, 0) + in.center, 1);
out.color = in.color;
return out;
}
fragment half4 fragment_main(VertexOut in [[stage_in]]) {
return half4(in.color);
}
import Cocoa
import MetalKit
extension Float {
/// Returns a random float value in the range [0, 1]
static func random(_ min: Float = 0, _ max: Float = 1) -> Float {
return (Float(drand48()) * (max - min)) + min
}
}
extension float4 {
var rgbaFromHSVA: float4 {
let h = self.x
let s = self.y
let v = self.z
let a = self.w
var r, g, b: Float
let c = v * s
let x = c * (1.0 - abs(fmod(h * 6, 2) - 1))
let m = v - c
if (h < 1 / 6.0) { r = c + m; g = x + m; b = m }
else if (h < 1 / 3.0) { r = x + m; g = c + m; b = m }
else if (h < 1 / 2.0) { r = m; g = c + m; b = x + m }
else if (h < 2 / 3.0) { r = m; g = x + m; b = c + m }
else if (h < 5 / 6.0) { r = x + m; g = m; b = c + m }
else if (h < 1.0) { r = c + m; g = m; b = x + m }
else { r = m; g = m; b = m }
return float4(r, g, b, a)
}
}
class ViewController: NSViewController, MTKViewDelegate {
let vertexBufferIndex = 0
let instanceBufferIndex = 1
let frameUniformsBufferIndex = 2
let positionAttributeIndex = 0
let colorAttributeIndex = 1
let centerAttributeIndex = 2
let radiusAttributeIndex = 3
let triangleCount = 120
let triangleVertexCount = 3
let circleCount = 200
let circleVertexCount = 64
lazy var circleIndexCount = (circleVertexCount - 1) * 3
let device: MTLDevice
let commandQueue: MTLCommandQueue
var renderPipelineState: MTLRenderPipelineState!
var depthStencilState: MTLDepthStencilState!
var triangleVertexBuffer: MTLBuffer!
var triangleUniformBuffer: MTLBuffer!
var circleVertexBuffer: MTLBuffer!
var circleUniformBuffer: MTLBuffer!
var circleIndexBuffer: MTLBuffer!
var mtkView: MTKView {
return view as! MTKView
}
override init(nibName nibNameOrNil: NSNib.Name?, bundle nibBundleOrNil: Bundle?) {
fatalError("init(nibName:nibBundleOrNil:) has not been implemented")
}
required init?(coder: NSCoder) {
device = MTLCreateSystemDefaultDevice()!
commandQueue = device.makeCommandQueue()!
super.init(coder: coder)
}
override func viewDidLoad() {
super.viewDidLoad()
srand48(Int(mach_absolute_time()))
mtkView.device = device
mtkView.sampleCount = 4
mtkView.colorPixelFormat = .bgra8Unorm_srgb
mtkView.depthStencilPixelFormat = .depth32Float
mtkView.isPaused = true
mtkView.enableSetNeedsDisplay = true
mtkView.delegate = self
renderPipelineState = makeRenderPipelineState()
let depthStencilDescriptor = MTLDepthStencilDescriptor()
depthStencilDescriptor.isDepthWriteEnabled = true
depthStencilDescriptor.depthCompareFunction = .less
depthStencilState = device.makeDepthStencilState(descriptor: depthStencilDescriptor)!
makeResources()
mtkView.needsDisplay = true
}
func makeRenderPipelineState() -> MTLRenderPipelineState {
let descriptor = MTLRenderPipelineDescriptor()
descriptor.sampleCount = mtkView.sampleCount
descriptor.colorAttachments[0].pixelFormat = mtkView.colorPixelFormat
descriptor.depthAttachmentPixelFormat = mtkView.depthStencilPixelFormat
let library = device.makeDefaultLibrary()!
let vertexFunction = library.makeFunction(name: "vertex_main")!
let fragmentFunction = library.makeFunction(name: "fragment_main")!
descriptor.vertexFunction = vertexFunction
descriptor.fragmentFunction = fragmentFunction
let vertexDescriptor = MTLVertexDescriptor()
vertexDescriptor.layouts[vertexBufferIndex].stepRate = 1
vertexDescriptor.layouts[vertexBufferIndex].stepFunction = .perVertex
vertexDescriptor.layouts[vertexBufferIndex].stride = MemoryLayout<float2>.stride
vertexDescriptor.attributes[positionAttributeIndex].bufferIndex = vertexBufferIndex
vertexDescriptor.attributes[positionAttributeIndex].offset = 0
vertexDescriptor.attributes[positionAttributeIndex].format = .float2
vertexDescriptor.layouts[instanceBufferIndex].stepRate = 1
vertexDescriptor.layouts[instanceBufferIndex].stepFunction = .perInstance
vertexDescriptor.layouts[instanceBufferIndex].stride = MemoryLayout<Float>.stride * 8
vertexDescriptor.attributes[colorAttributeIndex].bufferIndex = instanceBufferIndex
vertexDescriptor.attributes[colorAttributeIndex].offset = 0
vertexDescriptor.attributes[colorAttributeIndex].format = .float4
vertexDescriptor.attributes[centerAttributeIndex].bufferIndex = instanceBufferIndex
vertexDescriptor.attributes[centerAttributeIndex].offset = MemoryLayout<float4>.stride
vertexDescriptor.attributes[centerAttributeIndex].format = .float3
vertexDescriptor.attributes[radiusAttributeIndex].bufferIndex = instanceBufferIndex
vertexDescriptor.attributes[radiusAttributeIndex].offset = MemoryLayout<Float>.stride * 7
vertexDescriptor.attributes[radiusAttributeIndex].format = .float
descriptor.vertexDescriptor = vertexDescriptor
return try! device.makeRenderPipelineState(descriptor: descriptor)
}
func makeResources() {
let resourceOptions: MTLResourceOptions = .storageModeShared
let width: Float = 3200
let height: Float = 2000
let triangleVertexBufferLength = MemoryLayout<simd_float2>.stride * triangleVertexCount
triangleVertexBuffer = device.makeBuffer(length: triangleVertexBufferLength, options: resourceOptions)
let anglePhi = Float.pi / 2
var angleDelta = 2 * .pi / Float(triangleVertexCount)
let triangleVertices = triangleVertexBuffer.contents().bindMemory(to: float2.self, capacity: triangleVertexCount)
for i in 0..<triangleVertexCount {
let angle = Float(i) * angleDelta + anglePhi
triangleVertices[i] = float2(cos(angle), sin(angle))
}
let triangleUniformBufferLength = MemoryLayout<Float>.stride * 8 * triangleCount
triangleUniformBuffer = device.makeBuffer(length: triangleUniformBufferLength, options: resourceOptions)
let triangleInstances = triangleUniformBuffer.contents().bindMemory(to: Float.self, capacity: 8 * triangleCount)
for i in 0..<triangleCount {
let color = float4(Float.random(), 1.0, 1.0, 1.0).rgbaFromHSVA
let position = float3(Float.random(-width / 2, width / 2), Float.random(-height / 2, height / 2), Float.random())
let radius = Float.random(20, 100)
triangleInstances[i * 8 + 0] = color[0]
triangleInstances[i * 8 + 1] = color[1]
triangleInstances[i * 8 + 2] = color[2]
triangleInstances[i * 8 + 3] = color[3]
triangleInstances[i * 8 + 4] = position[0]
triangleInstances[i * 8 + 5] = position[1]
triangleInstances[i * 8 + 6] = position[2]
triangleInstances[i * 8 + 7] = radius
}
let circleVertexBufferLength = MemoryLayout<float2>.stride * circleVertexCount
circleVertexBuffer = device.makeBuffer(length: circleVertexBufferLength, options: resourceOptions)
angleDelta = 2 * .pi / Float(circleVertexCount - 1)
let circleVertices = circleVertexBuffer.contents().bindMemory(to: float2.self, capacity: circleVertexCount)
circleVertices[0] = float2(0, 0)
for i in 1..<circleVertexCount {
let angle = Float(i) * angleDelta + anglePhi
circleVertices[i] = float2(cos(angle), sin(angle))
}
let circleIndexBufferLength = MemoryLayout<UInt16>.stride * circleIndexCount
circleIndexBuffer = device.makeBuffer(length: circleIndexBufferLength, options: resourceOptions)
let circleIndices = circleIndexBuffer.contents().bindMemory(to: UInt16.self, capacity: circleIndexCount)
let sideCount = circleVertexCount - 1
for i in 1...(circleIndexCount / 3) {
let baseIdx = (i - 1) * 3
circleIndices[baseIdx + 0] = 0
circleIndices[baseIdx + 1] = UInt16(i)
circleIndices[baseIdx + 2] = UInt16((i % sideCount) + 1)
}
let circleUniformBufferLength = MemoryLayout<Float>.stride * 8 * circleCount
circleUniformBuffer = device.makeBuffer(length: circleUniformBufferLength, options: resourceOptions)
let circleInstances = circleUniformBuffer.contents().bindMemory(to: Float.self, capacity: 8 * circleCount)
for i in 0..<circleCount {
let color = float4(Float.random(), 1.0, 1.0, 1.0).rgbaFromHSVA
let position = float3(Float.random(-width / 2, width / 2), Float.random(-height / 2, height / 2), Float.random())
let radius = Float.random(20, 500)
circleInstances[i * 8 + 0] = color[0]
circleInstances[i * 8 + 1] = color[1]
circleInstances[i * 8 + 2] = color[2]
circleInstances[i * 8 + 3] = color[3]
circleInstances[i * 8 + 4] = position[0]
circleInstances[i * 8 + 5] = position[1]
circleInstances[i * 8 + 6] = position[2]
circleInstances[i * 8 + 7] = radius
}
}
func mtkView(_ view: MTKView, drawableSizeWillChange size: CGSize) {
mtkView.needsDisplay = true
}
func draw(in view: MTKView) {
guard let commandBuffer = commandQueue.makeCommandBuffer(),
let renderPassDescriptor = view.currentRenderPassDescriptor else { return }
renderPassDescriptor.colorAttachments[0].clearColor = MTLClearColorMake(0.95, 0.95, 0.95, 1.0)
guard let renderCommandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor) else { return }
renderCommandEncoder.setRenderPipelineState(renderPipelineState)
renderCommandEncoder.setDepthStencilState(depthStencilState)
renderCommandEncoder.setFrontFacing(.counterClockwise)
renderCommandEncoder.setCullMode(.back)
let width = Float(mtkView.drawableSize.width)
let height = Float(mtkView.drawableSize.height)
let L = Float(-width / 2)
let R = Float(width / 2)
let T = Float(height / 2)
let B = Float(-height / 2)
let N = Float(0)
let F = Float(1)
var projectionMatrix = float4x4([float4( 2 / (R - L), 0.0, 0.0, 0.0),
float4( 0.0, 2 / (T - B), 0.0, 0.0),
float4( 0.0, 0.0, 1 / (F - N), 0.0),
float4((L + R) / (L - R), (T + B) / (B - T), N / (N - F), 1.0)])
renderCommandEncoder.setVertexBytes(&projectionMatrix,
length: MemoryLayout<float4x4>.size,
index: frameUniformsBufferIndex)
renderCommandEncoder.setVertexBuffer(triangleVertexBuffer, offset: 0, index: vertexBufferIndex)
renderCommandEncoder.setVertexBuffer(triangleUniformBuffer, offset: 0, index: instanceBufferIndex)
renderCommandEncoder.drawPrimitives(type: .triangle,
vertexStart: 0,
vertexCount: triangleVertexCount,
instanceCount: triangleCount)
renderCommandEncoder.setVertexBuffer(circleVertexBuffer, offset: 0, index: vertexBufferIndex)
renderCommandEncoder.setVertexBuffer(circleUniformBuffer, offset: 0, index: instanceBufferIndex)
renderCommandEncoder.drawIndexedPrimitives(type: .triangle,
indexCount: circleIndexCount,
indexType: .uint16,
indexBuffer: circleIndexBuffer,
indexBufferOffset: 0,
instanceCount: circleCount)
renderCommandEncoder.endEncoding()
if let drawable = view.currentDrawable {
commandBuffer.present(drawable)
}
commandBuffer.commit()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment