Skip to content

Instantly share code, notes, and snippets.

@miketsprague
Created June 25, 2019 13:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save miketsprague/3ea3fd175be75c7c0dd302da92298b27 to your computer and use it in GitHub Desktop.
Save miketsprague/3ea3fd175be75c7c0dd302da92298b27 to your computer and use it in GitHub Desktop.
NTPConnection edited to make a race condition crash likely
//
// NTPConnection.swift
// TrueTime
//
// Created by Michael Sanders on 8/10/16.
// Copyright © 2016 Instacart. All rights reserved.
//
import CTrueTime
import Foundation
typealias NTPConnectionCallback = (NTPConnection, FrozenNetworkTimeResult) -> Void
final class NTPConnection {
let address: SocketAddress
let timeout: TimeInterval
let maxRetries: Int
var logger: LogCallback?
static func query(addresses: [SocketAddress],
config: NTPConfig,
logger: LogCallback?,
callbackQueue: DispatchQueue,
progress: @escaping NTPConnectionCallback) -> [NTPConnection] {
let connections = addresses.flatMap { address in
(0..<config.numberOfSamples).map { _ in
NTPConnection(address: address,
timeout: 0.0001, // force the connection timer to time out immediately
maxRetries: config.maxRetries,
logger: logger)
}
}
var throttleConnections: (() -> Void)?
let onComplete: NTPConnectionCallback = { connection, result in
progress(connection, result)
throttleConnections?()
}
throttleConnections = {
let remainingConnections = connections.filter { $0.canRetry }
let activeConnections = Array(remainingConnections[0..<min(config.maxConnections,
remainingConnections.count)])
activeConnections.forEach { $0.start(callbackQueue, onComplete: onComplete) }
}
throttleConnections?()
return connections
}
required init(address: SocketAddress,
timeout: TimeInterval,
maxRetries: Int,
logger: LogCallback?) {
self.address = address
self.timeout = timeout
self.maxRetries = maxRetries
self.logger = logger
}
deinit {
assert(!self.started, "Unclosed connection")
}
var canRetry: Bool {
var canRetry: Bool = false
lockQueue.sync {
canRetry = self.attempts < self.maxRetries && !self.didTimeout && !self.finished
}
return canRetry
}
func start(_ callbackQueue: DispatchQueue, onComplete: @escaping NTPConnectionCallback) {
lockQueue.async {
guard !self.started else { return }
self.callbackPending = true
var ctx = CFSocketContext(
version: 0,
info: UnsafeMutableRawPointer(Unmanaged.passRetained(self).toOpaque()),
retain: nil,
release: nil,
copyDescription: nil
)
self.attempts += 1
self.callbackQueue = callbackQueue
self.onComplete = onComplete
self.socket = CFSocketCreate(nil,
self.address.family,
SOCK_DGRAM,
IPPROTO_UDP,
NTPConnection.callbackFlags,
self.dataCallback,
&ctx)
if let socket = self.socket {
CFSocketSetSocketFlags(socket, kCFSocketCloseOnInvalidate)
self.source = CFSocketCreateRunLoopSource(nil, socket, 0)
}
if let source = self.source {
CFRunLoopAddSource(CFRunLoopGetMain(), source, CFRunLoopMode.commonModes)
self.startTimer()
}
}
}
func close(waitUntilFinished wait: Bool = false) {
let work = {
self.cancelTimer()
guard let socket = self.socket, let source = self.source else { return }
let disabledFlags = NTPConnection.callbackFlags |
kCFSocketAutomaticallyReenableDataCallBack |
kCFSocketAutomaticallyReenableReadCallBack |
kCFSocketAutomaticallyReenableWriteCallBack |
kCFSocketAutomaticallyReenableAcceptCallBack
CFSocketDisableCallBacks(socket, disabledFlags)
CFSocketInvalidate(socket)
CFRunLoopRemoveSource(CFRunLoopGetMain(), source, CFRunLoopMode.commonModes)
self.socket = nil
self.source = nil
self.debugLog("Connection closed \(self.address)")
}
if wait {
lockQueue.sync(execute: work)
} else {
lockQueue.async(execute: work)
}
}
func debugLog(_ message: @autoclosure () -> String) {
#if DEBUG_LOGGING
logger?(message())
#endif
}
private let dataCallback: CFSocketCallBack = { socket, type, address, data, info in
guard let info = info else { return }
let retainedClient = Unmanaged<NTPConnection>.fromOpaque(info)
let client = retainedClient.takeUnretainedValue()
guard let socket = socket, CFSocketIsValid(socket) else { return }
// Can't use switch here as these aren't defined as an enum.
if type == .dataCallBack {
let data = unsafeBitCast(data, to: CFData.self) as Data
client.callbackPending = false
client.handleResponse(data)
retainedClient.release()
} else if type == .writeCallBack {
client.debugLog("Buffer \(client.address) writable - requesting time")
client.requestTime()
} else {
assertionFailure("Unexpected socket callback")
}
}
var timer: DispatchSourceTimer?
private static let callbackTypes: [CFSocketCallBackType] = [.dataCallBack, .writeCallBack]
private static let callbackFlags: CFOptionFlags = callbackTypes.map {
$0.rawValue
}.reduce(0, |)
private let lockQueue = DispatchQueue(label: "com.instacart.ntp.connection")
private var attempts: Int = 0
private var callbackQueue: DispatchQueue?
private var didTimeout: Bool = false
private var onComplete: NTPConnectionCallback?
private var requestTicks: timeval?
private var socket: CFSocket?
private var source: CFRunLoopSource?
private var startTime: ntp_time_t?
private var finished: Bool = false
private var callbackPending: Bool = false
}
extension NTPConnection: TimedOperation {
var timerQueue: DispatchQueue { return lockQueue }
var started: Bool { return self.socket != nil }
func timeoutError(_ error: NSError) {
self.didTimeout = true
complete(.failure(error))
}
}
private extension NTPConnection {
func complete(_ result: FrozenNetworkTimeResult) {
guard let callbackQueue = callbackQueue, let onComplete = onComplete else {
assertionFailure("Completion callback not initialized")
return
}
close()
switch result {
case let .failure(error) where attempts < maxRetries && !didTimeout:
debugLog("Got error from \(address) (attempt \(attempts)), " +
"trying again. \(error)")
start(callbackQueue, onComplete: onComplete)
case .failure, .success:
finished = true
callbackQueue.async {
onComplete(self, result)
}
}
if callbackPending {
callbackPending = false
Unmanaged.passUnretained(self).release()
}
// the `close()` above actually runs on a separate thread. if it finishes after the socket callback completes, we're in trouble.
// Force that by setting the timeinterval here to long enough for the response to come back:
Thread.sleep(forTimeInterval: 0.5)
}
func requestTime() {
lockQueue.async {
guard let socket = self.socket else {
self.debugLog("Socket closed")
return
}
self.startTime = ntp_time_t(timeSince1970: .now())
self.requestTicks = .uptime()
if let startTime = self.startTime {
let packet = self.requestPacket(startTime).bigEndian
let interval = TimeInterval(milliseconds: startTime.milliseconds)
self.debugLog("Sending time: \(Date(timeIntervalSince1970: interval))")
let err = CFSocketSendData(socket,
self.address.networkData as CFData,
packet.data as CFData,
self.timeout)
if err != .success {
self.complete(.failure(NSError(errno: errno)))
} else {
self.startTimer()
}
}
}
}
func handleResponse(_ data: Data) {
let responseTicks = timeval.uptime()
lockQueue.async {
guard self.started else { return } // Socket closed.
guard data.count == MemoryLayout<ntp_packet_t>.size else { return } // Invalid packet length.
guard let startTime = self.startTime, let requestTicks = self.requestTicks else {
assertionFailure("Uninitialized timestamps")
return
}
let packet = data.withUnsafeBytes { $0.load(as: ntp_packet_t.self) }.nativeEndian
let responseTime = startTime.milliseconds + (responseTicks.milliseconds -
requestTicks.milliseconds)
guard let response = NTPResponse(packet: packet, responseTime: responseTime) else {
self.complete(.failure(NSError(trueTimeError: .badServerResponse)))
return
}
self.debugLog("Buffer \(self.address) has read data!")
self.debugLog("Start time: \(startTime.milliseconds) ms, " +
"response: \(packet.timeDescription)")
self.debugLog("Clock offset: \(response.offset) milliseconds")
self.debugLog("Round-trip delay: \(response.delay) milliseconds")
self.complete(.success(FrozenNetworkTime(time: response.networkDate,
uptime: responseTicks,
serverResponse: response,
startTime: startTime)))
}
}
func requestPacket(_ time: ntp_time_t) -> ntp_packet_t {
var packet = ntp_packet_t()
packet.client_mode = 3
packet.version_number = 3
packet.transmit_time = time
return packet
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment