Skip to content

Instantly share code, notes, and snippets.

@matsuda
Created September 15, 2020 09:50
Show Gist options
  • Save matsuda/f6196460a84e2daff551bc3c8db0d796 to your computer and use it in GitHub Desktop.
Save matsuda/f6196460a84e2daff551bc3c8db0d796 to your computer and use it in GitHub Desktop.
Custom URLProtocol in Swift
import Foundation
open class MockResponse {
static let errorDomain = "API Mock Error"
enum ErrorCode: Int {
case fileNotFound = -10001
case disableReadFile = -10002
var description: String {
switch self {
case .fileNotFound:
return "Not found file"
case .disableReadFile:
return "Can't read file"
}
}
}
}
private extension MockResponse {
// swiftlint:disable type_name
private class __Mock__ {}
static func bundlePath(forResource name: String) -> String? {
var fileExt: String?
if let ext = URL(string: name)?.pathExtension, !ext.isEmpty {
fileExt = ext
}
return Bundle(for: __Mock__.self).path(forResource: name,
ofType: fileExt == nil ? "json" : nil)
}
}
public extension MockResponse {
static func loadJSONData(_ name: String) throws -> Data {
guard let path = bundlePath(forResource: name) else {
throw NSError(domain: errorDomain, code: -1,
userInfo: [NSLocalizedDescriptionKey: "Not found file"])
}
do {
return try Data(contentsOf: URL(fileURLWithPath: path))
} catch {
print("error >>>", error)
throw error
}
}
}
public extension MockResponse {
static func loadJSON(_ name: String, encoding: String.Encoding = .utf8) throws -> Any {
guard let path = bundlePath(forResource: name) else {
throw NSError(domain: errorDomain, code: ErrorCode.fileNotFound.rawValue,
userInfo: [NSLocalizedDescriptionKey: ErrorCode.fileNotFound.description])
}
do {
// let string = try String(contentsOfFile: path, encoding: String.Encoding.shiftJIS)
// .replacingOccurrences(of: "\\\\", with: "¥")
var string = try String(contentsOfFile: path, encoding: encoding)
switch encoding {
case String.Encoding.shiftJIS:
string = string.replacingOccurrences(of: "\\\\", with: "¥")
default:
break
}
guard let data = string.data(using: .utf8) else {
throw NSError(domain: errorDomain, code: ErrorCode.disableReadFile.rawValue,
userInfo: [NSLocalizedDescriptionKey: ErrorCode.disableReadFile.description])
}
return try JSONSerialization.jsonObject(with: data, options: [])
} catch {
print("error >>>", error)
throw error
}
}
}
import Foundation
// MARK: - MockURLParams
public struct MockURLParams {
let path: String
var captures: [String]?
}
// MARK: - MockURLResponse
public typealias MockURLResponse = (Int, Data?)
// MARK: - MockURLAction
public class MockURLAction {
typealias ActionHandler = (URLRequest, MockURLParams) throws -> MockURLResponse
let path: String
let handler: ActionHandler
init(path: String, handler: @escaping ActionHandler) {
self.path = path
self.handler = handler
}
}
// MARK: - MockURLProtocol
open class MockURLProtocol: URLProtocol {
public static var delay: TimeInterval = 0.5
static var actions: [String: MockURLAction] = [:]
let cannedHeaders = ["Content-Type": "application/json;"]
class func register(action: MockURLAction) {
let path = action.path
actions[path] = action
}
/// Required
override open class func canInit(with request: URLRequest) -> Bool {
guard let scheme = request.url?.scheme, scheme == "http" || scheme == "https" else {
return false
}
guard let path = request.url?.absoluteString else { return false }
return isMatchingURL(path)
}
/// Required
override open class func canonicalRequest(for request: URLRequest) -> URLRequest {
return request
}
/// Required
override open func startLoading() {
print(#function)
let request = self.request
// guard let path = request.url?.path else { return }
guard let path = request.url?.absoluteString else { return }
print("request URL >>> \(path)")
let actions = type(of: self).actions
var captures: [String] = []
// swiftlint:disable unused_closure_parameter
let action = try? actions.first { (key, value) -> Bool in
let regexp = try NSRegularExpression(pattern: key, options: [.anchorsMatchLines])
let range = NSRange(location: 0, length: path.count)
let matches = regexp.matches(in: path, options: [], range: range)
print("matches >>>", matches)
if matches.isEmpty { return false }
matches.forEach({ (result) in
captures = (0..<result.numberOfRanges).map({ (i) -> String in
let range = result.range(at: i)
let capture = (path as NSString).substring(with: range)
return capture
})
print("captures >>>", captures)
})
return true
}
if let action = action {
print("matched path >>>>>>>>>> \(path)")
let params = MockURLParams(
path: path,
captures: captures.count > 1 ? Array(captures[0..<captures.count]) : nil
)
do {
let response = try action.value.handler(request, params)
let delay = type(of: self).delay
DispatchQueue.main.asyncAfter(deadline: .now() + delay, execute: { [weak self] in
guard let `self` = self else { return }
self.loadResponse(response: response)
})
} catch {
let delay = type(of: self).delay
DispatchQueue.main.asyncAfter(deadline: .now() + delay, execute: { [weak self] in
guard let `self` = self else { return }
self.loadError(error: error)
})
return
}
}
}
/// Required
override open func stopLoading() {
print(#function)
}
/*
open override class func canInit(with task: URLSessionTask) -> Bool {
guard let request = task.currentRequest else {
return false
}
return canInit(with: request)
}
*/
class func isMatchingURL(_ path: String) -> Bool {
// swiftlint:disable unused_closure_parameter
let action = try? actions.first { (key, value) -> Bool in
let regexp = try NSRegularExpression(pattern: key, options: [.anchorsMatchLines])
let range = NSRange(location: 0, length: path.count)
let matches = regexp.matches(in: path, options: [], range: range)
return !matches.isEmpty
}
return action != nil
}
func loadResponse(response: MockURLResponse) {
let client = self.client
let urlResponse = HTTPURLResponse(url: request.url!, statusCode: response.0, httpVersion: "HTTP/1.1", headerFields: cannedHeaders)
client?.urlProtocol(self, didReceive: urlResponse!, cacheStoragePolicy: .notAllowed)
if let data = response.1 {
client?.urlProtocol(self, didLoad: data)
}
client?.urlProtocolDidFinishLoading(self)
}
func loadError(error: Error) {
let client = self.client
let urlResponse = HTTPURLResponse(url: request.url!, statusCode: 500, httpVersion: "HTTP/1.1", headerFields: cannedHeaders)
client?.urlProtocol(self, didReceive: urlResponse!, cacheStoragePolicy: .notAllowed)
client?.urlProtocol(self, didFailWithError: error)
}
}
import Foundation
// MARK: - SampleURLProtocolManager
public final class SampleURLProtocolManager {
/// singleton
private static let shared: SampleURLProtocolManager = {
let manager = SampleURLProtocolManager()
URLSessionConfiguration.setupStubHandler = {
manager.useAll()
}
return manager
}()
static var isStubbed: Bool = false
/// setup
/// @parameter flag: install or uninstall
class func inject(_ flag: Bool) {
_ = shared
flag ? swizzle() : unswizzle()
}
/// 要再起動
private class func swizzle() {
guard !isStubbed else { return }
URLSessionConfiguration.swizzlingAPIStub()
isStubbed = true
}
/// 要再起動
private class func unswizzle() {
isStubbed = false
}
// MARK: - action
let actions: [SampleURLActionKey: MockURLAction]
private init() {
actions = [
.items: MockURLAction(path: "/items") { (request: URLRequest, params: MockURLParams) -> MockURLResponse in
return (200, try? MockResponse.loadJSONData("items.json"))
// return (500, try? MockResponse.loadJSONData("error.json"))
},
/*
.user_items: MockURLAction(path: "/users/([^\\s\\/]+)/items") { (request: URLRequest, params: MockURLParams) throws -> Data? in
return try MockResponse.loadJSONData("user_items.json")
},
*/
]
}
private func use(_ key: SampleURLActionKey) {
guard let action = actions[key] else { return }
MockURLProtocol.register(action: action)
}
private func useAll() {
for (_, action) in actions {
MockURLProtocol.register(action: action)
}
}
public class func use(_ key: SampleURLActionKey) {
shared.use(key)
}
public class func useAll() {
shared.useAll()
}
}
// MARK: - SampleURLActionKey
public enum SampleURLActionKey: String {
case items
case user_items
}
import Foundation
public extension URLSessionConfiguration {
static var setupStubHandler: (() -> Void)?
@objc class func swizzlingAPIStub() {
let defaultConfig = class_getClassMethod(URLSessionConfiguration.self, #selector(getter:URLSessionConfiguration.default))
let stubbedConfig = class_getClassMethod(URLSessionConfiguration.self, #selector(URLSessionConfiguration.stubbedDefaultConfiguration))
method_exchangeImplementations(defaultConfig!, stubbedConfig!)
}
@objc private class func stubbedDefaultConfiguration() -> URLSessionConfiguration {
let config = stubbedDefaultConfiguration()
config.protocolClasses = [MockURLProtocol.self] as [AnyClass] + config.protocolClasses!
setupStubHandler?()
return config
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment