Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
The TokenAcquisitionService automatically retry requests if it receives an unauthorized error. Complete with proof that it works correctly.
//
// RetryingTokenNetworkService.swift
//
// Created by Daniel Tartaglia on 16 Jan 2019.
// Copyright © 2019 Daniel Tartaglia. MIT License.
//
import Foundation
import RxSwift
public typealias Response = (URLRequest) -> Observable<(response: HTTPURLResponse, data: Data)>
/// Builds and makes network requests using the token provided by the service. Will request a new token and retry if the result is an unauthorized (401) error.
///
/// - Parameters:
/// - response: A function that sends requests to the network and emits responses. Can be for example `URLSession.shared.rx.response`
/// - tokenAcquisitionService: The object responsible for tracking the auth token. All requests should use the same object.
/// - request: A function that can build the request when given a token.
/// - Returns: response of a guaranteed authorized network request.
public func getData<T>(response: @escaping Response, tokenAcquisitionService: TokenAcquisitionService<T>, request: @escaping (T) throws -> URLRequest) -> Observable<(response: HTTPURLResponse, data: Data)> {
return Observable
.deferred { tokenAcquisitionService.token.take(1) }
.map { try request($0) }
.flatMap { response($0) }
.map { response in
guard response.response.statusCode != 401 else { throw TokenAcquisitionError.unauthorized }
return response
}
.retryWhen { $0.renewToken(with: tokenAcquisitionService) }
}
// MARK: -
/// Errors recognized by the `TokenAcquisitionService`.
///
/// - unauthorized: It listens for and activates when it receives an `.unauthorized` error.
/// - refusedToken: It emits a `.refusedToken` error if the `getToken` request fails.
public enum TokenAcquisitionError: Error, Equatable {
case unauthorized
case refusedToken(response: HTTPURLResponse, data: Data)
}
public final class TokenAcquisitionService<T> {
/// responds with the current token immediatly and emits a new token whenver a new one is aquired. You can, for example, subscribe to it in order to save the token as it's updated.
public var token: Observable<T> {
return _token.asObservable()
}
public typealias GetToken = (T) -> Observable<(response: HTTPURLResponse, data: Data)>
/// Creates a `TokenAcquisitionService` object that will store the most recent authorization token acquired and will acquire new ones as needed.
///
/// - Parameters:
/// - initialToken: The token the service should start with. Provide a token from storage or an empty string (object represting a missing token) if one has not been aquired yet.
/// - getToken: A function responsable for aquiring new tokens when needed.
/// - extractToken: A function that can extract a token from the data returned by `getToken`.
public init(initialToken: T, getToken: @escaping GetToken, extractToken: @escaping (Data) throws -> T) {
relay
.flatMapFirst { getToken($0) }
.map { (urlResponse) -> T in
guard urlResponse.response.statusCode / 100 == 2 else { throw TokenAcquisitionError.refusedToken(response: urlResponse.response, data: urlResponse.data) }
return try extractToken(urlResponse.data)
}
.startWith(initialToken)
.subscribe(_token)
.disposed(by: disposeBag)
}
/// Allows the token to be set imperativly if necessary.
/// - Parameter token: The new token the service should use. It will immediatly be emitted to any subscribers to the service.
func setToken(_ token: T) {
lock.lock()
_token.onNext(token)
lock.unlock()
}
/// Monitors the source for `.unauthorized` error events and passes all other errors on. When an `.unauthorized` error is seen, `self` will get a new token and emit a signal that it's safe to retry the request.
///
/// - Parameter source: An `Observable` (or like type) that emits errors.
/// - Returns: A trigger that will emit when it's safe to retry the request.
func trackErrors<O: ObservableConvertibleType>(for source: O) -> Observable<Void> where O.Element == Error {
let lock = self.lock
let relay = self.relay
let error = source
.asObservable()
.map { error in
guard (error as? TokenAcquisitionError) == .unauthorized else { throw error }
}
.flatMap { [unowned self] in self.token }
.do(onNext: {
lock.lock()
relay.onNext($0)
lock.unlock()
})
.filter { _ in false }
.map { _ in }
return Observable.merge(token.skip(1).map { _ in }, error)
}
private let _token = ReplaySubject<T>.create(bufferSize: 1)
private let relay = PublishSubject<T>()
private let lock = NSRecursiveLock()
private let disposeBag = DisposeBag()
}
extension ObservableConvertibleType where Element == Error {
/// Monitors self for `.unauthorized` error events and passes all other errors on. When an `.unauthorized` error is seen, the `service` will get a new token and emit a signal that it's safe to retry the request.
///
/// - Parameter service: A `TokenAcquisitionService` object that is being used to store the auth token for the request.
/// - Returns: A trigger that will emit when it's safe to retry the request.
public func renewToken<T>(with service: TokenAcquisitionService<T>) -> Observable<Void> {
return service.trackErrors(for: self)
}
}
class TokenAcquisitionServiceTests: XCTestCase {
var scheduler: TestScheduler!
var tokenResult: TestableObserver<String>!
var triggerResult: TestableObserver<Void>!
var bag: DisposeBag!
override func setUp() {
super.setUp()
scheduler = TestScheduler(initialClock: 0)
tokenResult = scheduler.createObserver(String.self)
triggerResult = scheduler.createObserver(Void.self)
bag = DisposeBag()
}
func testInitial() {
// given
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTFail()
fatalError()
}
func extractToken(_ data: Data) -> String {
XCTFail()
fatalError()
}
let service = TokenAcquisitionService(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
service.token
.bind(to: tokenResult)
.disposed(by: bag)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events, [.next(0, "first")])
}
func testUpdate() {
// given
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
fatalError()
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
scheduler.scheduleAt(10) {
service.setToken("second")
}
service.token
.bind(to: tokenResult)
.disposed(by: bag)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events, [.next(0, "first"), .next(10, "second")])
}
func testUnauthorized() {
// given
let trigger = scheduler.createColdObservable([.next(10, TokenAcquisitionError.unauthorized as Error)])
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 200, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
return Observable.just((response: response, data: data))
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger.renewToken(with: service).bind(to: triggerResult)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events, [.next(0, "first"), .next(10, "second")])
XCTAssertEqual(triggerResult.events.map { $0.time }, [10])
}
func testBadTokenRequest() {
// given
let trigger = scheduler.createColdObservable([.next(10, TokenAcquisitionError.unauthorized as Error)])
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 500, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
return Observable.just((response: response, data: data))
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger.renewToken(with: service).bind(to: triggerResult)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events, [.next(0, "first"), .error(10, TokenAcquisitionError.refusedToken(response: response, data: data))])
XCTAssertEqual(triggerResult.events.map { $0.time }, [10])
}
func testOtherErrorsFallThrough() {
// given
let trigger = scheduler.createColdObservable([.next(10, RxError.unknown as Error)])
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 200, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
return Observable.just((response: response, data: data))
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger.renewToken(with: service).bind(to: triggerResult)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events, [.next(0, "first")])
XCTAssertEqual(triggerResult.events.map { $0.time }, [10])
}
func testMultipleUnauthsOnlyCauseOneTokenRequest() {
// given
let trigger1 = scheduler.createColdObservable([.next(10, TokenAcquisitionError.unauthorized as Error)])
let trigger2 = scheduler.createColdObservable([.next(30, TokenAcquisitionError.unauthorized as Error)])
let triggerResult2 = scheduler.createObserver(Void.self)
var requestCount = 0
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
requestCount += 1
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 200, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
return Observable.just((response: response, data: data)).delay(.seconds(20), scheduler: scheduler)
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger1.renewToken(with: service).bind(to: triggerResult),
trigger2.renewToken(with: service).bind(to: triggerResult2)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events, [.next(0, "first"), .next(30, "second")])
XCTAssertEqual(triggerResult.events.map { $0.time }, [30])
XCTAssertEqual(triggerResult2.events.map { $0.time }, [30])
XCTAssertEqual(requestCount, 1)
}
}
func extractToken(_ data: Data) -> String {
return String(data: data, encoding: .utf8) ?? ""
}
@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 28, 2019

