Last active
December 3, 2020 19:48
-
-
Save dabrahams/9e1401d79fb0302d10d8abaa5f11c1cc to your computer and use it in GitHub Desktop.
Thing for Xihui
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if os(Linux) | |
import Glibc | |
#else | |
import Darwin | |
#endif | |
extension Collection { | |
/// Returns the index of the first element that matches the predicate. | |
/// | |
/// The collection must already be partitioned according to the predicate, as if | |
/// `self.partition(by: predicate)` had already been called. | |
func partitionPoint( | |
where predicate: (Element) throws -> Bool | |
) rethrows -> Index { | |
var n = distance(from: startIndex, to: endIndex) | |
var l = startIndex | |
while n > 0 { | |
let half = n / 2 | |
let mid = index(l, offsetBy: half) | |
if try predicate(self[mid]) { | |
n = half | |
} else { | |
l = index(after: mid) | |
n -= half + 1 | |
} | |
} | |
return l | |
} | |
} | |
/// The abstract shape of a function mapping `Double` into `Double` | |
public struct Shape { | |
/// Creates an instance representing the shape of the values of `curve` over the domain 0...1. | |
public init(_ curve: @escaping (Double)->Double) { | |
self.curve = curve | |
curveSampledAt.0 = curve(0) | |
curveSampledAt.1 = curve(1) | |
precondition( | |
curveSampledAt.0 != curveSampledAt.1, | |
"Curve must have distinct values at 0 and 1.") | |
} | |
/// Returns a function `f` such that `f(domain.lowerBound) == startResult` and | |
/// `f(domain.upperBound) == endResult`, with the shape of intermediate values defined by `self`. | |
/// | |
/// The `curve` with which `self` was created is sampled between 0 and 1 and scaled linearly to | |
/// fit the constraints. | |
public func projected(intoDomain domain: ClosedRange<Double>, startResult: Double, endResult: Double) | |
-> (Double)->Double | |
{ | |
if domain.lowerBound == domain.upperBound { | |
// This kind of projection can be useful for modeling discontinuities. | |
return { | |
x in x == domain.lowerBound ? startResult : endResult | |
} | |
} | |
let domainSize = domain.upperBound - domain.lowerBound | |
let rangeScale = (endResult - startResult) / (curveSampledAt.1 - curveSampledAt.0) | |
return { x in | |
let domainFraction = (x - domain.lowerBound) / domainSize | |
return (curve(domainFraction) - curveSampledAt.0) * rangeScale + startResult | |
} | |
} | |
/// A function describing the shape, when sampled between 0.0 and 1.0 | |
private let curve: (Double)->Double | |
/// Samples of `curve` taken at 0 and 1. | |
private let curveSampledAt: (Double, Double) | |
} | |
/// A mapping from step number to learning rate, expressed as a collection of learning rates, and as | |
/// a callable “function instance”. | |
public struct LearningRateSchedule { | |
/// A fragment of the schedule, when paired with known start step | |
fileprivate typealias Segment = (endStep: Int, rateAtStep: (Int)->Double) | |
/// The entire representation of self. | |
/// | |
/// Always contains at least one segment with an endStep of 0. | |
private var segments: [Segment] | |
/// Creates a schedule that begins at `startRate`. | |
public init(startRate: Double) { | |
segments = [(endStep: 0, rateAtStep: { _ in startRate })] | |
} | |
/// Returns the learning rate at step `n`. | |
/// | |
/// - Precondition: `n >= 0 && n < count` | |
/// - Complexity: O(log(count)) | |
public func callAsFunction(_ n: Int) -> Double { | |
precondition(n >= 0) | |
precondition(n <= count) | |
let p = segments.partitionPoint { $0.endStep >= n } | |
return segments[p].rateAtStep(n) | |
} | |
/// Appends an `n`-step segment with shape `s` and end rate `r` | |
/// | |
/// The start rate of the appended segment is the end rate of the last segment appended, or if | |
/// `self.isEmpty`, the `startRate` with which `self` was initialized. | |
/// | |
/// A discontinous jump from one step to another can be acheived with a segment of one step; the | |
/// shape is irrelevant in that case. | |
/// | |
/// - Precondition: n > 0 | |
public mutating func appendSegment(stepCount n: Int, shape s: Shape, endRate r: Double) { | |
precondition(n > 0) | |
let newEnd = count + n | |
let lastSegment = segments.last! | |
let curve = s.projected( | |
intoDomain: Double(count)...Double(newEnd - 1), | |
startResult: lastSegment.rateAtStep(lastSegment.endStep), endResult: r) | |
segments.append((endStep: newEnd, rateAtStep: { curve(Double($0)) })) | |
} | |
} | |
extension LearningRateSchedule: BidirectionalCollection { | |
/// An element position. | |
public struct Index { | |
/// The absolute step number of the element at `self`. | |
fileprivate let step: Int | |
/// The position of the segment that generates the element at `self`. | |
fileprivate let segment: Array<Segment>.Index | |
} | |
/// The position of the first element. | |
public var startIndex: Index { Index(step: 0, segment: 1) } | |
/// The position one past the last element. | |
public var endIndex: Index { Index(step: count, segment: segments.count) } | |
/// The number of elements in `self`. | |
public var count: Int { segments.last!.endStep } | |
/// Returns the element at `i`. | |
public subscript(i: Index) -> Double { | |
//print("subscript at", i) | |
return segments[i.segment].rateAtStep(i.step) | |
} | |
/// Returns the position in `self` following `i`. | |
public func index(after i: Index) -> Index { | |
let newStep = i.step + 1 | |
let newSegment = i.segment + (newStep == segments[i.segment].endStep ? 1 : 0) | |
return Index(step: newStep, segment: newSegment) | |
} | |
/// Returns the position in `self` preceding `i`. | |
public func index(before i: Index) -> Index { | |
let newSegment = i.segment - (i.step == segments[i.segment - 1].endStep ? 1 : 0) | |
return Index(step: i.step - 1, segment: newSegment) | |
} | |
} | |
extension LearningRateSchedule.Index: Comparable { | |
public static func == (l: Self, r: Self) -> Bool { return l.step == r.step } | |
public static func < (l: Self, r: Self) -> Bool { return l.step < r.step } | |
} | |
// TEST | |
/// An exponential curve. | |
let exponential = Shape(exp) | |
/// A linear curve. | |
let linear = Shape({ $0 }) | |
// Test that curves project | |
let e = exponential.projected(intoDomain: 10.0...20.0, startResult: 2.0, endResult: 1.0) | |
// Create a learning rate schedule with two segments | |
var s = LearningRateSchedule(startRate: 1) | |
s.appendSegment(stepCount: 6, shape: exponential, endRate: 100) | |
s.appendSegment(stepCount: 10, shape: linear, endRate: 0) | |
// Print the whole schedule | |
for (i, r) in s.enumerated() { | |
print("\(i)\t| \(r)") | |
assert(s(i) == r) // Demonstrate that callAsFunction does the right thing. | |
} | |
// Local Variables: | |
// fill-column: 100 | |
// End: |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment