diff --git a/Request Ranger/Networking/ProxyHandler.swift b/Request Ranger/Networking/ProxyHandler.swift index 6fce56a..3fbff47 100644 --- a/Request Ranger/Networking/ProxyHandler.swift +++ b/Request Ranger/Networking/ProxyHandler.swift @@ -11,22 +11,6 @@ import NIOHTTP1 import Logging import Atomics -enum HttpProtocol { - case HTTP - case HTTPS -} - -func convertToClientRequestPart(_ reqPart: HTTPServerRequestPart) -> HTTPClientRequestPart { - switch reqPart { - case .head(let head): - return .head(head) - case .body(let buffer): - return .body(.byteBuffer(buffer)) - case .end(let headers): - return .end(headers) - } -} - final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equatable { static func == (lhs: ProxyHandler, rhs: ProxyHandler) -> Bool { return false @@ -40,6 +24,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata private var logger: Logger private var targetHost: String? private var targetPort: Int? + private var targetProtocol: HttpProtocol? private static let globalRequestID = ManagedAtomic(0) // FIXME: should initialize with latest saved ID public var requestParts: [HTTPClientRequestPart] = [] private var waitingContext: ChannelHandlerContext? @@ -62,11 +47,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata case connectRequested } - public func forwardRequestForProtocol(httpProtocol: HttpProtocol, context: ChannelHandlerContext, requestParts: [HTTPClientRequestPart]) { + enum HttpProtocol { + case HTTP + case HTTPS + } + + public func forwardRequestForProtocol(context: ChannelHandlerContext, requestParts: [HTTPClientRequestPart]) { guard let port = self.targetPort else { fatalError("Port was not passed") } + guard let httpProtocol = self.targetProtocol else { + fatalError("Targer protocol was not passed") + } + if httpProtocol == .HTTPS { forwardRequestForHttps(context: context, port: port, requestParts: requestParts) } else { @@ -81,18 +75,18 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata switch upgradeState { case .idle: let channelFuture = clientBootstrap.channelInitializer { channel in - let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) - let sslHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.targetHost!) - - return channel.pipeline.addHandler(sslHandler).flatMap { - channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap { - channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))) - .flatMap { - channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!)) - } - } + let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) + let sslHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.targetHost!) + + return channel.pipeline.addHandler(sslHandler).flatMap { + channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap { + channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))) + .flatMap { + channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!)) + } } } + } .connect(host: targetHost!, port: port) channelFuture.whenSuccess { channel in @@ -117,13 +111,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata switch upgradeState { case .idle: let channelFuture = clientBootstrap.channelInitializer { channel in - channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap { - channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))) - .flatMap { - return channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!)) - } - } + channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap { + channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))) + .flatMap { + return channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!)) + } } + } .connect(host: targetHost!, port: port) channelFuture.whenSuccess { channel in @@ -206,7 +200,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata } } - func handleRequest(context: ChannelHandlerContext, reqPart: HTTPServerRequestPart, httpProtocol: HttpProtocol, port: Int) { + func handleRequest(context: ChannelHandlerContext, reqPart: HTTPServerRequestPart, httpProtocol: HttpProtocol) { switch(reqPart) { case .head(var head): // Remove the Accept-Encoding header @@ -223,7 +217,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata requestParts.append(clientReqPart) case .end(let headers): requestParts.append(HTTPClientRequestPart.end(headers)) - + if AppState.shared.isInterceptEnabled { waitingContext = context @@ -238,20 +232,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata } } } else { - forwardRequestForProtocol(httpProtocol: httpProtocol, context: context, requestParts: requestParts) + forwardRequestForProtocol(context: context, requestParts: requestParts) } } } func channelReadForHttps(context: ChannelHandlerContext, data: NIOAny) { let reqPart = self.unwrapInboundIn(data) - handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTPS, port: 443) + handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTPS) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let reqPart = self.unwrapInboundIn(data) print("Invoked channel read") - + switch(upgradeState) { case .connectRequested: print("Connect return") @@ -279,14 +273,15 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata } self.targetHost = newHost self.targetPort = originalURI.port ?? 80 + self.targetProtocol = .HTTP head.headers.replaceOrAdd(name: "Host", value: newHost) head.uri = originalURI.relativePath - handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP, port: 80) + handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP) } default: - handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP, port: 80) + handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP) } } @@ -294,6 +289,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata let uriComponents = head.uri.split(separator: ":", maxSplits: 1, omittingEmptySubsequences: false) self.targetHost = String(uriComponents.first!) self.targetPort = Int(uriComponents.last!) + self.targetProtocol = .HTTPS guard let targetHost = self.targetHost, let _ = self.targetPort else { @@ -301,16 +297,16 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata context.close(promise: nil) return } - + let selfSignedCertAndKey = CertificateManager.shared.certificateForDomain(String(targetHost)) let selfSignedRootCa = try! CertificateManager.shared.loadRootCAFromKeychain() var serializer = DER.Serializer() try! serializer.serialize(selfSignedCertAndKey!.certificate) - + var selfSignedRootCaSerializer = DER.Serializer() try! selfSignedRootCaSerializer.serialize(selfSignedRootCa.rootCertificate) - + let certificate = try! NIOSSLCertificate(bytes: serializer.serializedBytes, format: .der) let privateKey = try! NIOSSLPrivateKey(bytes: [UInt8](selfSignedCertAndKey!.privateKey.derRepresentation), format: .der) @@ -319,13 +315,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata let rootCert = NIOSSLCertificateSource.certificate(rootCertificate) let tlsConfiguration = TLSConfiguration.makeServerConfiguration(certificateChain: [serverCert, rootCert], privateKey: .privateKey(privateKey)) let sslContext = try! NIOSSLContext(configuration: tlsConfiguration) - + let sslHandler = NIOSSLServerHandler(context: sslContext) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) let responsePart = HTTPServerResponsePart.head(responseHead) self.preUpgradedRequest = context - + context.writeAndFlush(self.wrapOutboundOut(responsePart)) .flatMap { _ in context.pipeline.removeHandler(name: "HTTPResponseEncoder") @@ -386,8 +382,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata let updatedRequestParts = parseRawRequest(rawRequest: rawRequest) // Forward the updated request - // FIXME: don't hardcode used protocol - forwardRequestForProtocol(httpProtocol: .HTTP, context: waitingContext!, requestParts: updatedRequestParts) + forwardRequestForProtocol(context: waitingContext!, requestParts: updatedRequestParts) } private func parseRawRequest(rawRequest: String) -> [HTTPClientRequestPart] {