how it using with ?

            .create { observer in
                Rest.callAPI(model: modelType,
                             paramsType: paramsType,
                             apiEndpoint: endpoint,
                             apiMethod: method,
                             bodyParams: params,
                             queryParams: query,
                             stringParams: strings,
                             apiHeaders: headers,
                             requestTaskName: name,
                             timeOutRequest: time)
                    .observeOn(MainScheduler.instance)
                    .timeout(time, scheduler: MainScheduler.instance)
                    .subscribe(onNext: { model in
                        observer.onNext(model)
                        observer.onCompleted()
                    }, onError: { error in
                        observer.onError(error)
                    })
                
        }
@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 28, 2019

callApi it's returned Observable where Model : Mappable

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Aug 28, 2019

@Terens777 You don't need to wrap that code in a create call. Replace the whole thing with:

Observable.deferred { tokenAcquisitionService.token.take(1) }
    .flatMap { token -> Observable<T> in // You might have to adjust the return type depending on how your `callAPI` is declared.
        // you need to insert the current token into the headers below.
        Rest.callAPI(
            model: modelType,
            paramsType: paramsType,
            apiEndpoint: endpoint,
            apiMethod: method,
            bodyParams: params,
            queryParams: query,
            stringParams: strings,
            apiHeaders: headers,
            requestTaskName: name,
            timeOutRequest: time
        )
            .catchError { error in
                // you need some way to know if the error was because of an unauthorized attempt (401)
                if userUnauthorized(error) { throw TokenAcquisitionError.unauthorized }
                else { throw error }
        }
    }
    .observeOn(MainScheduler.instance)
    .timeout(time, scheduler: MainScheduler.instance)
    .retryWhen { $0.renewToken(with: tokenAcquisitionService) }
