Created
March 13, 2024 20:05
-
-
Save JadenGeller/94a6d444444068daa626386f0881e2eb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Accelerate | |
public struct Vector512Float32: VectorProtocol, Hashable { | |
public typealias Scalar = Float32 | |
public var components: [Scalar] | |
public init(components: [Scalar]) { | |
self.components = components | |
} | |
static var dimension: UInt { 512 } | |
static var zero: Self { | |
.init(components: .init(repeating: 0, count: Int(dimension))) | |
} | |
var negated: Self { | |
.init(components: vDSP.negative(components)) | |
} | |
static func +=(_ lhs: inout Self, _ rhs: Self) { | |
vDSP.add(lhs.components, rhs.components, result: &lhs.components) | |
} | |
static func *= (_ lhs: inout Self, _ rhs: Scalar) { | |
vDSP.multiply(rhs, lhs.components, result: &lhs.components) | |
} | |
mutating func addProduct(_ lhs: Scalar, _ rhs: Self) { | |
vDSP.add(multiplication: (rhs.components, lhs), components, result: &components) | |
} | |
static func dotProduct(_ lhs: Self, _ rhs: Self) -> Scalar { | |
vDSP.dot(lhs.components, rhs.components) | |
} | |
var magnitude: Scalar { | |
sqrt(vDSP.sumOfSquares(components)) | |
} | |
mutating func lerp(towards destination: Self, by progress: Scalar) { | |
vDSP.linearInterpolate(components, destination.components, using: progress, result: &components) | |
} | |
mutating func slerp(towards destination: Self, by progress: Scalar) { | |
let cosOmega = Self.dotProduct(self, destination) | |
guard abs(cosOmega) < 0.9995 else { return lerp(towards: destination, by: progress) } | |
let omega = acos(cosOmega) | |
let sinOmega = sin(omega) | |
let progressOmega = progress * omega | |
var startFactor = sin(omega - progressOmega) / sinOmega | |
var endFactor = sin(progressOmega) / sinOmega | |
vDSP_vsmsma(components, 1, &startFactor, destination.components, 1, &endFactor, &components, 1, Self.dimension) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import RealModule | |
typealias VectorScalar = AlgebraicField & RealFunctions & Comparable & ExpressibleByFloatLiteral & FloatingPoint | |
protocol VectorProtocol { | |
associatedtype Scalar: VectorScalar | |
static var dimension: UInt { get } | |
static var zero: Self { get } | |
var negated: Self { get } | |
static func +=(_ lhs: inout Self, _ rhs: Self) | |
static func *=(_ lhs: inout Self, _ rhs: Scalar) | |
mutating func addProduct(_ lhs: Scalar, _ rhs: Self) | |
static func dotProduct(_ lhs: Self, _ rhs: Self) -> Scalar | |
var magnitude: Scalar { get } | |
mutating func normalize() | |
} | |
extension VectorProtocol { | |
mutating func normalize() { | |
self *= 1 / magnitude | |
assert(isNormalized()) | |
} | |
func normalized() -> Self { | |
var copy = self | |
copy.normalize() | |
return copy | |
} | |
func isNormalized(tolerance: Scalar = 1e-3) -> Bool { | |
abs(magnitude - 1) < tolerance | |
} | |
static func sum(_ vectors: [Self]) -> Self { | |
vectors.reduce(into: Self.zero) { sum, vector in | |
sum += vector | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment