Skip to content

Instantly share code, notes, and snippets.

@kreeger
Created December 27, 2019 02:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kreeger/7423a72a6b22f98c1d22954c93f062d0 to your computer and use it in GitHub Desktop.
Save kreeger/7423a72a6b22f98c1d22954c93f062d0 to your computer and use it in GitHub Desktop.
A convenient way to mock out URL requests in Swift tests.
//
// HTTPRecording.swift
//
import Foundation
struct HTTPRecording {
let statusCode: Int
let httpVersion: String
let headers: [String: String]
let body: Data?
init?(data: Data, encoding: String.Encoding) {
guard let content = String(data: data, encoding: encoding) else { return nil }
var lines = content.components(separatedBy: .newlines).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }
let firstLineElements = lines.removeFirst().components(separatedBy: .whitespaces)
guard firstLineElements.count >= 2, let statusCode = Int(firstLineElements[1]) else { return nil }
self.httpVersion = firstLineElements[0]
self.statusCode = statusCode
var headers = [String: String]()
var lastLine = 0
for (idx, line) in lines.enumerated() {
lastLine = idx
guard !line.isEmpty else { break }
var components = line.components(separatedBy: ": ")
guard components.count >= 2 else { break }
headers[components.removeFirst()] = components.joined(separator: ": ")
}
self.headers = headers
if (lastLine + 1) >= lines.count {
self.body = nil
} else {
let remnants = lines[(lastLine + 1)..<lines.count].joined(separator: "\n")
self.body = remnants.data(using: encoding)
}
}
func response(for url: URL) -> HTTPURLResponse? {
return HTTPURLResponse(url: url, statusCode: statusCode, httpVersion: httpVersion, headerFields: headers)
}
}
//
// MockURLProtocol.swift
//
import Foundation
class MockURLProtocol: URLProtocol {
enum Result {
case redirect(request: URLRequest, response: URLResponse)
case recording(HTTPRecording)
case fail(Error)
case authChallenge(URLAuthenticationChallenge)
}
private static var testURLs = [URL: Result]()
private static var requestCounts = [URL: Int]()
// MARK: - URLProtocol overrides
override class func canInit(with request: URLRequest) -> Bool {
guard let url = request.url else { return false }
return findResponseMock(url: url) != nil
}
override class func canonicalRequest(for request: URLRequest) -> URLRequest {
return request
}
override func startLoading() {
defer {
client?.urlProtocolDidFinishLoading(self)
}
guard let url = request.url else { return }
type(of: self).incrementRequestCount(url: url)
switch type(of: self).findResponseMock(url: url) {
case .redirect(let redirectRequest, let redirectResponse)?:
client?.urlProtocol(self, wasRedirectedTo: redirectRequest, redirectResponse: redirectResponse)
case .recording(let recording)?:
if let data = recording.body {
client?.urlProtocol(self, didLoad: data)
}
if let response = recording.response(for: url) {
client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
}
case .fail(let error)?:
client?.urlProtocol(self, didFailWithError: error)
case .authChallenge(let challenge)?:
client?.urlProtocol(self, didReceive: challenge)
default:
assertionFailure("Unexpected network call to url \(url)")
break
}
}
override func stopLoading() { }
// MARK: - Public class functions
class func mock(result: MockURLProtocol.Result, url: String) {
let url = URL(string: url)!
testURLs[url] = result
requestCounts[url] = 0
}
class func clear() {
testURLs.removeAll()
requestCounts.removeAll()
}
// MARK: - Private functions
private class func incrementRequestCount(url: URL) {
// If we already have an exact match, use it.
if let found = requestCounts[url] {
requestCounts[url] = found + 1
return
}
guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false) else {
// Fall back on the original URL and mark it (something went wrong).
requestCounts[url, default: 0] += 1
return
}
components.queryItems = nil
guard let newURL = components.url else { return }
// If we have a known request count at the URL sans query params, use it.
if let found = requestCounts[newURL] {
requestCounts[newURL] = found + 1
return
}
// Otherwise, use the original URL when marking the value.
requestCounts[url, default: 0] += 1
}
private class func findResponseMock(url: URL) -> Result? {
// If we already have an exact match, use it.
if let found = testURLs[url] {
return found
}
guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false) else {
return nil
}
components.queryItems = nil
guard let newURL = components.url else { return nil }
// If we have a known request mock at the URL sans query params, use it.
return testURLs[newURL]
}
}
//
// NetworkTestable.swift
//
import Foundation
enum NetworkTestableError: Error, CustomStringConvertible {
case invalidURL
case unknownBundle
case missingFile
case invalidRecording
var description: String {
switch self {
case .invalidURL: return "invalid URL"
case .unknownBundle: return "unknown bundle identifier"
case .missingFile: return "missing file"
case .invalidRecording: return "invalid recording"
}
}
}
protocol NetworkTestable: AnyObject { }
extension NetworkTestable {
func loadRecording(named name: String, for url: String, bundle: String) throws {
var fileComponents = name.components(separatedBy: ".")
let fileExt = fileComponents.removeLast()
let filename = fileComponents.joined(separator: ".")
guard let bundle = Bundle(identifier: bundle) else { throw NetworkTestableError.unknownBundle }
guard let fileURL = bundle.url(forResource: filename, withExtension: fileExt) else { throw NetworkTestableError.missingFile }
let data = try Data(contentsOf: fileURL)
guard let recording = HTTPRecording(data: data, encoding: .utf8) else { throw NetworkTestableError.invalidRecording }
mockResult(.recording(recording), url: url)
}
func mockResult(_ result: MockURLProtocol.Result, url: String) {
MockURLProtocol.mock(result: result, url: url)
}
func removeAllURLMocks() {
MockURLProtocol.clear()
}
func vendURLSession() -> (session: URLSession, queue: DispatchQueue) {
let configuration = URLSessionConfiguration.ephemeral
configuration.protocolClasses = [MockURLProtocol.self]
let queue = DispatchQueue(label: "URLSession.Mock.DispatchQueue")
let opQueue = OperationQueue()
opQueue.name = "URLSession.Mock.OperationQueue"
opQueue.underlyingQueue = queue
return (URLSession(configuration: configuration, delegate: nil, delegateQueue: opQueue), queue)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment