Skip to content

Commit

Permalink
Don't hardcode HTTP protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasReschke committed Apr 23, 2023
1 parent 97c9503 commit e627693
Showing 1 changed file with 42 additions and 47 deletions.
89 changes: 42 additions & 47 deletions Request Ranger/Networking/ProxyHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,6 @@ import NIOHTTP1
import Logging
import Atomics

enum HttpProtocol {
case HTTP
case HTTPS
}

func convertToClientRequestPart(_ reqPart: HTTPServerRequestPart) -> HTTPClientRequestPart {
switch reqPart {
case .head(let head):
return .head(head)
case .body(let buffer):
return .body(.byteBuffer(buffer))
case .end(let headers):
return .end(headers)
}
}

final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equatable {
static func == (lhs: ProxyHandler, rhs: ProxyHandler) -> Bool {
return false
Expand All @@ -40,6 +24,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
private var logger: Logger
private var targetHost: String?
private var targetPort: Int?
private var targetProtocol: HttpProtocol?
private static let globalRequestID = ManagedAtomic<Int>(0) // FIXME: should initialize with latest saved ID
public var requestParts: [HTTPClientRequestPart] = []
private var waitingContext: ChannelHandlerContext?
Expand All @@ -62,11 +47,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
case connectRequested
}

public func forwardRequestForProtocol(httpProtocol: HttpProtocol, context: ChannelHandlerContext, requestParts: [HTTPClientRequestPart]) {
enum HttpProtocol {
case HTTP
case HTTPS
}

public func forwardRequestForProtocol(context: ChannelHandlerContext, requestParts: [HTTPClientRequestPart]) {
guard let port = self.targetPort else {
fatalError("Port was not passed")
}

guard let httpProtocol = self.targetProtocol else {
fatalError("Targer protocol was not passed")
}

if httpProtocol == .HTTPS {
forwardRequestForHttps(context: context, port: port, requestParts: requestParts)
} else {
Expand All @@ -81,18 +75,18 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
switch upgradeState {
case .idle:
let channelFuture = clientBootstrap.channelInitializer { channel in
let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration())
let sslHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.targetHost!)

return channel.pipeline.addHandler(sslHandler).flatMap {
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
.flatMap {
channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
}
}
let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration())
let sslHandler = try! NIOSSLClientHandler(context: sslContext, serverHostname: self.targetHost!)

return channel.pipeline.addHandler(sslHandler).flatMap {
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
.flatMap {
channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
}
}
}
}
.connect(host: targetHost!, port: port)

channelFuture.whenSuccess { channel in
Expand All @@ -117,13 +111,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
switch upgradeState {
case .idle:
let channelFuture = clientBootstrap.channelInitializer { channel in
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
.flatMap {
return channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
}
}
channel.pipeline.addHandler(HTTPRequestEncoder()).flatMap {
channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)))
.flatMap {
return channel.pipeline.addHandler(ResponseHandler(context: context, preUpgradedRequest: self.preUpgradedRequest, request: self.loggedRequest!))
}
}
}
.connect(host: targetHost!, port: port)

channelFuture.whenSuccess { channel in
Expand Down Expand Up @@ -206,7 +200,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
}
}

func handleRequest(context: ChannelHandlerContext, reqPart: HTTPServerRequestPart, httpProtocol: HttpProtocol, port: Int) {
func handleRequest(context: ChannelHandlerContext, reqPart: HTTPServerRequestPart, httpProtocol: HttpProtocol) {
switch(reqPart) {
case .head(var head):
// Remove the Accept-Encoding header
Expand All @@ -223,7 +217,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
requestParts.append(clientReqPart)
case .end(let headers):
requestParts.append(HTTPClientRequestPart.end(headers))

if AppState.shared.isInterceptEnabled {
waitingContext = context

Expand All @@ -238,20 +232,20 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
}
}
} else {
forwardRequestForProtocol(httpProtocol: httpProtocol, context: context, requestParts: requestParts)
forwardRequestForProtocol(context: context, requestParts: requestParts)
}
}
}

func channelReadForHttps(context: ChannelHandlerContext, data: NIOAny) {
let reqPart = self.unwrapInboundIn(data)
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTPS, port: 443)
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTPS)
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let reqPart = self.unwrapInboundIn(data)
print("Invoked channel read")

switch(upgradeState) {
case .connectRequested:
print("Connect return")
Expand Down Expand Up @@ -279,38 +273,40 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
}
self.targetHost = newHost
self.targetPort = originalURI.port ?? 80
self.targetProtocol = .HTTP

head.headers.replaceOrAdd(name: "Host", value: newHost)
head.uri = originalURI.relativePath

handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP, port: 80)
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP)
}
default:
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP, port: 80)
handleRequest(context: context, reqPart: reqPart, httpProtocol: .HTTP)
}
}

private func handleConnectRequest(context: ChannelHandlerContext, head: inout HTTPRequestHead) {
let uriComponents = head.uri.split(separator: ":", maxSplits: 1, omittingEmptySubsequences: false)
self.targetHost = String(uriComponents.first!)
self.targetPort = Int(uriComponents.last!)
self.targetProtocol = .HTTPS

guard let targetHost = self.targetHost,
let _ = self.targetPort else {
sendHttpResponse(ctx: context, status: .badRequest)
context.close(promise: nil)
return
}

let selfSignedCertAndKey = CertificateManager.shared.certificateForDomain(String(targetHost))
let selfSignedRootCa = try! CertificateManager.shared.loadRootCAFromKeychain()

var serializer = DER.Serializer()
try! serializer.serialize(selfSignedCertAndKey!.certificate)

var selfSignedRootCaSerializer = DER.Serializer()
try! selfSignedRootCaSerializer.serialize(selfSignedRootCa.rootCertificate)


let certificate = try! NIOSSLCertificate(bytes: serializer.serializedBytes, format: .der)
let privateKey = try! NIOSSLPrivateKey(bytes: [UInt8](selfSignedCertAndKey!.privateKey.derRepresentation), format: .der)
Expand All @@ -319,13 +315,13 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
let rootCert = NIOSSLCertificateSource.certificate(rootCertificate)
let tlsConfiguration = TLSConfiguration.makeServerConfiguration(certificateChain: [serverCert, rootCert], privateKey: .privateKey(privateKey))
let sslContext = try! NIOSSLContext(configuration: tlsConfiguration)

let sslHandler = NIOSSLServerHandler(context: sslContext)

let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
let responsePart = HTTPServerResponsePart.head(responseHead)
self.preUpgradedRequest = context

context.writeAndFlush(self.wrapOutboundOut(responsePart))
.flatMap { _ in
context.pipeline.removeHandler(name: "HTTPResponseEncoder")
Expand Down Expand Up @@ -386,8 +382,7 @@ final class ProxyHandler: ChannelInboundHandler, RemovableChannelHandler, Equata
let updatedRequestParts = parseRawRequest(rawRequest: rawRequest)

// Forward the updated request
// FIXME: don't hardcode used protocol
forwardRequestForProtocol(httpProtocol: .HTTP, context: waitingContext!, requestParts: updatedRequestParts)
forwardRequestForProtocol(context: waitingContext!, requestParts: updatedRequestParts)
}

private func parseRawRequest(rawRequest: String) -> [HTTPClientRequestPart] {
Expand Down

0 comments on commit e627693

Please sign in to comment.