@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

Thank you for your help, with this example I’ve come closer to my success, do you have an example of how you use this as an example of authorization and a query that results in 401, it’s difficult for me?

@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

it,s my call Api

 static func callAPI<Model>(model modelType: Model.Type,
                               paramsType type: ParamsType,
                               apiEndpoint endpoint: RequestProtocol,
                               apiMethod method: HTTPMethod,
                               bodyParams params: Parameters?,
                               queryParams query: [URLQueryItem]?,
                               stringParams strings: [URLStringItem]?,
                               apiHeaders headers: HTTPHeaders?,
                               parameterEncoding encoding: ParameterEncoding = JSONEncoding.default,
                               requestTaskName name: String?,
                               timeOutRequest time: Double) -> Observable<Model> where Model : Mappable {
        return Observable
            .create { observer in
    
            let requestGenerator = RequestGenerator(url: endpoint.requestPath)
            
            var urlRequest: String
            
            switch type {
            case .body, .none:
                urlRequest = endpoint.requestPath
            case .query:
                urlRequest = requestGenerator.requestWithQuery(queryItems: query ?? [])
            case .strings:
                urlRequest = requestGenerator.requestWithStrings(stringItems: strings ?? [])
            }
            
            PrintRequest.printRequest(method: method, params: params, header: headers ?? [:], path: urlRequest)
         
            let request = Rest.sessionManager.request(urlRequest,
                                                      method: method,
                                                      parameters: params,
                                                      encoding: encoding,
                                                      headers: headers)
            .responseObject { (response: DataResponse<Model>) in
                    
                switch response.result {
                case .failure(let error):
                    observer.onError(error)
                case .success(let model):
                    observer.onNext(model)
                    observer.onCompleted()
                }
            }
    
            request.task?.taskDescription = name
            
            return Disposables.create {
                request.cancel()
            }
        }
    }```
@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

extension DataFetch {
    
    func callAPI<Model> (model modelType            : Model.Type,
                         paramsType type            : ParamsType,
                         apiEndpoint endpoint       : RequestProtocol,
                         apiMethod method           : HTTPMethod,
                         bodyParams params          : Parameters? = nil,
                         queryParams query          : [URLQueryItem]? = nil,
                         stringParams strings       : [URLStringItem]? = nil,
                         apiHeaders headers         : HTTPHeaders? = nil,
                         requestTaskName name       : String? = nil,
                         timeOutRequest time        : Double = 30) -> Observable<Model> where Model: Mappable {
        
        return Observable
            .create { observer in
                Rest.callAPI(model: modelType,
                             paramsType: type,
                             apiEndpoint: endpoint,
                             apiMethod: method,
                             bodyParams: params,
                             queryParams: query,
                             stringParams: strings,
                             apiHeaders: headers,
                             requestTaskName: name,
                             timeOutRequest: time)
                    .observeOn(MainScheduler.instance)
                    .timeout(time, scheduler: MainScheduler.instance)
                    .subscribe(onNext: { model in
                        observer.onNext(model)
                        observer.onCompleted()
                    }, onError: { error in
                        observer.onError(error)
                    })
                
        }
    }
}



@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

DataFetch it's my protocol for Api Calls Classes

@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

I can’t quite understand how to create a token entity for TokenAcquisitionService (

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Aug 29, 2019

I have no idea why your network API is so complex. URLSession.rx.response is all you should need and any function with the signature (URLRequest) -> Observable<(response: HTTPURLResponse, data: Data)> can serve as a proxy/mock.

As for getting a new token, any function with the signature (T) -> Observable<(response: HTTPURLResponse, data: Data)> will do. This function receives the current token T and just needs to make a network request to get a new token and return an Observable witch will emit the results. The extractToken function gets the results in order to find and return the actual token.

You can change the token service however you want to accommodate your api service, but if you don't know how to make a network request to get a token, then I suggest you talk to whoever is writing your server or read its documentation.

@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

I use Alamofire for the network, I never used URLRequest.rx.response, I have a huge social network, there are a lot of parameters for the request, so it’s complicated) you have an example of how you use the token refresh with your any requests, I would really wanted to reconsider my work with api, I understand that this is not a good option, but I don’t know another, and I don’t know where to get it

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Aug 29, 2019

Alamofire isn't something I use, but I'm sure it has some way to let you know when there is a 401 error. Find out what way that is and convert it to a TokenAcquisitionError.unauthorized error. I showed that in the custom code sample I gave you.

@Terens777

This comment has been minimized.

Copy link

Terens777 commented Aug 29, 2019

Thanks for your time, I will work on it

@guseducampos

This comment has been minimized.

Copy link

guseducampos commented Oct 2, 2019

Hi @danielt1263 great article and gist, thank you! Just found this in your post in medium:

share(replay: 1) ensures that any observers that subscribe to the token property will immediately get the most recently acquired token value so they can attempt the first request.

But reviewing your code, you aren't using share(replay: 1) anywhere, so it makes me wonder how actually the same token response is shared across multiple observers when an unauthorized event is fired for multiple requests?

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Oct 2, 2019

Good catch @guseducampos.

I replaced the share(replay: 1) with a ReplaySubject in order to add the mutable setToken at someone's request and didn't update the article.

@marksands

This comment has been minimized.

Copy link

marksands commented Oct 28, 2019

I tried to Combine-ify this, but it's missing a few key operators. Namely, retryWhen is a no go here. If anyone has already done the brutal work and is willing to share—please do!

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Oct 30, 2019

I tried to Combine-ify this, but it's missing a few key operators. Namely, retryWhen is a no go here. If anyone has already done the brutal work and is willing to share—please do!

I'm working on one... Care to help? https://gist.github.com/danielt1263/17ebe60a1c7d9aa87c8b5393639a079e

@cuibeihong

This comment has been minimized.

Copy link

cuibeihong commented Nov 15, 2019

@danielt1263 hi , i use moya , my network request method:

init(endpointClosure: @escaping MoyaProvider<YunguApi>.EndpointClosure = MoyaProvider<YunguApi>.defaultEndpointMapping,
      requestClosure: @escaping MoyaProvider<YunguApi>.RequestClosure = MoyaProvider<YunguApi>.defaultRequestMapping,
      stubClosure: @escaping MoyaProvider<YunguApi>.StubClosure = MoyaProvider<YunguApi>.neverStub,
      session: Session = MoyaProvider<YunguApi>.defaultAlamofireSession(),
      plugins: [PluginType] = [],
      trackInflights: Bool = false,
      online: Observable<Bool> = connectedToInternet()) {
   
   self.online = online
   self.provider = MoyaProvider(endpointClosure: endpointClosure, requestClosure: requestClosure, stubClosure: stubClosure, session: session, plugins: plugins, trackInflights: trackInflights) 
 }
 
 func request(_ token: YunguApi, isSecondTryAfterAuth: Bool = false) -> Observable<Moya.Response> {
   let actualRequest = provider.rx.request(token)
   return online
     .ignore(value: false)  // Wait until we're online
     .take(1)        // Take 1 to make sure we only invoke the API once.
     .flatMap { _ in // Turn the online state into a network request
       return actualRequest
         
   }
 }

I tried using it but it doesn't work
How can I transform it?
thank you

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Nov 27, 2019

@marksands

I tried to Combine-ify this, but it's missing a few key operators. Namely, retryWhen is a no go here. If anyone has already done the brutal work and is willing to share—please do!

I was finally able to create a retryWhen operator: https://gist.github.com/danielt1263/17ebe60a1c7d9aa87c8b5393639a079e

@korovyev

This comment has been minimized.

Copy link

korovyev commented Dec 12, 2019

Am I right in saying the line:
.flatMapFirst { getToken($0) }
Will ignore any subsequent incoming .unauthorized requests while the first one is getting the new token?

To be a bit clearer:
Two unauthorised api calls are made in quick succession, the flapMapFirst gets a new token and as it ignores the second api call, only the first one will get reattempted with the new token.

Is this correct? I would hope the second api call would get reattempted with the new token too.

from the flatMapFirst docs:

If element is received while there is some projected observable sequence being merged it will simply be ignored.

Thank you for both this gist & the article by the way, I'm learning a lot from trying to implement it myself.

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Dec 12, 2019

@korovyev Yes, the subsequent incoming .unauthorized events will be ignored while the system is waiting for the request to return a new token; however, the subsequent observable chains will sit waiting for an event to be emitted by the renewToken observable and the observable will emit to all of them once the new token arrives. Therefore, the second api call in your example does get reattempted with the new token.

@korovyev

This comment has been minimized.

Copy link

korovyev commented Dec 13, 2019

Yep, my implementation works as you describe. thanks again!

@mesheilah

This comment has been minimized.

Copy link

mesheilah commented Jan 21, 2020

is there any Combine version for this gist ?
got stuck at private let _token = ReplaySubject<T>.create(bufferSize: 1) the closest equivalent in Combine is CurrentValueSubject which requires an initial value to be set in its initializer which can't be accomplished with T (because we don't know what is the type T going to be).

@danielt1263

This comment has been minimized.

Copy link
Owner Author

danielt1263 commented Jan 21, 2020

@mesheilah That's easy to take care of, but Combine doesn't have a .retryWhen operator which is required for this solution. I've been working on creating the operator for Combine but I'm not satisfied with what I have so far.

@mesheilah

This comment has been minimized.

Copy link

mesheilah commented Jan 21, 2020

@mesheilah That's easy to take care of, but Combine doesn't have a .retryWhen operator which is required for this solution. I've been working on creating the operator for Combine but I'm not satisfied with what I have so far.

I've checked your custom retryWhen publisher but didn't take it for a spin yet

@EnasAhmedZaki

This comment has been minimized.

Copy link

EnasAhmedZaki commented Feb 23, 2020

Thanks for your time, I will work on it

@Terens777 have you found a solution, I'm using alamofire too and I'm having difficulties implement it with daniel code

@murad1981

This comment has been minimized.

Copy link

murad1981 commented Feb 23, 2020

Thanks for your time, I will work on it

@Terens777 have you found a solution, I'm using alamofire too and I'm having difficulties implement it with daniel code

I have implemented it with Alamofire without problems, you have to create the loadData method that takes pageIndex and returns an Observable of array of result objects, this is the key point, regarding refreshing the token using Alamofire it's easy and it has nothing to do with the code above, you have to create a subclass of RequestInterceptor and override both functions (adapt and retry), there are lots of examples on how to override them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.