Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@dabrahams
Last active December 3, 2020 19:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dabrahams/9e1401d79fb0302d10d8abaa5f11c1cc to your computer and use it in GitHub Desktop.
Save dabrahams/9e1401d79fb0302d10d8abaa5f11c1cc to your computer and use it in GitHub Desktop.
Thing for Xihui
#if os(Linux)
import Glibc
#else
import Darwin
#endif
extension Collection {
/// Returns the index of the first element that matches the predicate.
///
/// The collection must already be partitioned according to the predicate, as if
/// `self.partition(by: predicate)` had already been called.
func partitionPoint(
where predicate: (Element) throws -> Bool
) rethrows -> Index {
var n = distance(from: startIndex, to: endIndex)
var l = startIndex
while n > 0 {
let half = n / 2
let mid = index(l, offsetBy: half)
if try predicate(self[mid]) {
n = half
} else {
l = index(after: mid)
n -= half + 1
}
}
return l
}
}
/// The abstract shape of a function mapping `Double` into `Double`
public struct Shape {
/// Creates an instance representing the shape of the values of `curve` over the domain 0...1.
public init(_ curve: @escaping (Double)->Double) {
self.curve = curve
curveSampledAt.0 = curve(0)
curveSampledAt.1 = curve(1)
precondition(
curveSampledAt.0 != curveSampledAt.1,
"Curve must have distinct values at 0 and 1.")
}
/// Returns a function `f` such that `f(domain.lowerBound) == startResult` and
/// `f(domain.upperBound) == endResult`, with the shape of intermediate values defined by `self`.
///
/// The `curve` with which `self` was created is sampled between 0 and 1 and scaled linearly to
/// fit the constraints.
public func projected(intoDomain domain: ClosedRange<Double>, startResult: Double, endResult: Double)
-> (Double)->Double
{
if domain.lowerBound == domain.upperBound {
// This kind of projection can be useful for modeling discontinuities.
return {
x in x == domain.lowerBound ? startResult : endResult
}
}
let domainSize = domain.upperBound - domain.lowerBound
let rangeScale = (endResult - startResult) / (curveSampledAt.1 - curveSampledAt.0)
return { x in
let domainFraction = (x - domain.lowerBound) / domainSize
return (curve(domainFraction) - curveSampledAt.0) * rangeScale + startResult
}
}
/// A function describing the shape, when sampled between 0.0 and 1.0
private let curve: (Double)->Double
/// Samples of `curve` taken at 0 and 1.
private let curveSampledAt: (Double, Double)
}
/// A mapping from step number to learning rate, expressed as a collection of learning rates, and as
/// a callable “function instance”.
public struct LearningRateSchedule {
/// A fragment of the schedule, when paired with known start step
fileprivate typealias Segment = (endStep: Int, rateAtStep: (Int)->Double)
/// The entire representation of self.
///
/// Always contains at least one segment with an endStep of 0.
private var segments: [Segment]
/// Creates a schedule that begins at `startRate`.
public init(startRate: Double) {
segments = [(endStep: 0, rateAtStep: { _ in startRate })]
}
/// Returns the learning rate at step `n`.
///
/// - Precondition: `n >= 0 && n < count`
/// - Complexity: O(log(count))
public func callAsFunction(_ n: Int) -> Double {
precondition(n >= 0)
precondition(n <= count)
let p = segments.partitionPoint { $0.endStep >= n }
return segments[p].rateAtStep(n)
}
/// Appends an `n`-step segment with shape `s` and end rate `r`
///
/// The start rate of the appended segment is the end rate of the last segment appended, or if
/// `self.isEmpty`, the `startRate` with which `self` was initialized.
///
/// A discontinous jump from one step to another can be acheived with a segment of one step; the
/// shape is irrelevant in that case.
///
/// - Precondition: n > 0
public mutating func appendSegment(stepCount n: Int, shape s: Shape, endRate r: Double) {
precondition(n > 0)
let newEnd = count + n
let lastSegment = segments.last!
let curve = s.projected(
intoDomain: Double(count)...Double(newEnd - 1),
startResult: lastSegment.rateAtStep(lastSegment.endStep), endResult: r)
segments.append((endStep: newEnd, rateAtStep: { curve(Double($0)) }))
}
}
extension LearningRateSchedule: BidirectionalCollection {
/// An element position.
public struct Index {
/// The absolute step number of the element at `self`.
fileprivate let step: Int
/// The position of the segment that generates the element at `self`.
fileprivate let segment: Array<Segment>.Index
}
/// The position of the first element.
public var startIndex: Index { Index(step: 0, segment: 1) }
/// The position one past the last element.
public var endIndex: Index { Index(step: count, segment: segments.count) }
/// The number of elements in `self`.
public var count: Int { segments.last!.endStep }
/// Returns the element at `i`.
public subscript(i: Index) -> Double {
//print("subscript at", i)
return segments[i.segment].rateAtStep(i.step)
}
/// Returns the position in `self` following `i`.
public func index(after i: Index) -> Index {
let newStep = i.step + 1
let newSegment = i.segment + (newStep == segments[i.segment].endStep ? 1 : 0)
return Index(step: newStep, segment: newSegment)
}
/// Returns the position in `self` preceding `i`.
public func index(before i: Index) -> Index {
let newSegment = i.segment - (i.step == segments[i.segment - 1].endStep ? 1 : 0)
return Index(step: i.step - 1, segment: newSegment)
}
}
extension LearningRateSchedule.Index: Comparable {
public static func == (l: Self, r: Self) -> Bool { return l.step == r.step }
public static func < (l: Self, r: Self) -> Bool { return l.step < r.step }
}
// TEST
/// An exponential curve.
let exponential = Shape(exp)
/// A linear curve.
let linear = Shape({ $0 })
// Test that curves project
let e = exponential.projected(intoDomain: 10.0...20.0, startResult: 2.0, endResult: 1.0)
// Create a learning rate schedule with two segments
var s = LearningRateSchedule(startRate: 1)
s.appendSegment(stepCount: 6, shape: exponential, endRate: 100)
s.appendSegment(stepCount: 10, shape: linear, endRate: 0)
// Print the whole schedule
for (i, r) in s.enumerated() {
print("\(i)\t| \(r)")
assert(s(i) == r) // Demonstrate that callAsFunction does the right thing.
}
// Local Variables:
// fill-column: 100
// End:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment