Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the HTTP client #530

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ All notable changes to this project will be documented in this file. Take a look
#### Shared

* The default `ZIPArchiveOpener` is now using ZIPFoundation instead of Minizip, with improved performances when reading ranges of `stored` ZIP entries.
* Improvements in the HTTP client:
* The `consume` closure of `HTTPClient.stream()` can now return an error to abort the HTTP request.
* `HTTPError` has been refactored for improved type safety and a clearer separation of connection errors versus HTTP errors.
* `DefaultHTTPClient` no longer automatically restarts a failed `HEAD` request as a `GET` to retrieve the response body. If you relied on this behavior, you can implement it using a custom `DefaultHTTPClientDelegate.httpClient(_:recoverRequest:fromError:)`.

### Fixed

#### Shared

* Fixed a crash using `HTTPClient.download()` when the device storage is full.

#### OPDS

* Fixed a data race in the OPDS 1 parser.
Expand Down
9 changes: 8 additions & 1 deletion Sources/Adapters/GCDWebServer/GCDHTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,14 @@ public class GCDHTTPServer: HTTPServer, Loggable {
log(.warning, "Resource not found for request \(request)")
completion(
HTTPServerRequest(url: url, href: nil),
HTTPServerResponse(error: .notFound),
HTTPServerResponse(error: .errorResponse(HTTPResponse(
request: HTTPRequest(url: url),
url: url,
status: .notFound,
headers: [:],
mediaType: nil,
body: nil
))),
nil
)
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/LCP/LCPError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public enum RenewError: Error {
// Incorrect renewal period, your publication could not be renewed.
case invalidRenewalPeriod(maxRenewDate: Date?)
// An unexpected error has occurred on the licensing server.
case unexpectedServerError
case unexpectedServerError(HTTPError)
}

/// Errors while returning a loan.
Expand All @@ -96,7 +96,7 @@ public enum ReturnError: Error {
// Your publication has already been returned before or is expired.
case alreadyReturnedOrExpired
// An unexpected error has occurred on the licensing server.
case unexpectedServerError
case unexpectedServerError(HTTPError)
}

/// Errors while parsing the License or Status JSON Documents.
Expand Down
34 changes: 22 additions & 12 deletions Sources/LCP/License/License.swift
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,18 @@ extension License: LCPLicense {
return try await httpClient.fetch(HTTPRequest(url: url, method: .put))
.map { $0.body ?? Data() }
.mapError { error -> RenewError in
switch error.kind {
case .badRequest:
return .renewFailed
case .forbidden:
return .invalidRenewalPeriod(maxRenewDate: self.maxRenewDate)
switch error {
case let .errorResponse(response):
switch response.status {
case .badRequest:
return .renewFailed
case .forbidden:
return .invalidRenewalPeriod(maxRenewDate: self.maxRenewDate)
default:
return .unexpectedServerError(error)
}
default:
return .unexpectedServerError
return .unexpectedServerError(error)
}
}
.get()
Expand Down Expand Up @@ -260,13 +265,18 @@ extension License: LCPLicense {
do {
let data = try await httpClient.fetch(HTTPRequest(url: url, method: .put))
.mapError { error -> ReturnError in
switch error.kind {
case .badRequest:
return .returnFailed
case .forbidden:
return .alreadyReturnedOrExpired
switch error {
case let .errorResponse(response):
switch response.status {
case .badRequest:
return .returnFailed
case .forbidden:
return .alreadyReturnedOrExpired
default:
return .unexpectedServerError(error)
}
default:
return .unexpectedServerError
return .unexpectedServerError(error)
}
}
.map { $0.body ?? Data() }
Expand Down
7 changes: 1 addition & 6 deletions Sources/LCP/Services/DeviceService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ final class DeviceService {
}

let data = await httpClient.fetch(HTTPRequest(url: url, method: .post))
.map { response -> Data? in
guard 100 ..< 400 ~= response.statusCode else {
return nil
}
return response.body
}
.map(\.body)

try await repository.registerDevice(for: license.id)

Expand Down
130 changes: 80 additions & 50 deletions Sources/Shared/Toolkit/HTTP/DefaultHTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {

public func stream(
request: any HTTPRequestConvertible,
consume: @escaping (Data, Double?) -> Void
consume: @escaping (Data, Double?) -> HTTPResult<Void>
) async -> HTTPResult<HTTPResponse> {
await request.httpRequest()
.asyncFlatMap(willStartRequest)
Expand Down Expand Up @@ -321,11 +321,7 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
typealias Continuation = CheckedContinuation<HTTPResult<HTTPResponse>, Never>
typealias ReceiveResponse = (HTTPResponse) -> Void
typealias ReceiveChallenge = (URLAuthenticationChallenge) async -> URLAuthenticationChallengeResponse
typealias Consume = (Data, Double?) -> Void

enum TaskError: Error {
case byteRangesNotSupported(url: HTTPURL)
}
typealias Consume = (Data, Double?) -> HTTPResult<Void>

private let request: HTTPRequest
fileprivate let task: URLSessionTask
Expand All @@ -339,13 +335,20 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
private enum State {
/// Waiting to start the task.
case initializing

/// Waiting for the HTTP response.
case start(continuation: Continuation)
/// We received a success response, the data will be sent to `consume` progressively.

/// We received a success response, the data will be sent to
/// `consume` progressively.
case stream(continuation: Continuation, response: HTTPResponse, readBytes: Int64)
/// We received an error response, the data will be accumulated in `response.body` to make the final
/// `HTTPError`. The body is needed for example when the response is an OPDS Authentication Document.
case failure(continuation: Continuation, kind: HTTPError.Kind, cause: Error?, response: HTTPResponse?)

/// We received an error response, the data will be accumulated in
/// `response.body` if the error is an `HTTPError.errorResponse`, as
/// it could be needed for example when the response is an OPDS
/// Authentication Document.
case failure(continuation: Continuation, error: HTTPError)

/// The request is terminated.
case finished

Expand All @@ -357,7 +360,7 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
return continuation
case let .stream(continuation, _, _):
return continuation
case let .failure(continuation, _, _, _):
case let .failure(continuation, _):
return continuation
}
}
Expand Down Expand Up @@ -394,14 +397,15 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
private func finish() {
switch state {
case let .start(continuation):
continuation.resume(returning: .failure(HTTPError(kind: .cancelled)))
continuation.resume(returning: .failure(.cancelled))

case let .stream(continuation, response, _):
continuation.resume(returning: .success(response))

case let .failure(continuation, kind, cause, response):
let error = HTTPError(kind: kind, cause: cause, response: response)
log(.error, "\(request.method) \(request.url) failed with: \(error.localizedDescription)")
case let .failure(continuation, error):
var errorDescription = ""
dump(error, to: &errorDescription)
log(.error, "\(request.method) \(request.url) failed with:\n\(errorDescription)")
continuation.resume(returning: .failure(error))

case .initializing, .finished:
Expand All @@ -427,34 +431,22 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {

var response = HTTPResponse(request: request, response: urlResponse, url: url)

if let kind = HTTPError.Kind(statusCode: response.statusCode) {
state = .failure(continuation: continuation, kind: kind, cause: nil, response: response)

// It was a HEAD request? We need to query the resource again to get the error body. The body is needed
// for example when the response is an OPDS Authentication Document.
if request.method == .head {
var modifiedRequest = request
modifiedRequest.method = .get
session.dataTask(with: modifiedRequest.urlRequest) { data, _, error in
response.body = data
self.state = .failure(continuation: continuation, kind: kind, cause: error, response: response)
completionHandler(.cancel)
}.resume()
return
}

} else {
guard !request.hasHeader("Range") || response.acceptsByteRanges else {
log(.error, "Streaming ranges requires the remote HTTP server to support byte range requests: \(url)")
state = .failure(continuation: continuation, kind: .other, cause: TaskError.byteRangesNotSupported(url: url), response: response)
completionHandler(.cancel)
return
}
guard response.status.isSuccess else {
state = .failure(continuation: continuation, error: .errorResponse(response))
completionHandler(.allow)
return
}

state = .stream(continuation: continuation, response: response, readBytes: 0)
receiveResponse(response)
guard !request.hasHeader("Range") || response.acceptsByteRanges else {
log(.error, "Streaming ranges requires the remote HTTP server to support byte range requests: \(url)")
state = .failure(continuation: continuation, error: .rangeNotSupported)
completionHandler(.cancel)
return
}

state = .stream(continuation: continuation, response: response, readBytes: 0)
receiveResponse(response)

completionHandler(.allow)
}

Expand All @@ -469,14 +461,23 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
if let expectedBytes = response.contentLength {
progress = Double(min(readBytes, expectedBytes)) / Double(expectedBytes)
}
consume(data, progress)
state = .stream(continuation: continuation, response: response, readBytes: readBytes)

case .failure(let continuation, let kind, let cause, var response):
var body = response?.body ?? Data()
body.append(data)
response?.body = body
state = .failure(continuation: continuation, kind: kind, cause: cause, response: response)

switch consume(data, progress) {
case .success:
state = .stream(continuation: continuation, response: response, readBytes: readBytes)
case let .failure(error):
state = .failure(continuation: continuation, error: error)
}

case .failure(let continuation, var error):
if case var .errorResponse(response) = error {
var body = response.body ?? Data()
body.append(data)
response.body = body
error = .errorResponse(response)
}

state = .failure(continuation: continuation, error: error)
}
}

Expand All @@ -485,7 +486,7 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
if case .failure = state {
// No-op, we don't want to overwrite the failure state in this case.
} else if let continuation = state.continuation {
state = .failure(continuation: continuation, kind: HTTPError.Kind(error: error), cause: error, response: nil)
state = .failure(continuation: continuation, error: HTTPError(error: error))
} else {
state = .finished
}
Expand All @@ -511,6 +512,35 @@ public final class DefaultHTTPClient: HTTPClient, Loggable {
}
}

private extension HTTPError {
/// Maps a native `URLError` to `HTTPError`.
init(error: Error) {
switch error {
case let error as URLError:
switch error.code {
case .httpTooManyRedirects, .redirectToNonExistentLocation:
self = .redirection(error)
case .secureConnectionFailed, .clientCertificateRejected, .clientCertificateRequired, .appTransportSecurityRequiresSecureConnection, .userAuthenticationRequired:
self = .security(error)
case .badServerResponse, .zeroByteResource, .cannotDecodeContentData, .cannotDecodeRawData, .dataLengthExceedsMaximum:
self = .malformedResponse(error)
case .notConnectedToInternet, .networkConnectionLost:
self = .offline(error)
case .cannotConnectToHost, .cannotFindHost:
self = .unreachable(error)
case .timedOut:
self = .timeout(error)
case .cancelled, .userCancelledAuthentication:
self = .cancelled
default:
self = .other(error)
}
default:
self = .other(error)
}
}
}

private extension HTTPRequest {
var urlRequest: URLRequest {
var request = URLRequest(url: url.url)
Expand Down Expand Up @@ -545,7 +575,7 @@ private extension HTTPResponse {
self.init(
request: request,
url: url,
statusCode: response.statusCode,
status: HTTPStatus(rawValue: response.statusCode),
headers: headers,
mediaType: response.mimeType.flatMap { MediaType($0) },
body: body
Expand Down
Loading