Skip to content

Instantly share code, notes, and snippets.

@kielgillard
Created May 8, 2020 06:35
Show Gist options
  • Save kielgillard/59066191d2f78a24b05782cec0271d3c to your computer and use it in GitHub Desktop.
Save kielgillard/59066191d2f78a24b05782cec0271d3c to your computer and use it in GitHub Desktop.
A rough and ungeneralised custom publisher/operator for Combine. When the upstream fails with a 401 Unauthorized, subscribe to the reauthenticate publisher and if it completes with success, re-subscribe to the upstream.
//
// ViewController.swift
// Streams
//
// Created by Kiel Gillard on 17/4/20.
// Copyright © 2020 Streams. All rights reserved.
//
import UIKit
import Combine
class ViewController: UIViewController {
var session: URLSession?
var cancellables = Set<AnyCancellable>(minimumCapacity: 1)
override func viewDidLoad() {
super.viewDidLoad()
let config = URLSessionConfiguration.ephemeral
config.protocolClasses = [Stubber.self]
let session = URLSession(configuration: config)
self.session = session
let url = URL(string: "https://localhost")!
let reauthenticatePublisher = session.dataTaskPublisher(for: url)
session
.dataTaskPublisher(for: url)
.retryIfAuthorizationExpired(using: reauthenticatePublisher)
.compactMap { String(data: $0.data, encoding: .utf8) }
.sink(receiveCompletion: { completion in
print(completion)
}, receiveValue: { text in
print("the text is: ", text)
})
.store(in: &cancellables)
}
}
class Stubber: URLProtocol {
override class func canInit(with request: URLRequest) -> Bool {
return true
}
override class func canonicalRequest(for request: URLRequest) -> URLRequest {
return request
}
// Changing this value simulates the request failing the first or second time.
private static var shouldFail = false
override func startLoading() {
defer {
// Changing this value simulates the request failing when attempting to reauthenticate.
Stubber.shouldFail = false
}
guard Stubber.shouldFail else {
client?.urlProtocol(self, didReceive: HTTPURLResponse(url: request.url!, statusCode: 200, httpVersion: "1.1", headerFields: nil)!, cacheStoragePolicy: .notAllowed)
client?.urlProtocol(self, didLoad: "You have authenticated!".data(using: .utf8)!)
client?.urlProtocolDidFinishLoading(self)
return
}
client?.urlProtocol(self, didReceive: HTTPURLResponse(url: request.url!, statusCode: 401, httpVersion: "1.1", headerFields: nil)!, cacheStoragePolicy: .notAllowed)
client?.urlProtocolDidFinishLoading(self)
}
override func stopLoading() {
return
}
}
/// When the upstream fails with a 401 Unauthorized, subscribe to the reauthenticatePublisher and if it completes with success, re-subscribe to the upstream.
struct Reauthenticate: Publisher {
typealias Upstream = URLSession.DataTaskPublisher
typealias Output = Upstream.Output
typealias Failure = Upstream.Failure
private let upstreamPublisher: Upstream
private let reauthenticatingPublisher: Upstream
init(upstreamPublisher: Upstream, reauthenticatingPublisher: Upstream) {
self.upstreamPublisher = upstreamPublisher
self.reauthenticatingPublisher = reauthenticatingPublisher
}
func receive<S>(subscriber: S) where S : Subscriber, S.Input == Output, S.Failure == Failure {
let requestor = Requestor(stage: .initial, downstream: subscriber, upstreamPublisher: upstreamPublisher, reauthenticatePublisher: reauthenticatingPublisher)
subscriber.receive(subscription: requestor)
}
}
extension URLSession.DataTaskPublisher {
func retryIfAuthorizationExpired(using publisher: URLSession.DataTaskPublisher) -> Reauthenticate {
return Reauthenticate(upstreamPublisher: self, reauthenticatingPublisher: publisher)
}
}
/// When the upstream fails with a 401 Unauthorized, subscribe to the reauthenticatePublisher and if it completes with success, re-subscribe to the upstream.
class Requestor<Downstream: Subscriber>: CustomCombineIdentifierConvertible, Cancellable {
typealias Upstream = URLSession.DataTaskPublisher
var downstream: Downstream?
let upstreamPublisher: Upstream
let reauthenticatePublisher: Upstream
enum Stage {
case initial
case reauthenticating(Upstream.Output)
case retrying
}
let stage: Stage
private var demand = Subscribers.Demand.none
private var upstreamSubscription: Subscription?
private var reauthenticateSubscription: Subscription?
private var canComplete = false
init(stage: Stage, downstream: Downstream?, upstreamPublisher: Upstream, reauthenticatePublisher: Upstream) {
self.stage = stage
self.downstream = downstream
self.upstreamPublisher = upstreamPublisher
self.reauthenticatePublisher = reauthenticatePublisher
}
func cancel() {
upstreamSubscription?.cancel()
reauthenticateSubscription?.cancel()
upstreamSubscription = nil
reauthenticateSubscription = nil
}
}
extension Requestor: Subscription where Downstream.Failure == URLSession.DataTaskPublisher.Failure, Downstream.Input == URLSession.DataTaskPublisher.Output {
func request(_ demand: Subscribers.Demand) {
self.demand = demand
switch stage {
case .initial:
upstreamPublisher.receive(subscriber: self)
case .reauthenticating:
reauthenticatePublisher.receive(subscriber: self)
case .retrying:
upstreamPublisher.receive(subscriber: self)
}
}
}
extension Requestor: Subscriber where Downstream.Failure == URLSession.DataTaskPublisher.Failure, Downstream.Input == URLSession.DataTaskPublisher.Output {
typealias Input = URLSession.DataTaskPublisher.Output
typealias Failure = URLSession.DataTaskPublisher.Failure
func receive(subscription: Subscription) {
// I do not know when (and perhaps why) I should be doing this.
downstream?.receive(subscription: self) // CompactMap
switch stage {
case .initial:
upstreamSubscription = subscription
case .reauthenticating:
reauthenticateSubscription = subscription
case .retrying:
upstreamSubscription = subscription
}
Swift.print("Requesting demand at \(stage).")
subscription.request(demand)
}
func receive(_ input: URLSession.DataTaskPublisher.Output) -> Subscribers.Demand {
switch stage {
case .initial:
if let httpResponse = input.response as? HTTPURLResponse, httpResponse.statusCode == 401 {
let requestor = Requestor(stage: .reauthenticating(input), downstream: downstream, upstreamPublisher: upstreamPublisher, reauthenticatePublisher: reauthenticatePublisher)
requestor.demand = demand
reauthenticatePublisher.subscribe(requestor)
return .none
}
case .reauthenticating:
if let httpResponse = input.response as? HTTPURLResponse, httpResponse.statusCode == 200 {
let requestor = Requestor(stage: .retrying, downstream: downstream, upstreamPublisher: upstreamPublisher, reauthenticatePublisher: reauthenticatePublisher)
requestor.demand = demand
upstreamPublisher.subscribe(requestor)
return .none
}
case .retrying:
break
}
let downstreamInput: Downstream.Input = {
switch stage {
case .initial:
return input
case .reauthenticating(let initialInput):
return initialInput
case .retrying:
return input
}
}()
canComplete = true
Swift.print("Received value at \(stage). Sending value: \(downstreamInput)")
return downstream?.receive(downstreamInput) ?? .none
}
func receive(completion: Subscribers.Completion<URLSession.DataTaskPublisher.Failure>) {
guard canComplete else {
return
}
downstream?.receive(completion: completion)
Swift.print("Completed at \(stage).")
}
}
@jbennett
Copy link

jbennett commented May 8, 2020

Hey @kielgillard, you'll want to watch out for storing the cancellable (line 30). These aren't being removed except on VC dealloc. In this case you are probably fine since you only load once, but if you were refreshing the data you would be leaking. The solution I've seen for this is:

var cancellable: AnyCancellable!
cancellable = session
    .dataTaskPublisher(for: url)
    .stuff()
    .sink(receiveCompletion: { completion in
        print(completion)
        cancellables.remove(cancellable)
    }, receiveValue: { text in
        print("the text is: ", text)
    })
cancellables.insert(cancellable)

Alternatively it might be more clear to remove the set and make cancellables be an AnyCancellable instead of a set.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment