Skip to content

Commit

Permalink
When processing forwarded HTTP request, don't process until we have r…
Browse files Browse the repository at this point in the history
…eceived the whole body

It might happen that the proxied request arrives in two chunks - first
just the headers and then the body. In that case we would have sent just
the headers with an empty body through the pipeline which might trigger
errors like "400 Client Error", then the body would arrive in the next
chunk.

This patch changes the logic to keep the whole body in the buffer until
it is complete and can be processed by the pipeline and sent off to the
server.

Fixes: stacklok#517
  • Loading branch information
jhrozek committed Jan 9, 2025
1 parent 7fa31a8 commit d4c1f3c
Showing 1 changed file with 83 additions and 2 deletions.
85 changes: 83 additions & 2 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,82 @@ async def _forward_data_to_target(self, data: bytes) -> None:
pipeline_output = pipeline_output.reconstruct()
self.target_transport.write(pipeline_output)

def _has_complete_body(self) -> bool:
"""
Check if we have received the complete request body based on Content-Length header.
We check the headers from the buffer instead of using self.request.headers on purpose
because with CONNECT requests, the whole request arrives in the data and is stored in
the buffer.
"""
try:
# For the initial CONNECT request
if not self.headers_parsed and self.request and self.request.method == "CONNECT":
return True

# For subsequent requests or non-CONNECT requests, parse the method from the buffer
try:
first_line = self.buffer[: self.buffer.index(b"\r\n")].decode("utf-8")
method = first_line.split()[0]
except (ValueError, IndexError):
# Haven't received the complete request line yet
return False

if method != "POST": # do we need to check for other methods? PUT?
return True

# Parse headers from the buffer instead of using self.request.headers
headers_dict = {}
try:
headers_end = self.buffer.index(b"\r\n\r\n")
if headers_end <= 0: # Ensure we have a valid headers section
return False

headers = self.buffer[:headers_end].split(b"\r\n")
if len(headers) <= 1: # Ensure we have headers after the request line
return False

for header in headers[1:]: # Skip the request line
if not header: # Skip empty lines
continue
try:
name, value = header.decode("utf-8").split(":", 1)
headers_dict[name.strip().lower()] = value.strip()
except ValueError:
# Skip malformed headers
continue
except ValueError:
# Haven't received the complete headers yet
return False

# TODO: Add proper support for chunked transfer encoding
# For now, just pass through and let the pipeline handle it
if "transfer-encoding" in headers_dict:
return True

try:
content_length = int(headers_dict.get("content-length"))
except (ValueError, TypeError):
# Content-Length header is required for POST requests without chunked encoding
logger.error("Missing or invalid Content-Length header in POST request")
return False

body_start = headers_end + 4 # Add safety check for buffer length
if body_start >= len(self.buffer):
return False

current_body_length = len(self.buffer) - body_start
return current_body_length >= content_length
except Exception as e:
logger.error(f"Error checking body completion: {e}")
return False

def data_received(self, data: bytes) -> None:
"""Handle received data from client"""
"""
Handle received data from client. Since we need to process the complete body
through our pipeline before forwarding, we accumulate the entire request first.
"""
logger.info(f"Received data from {self.peername}: {data}")
try:
if not self._check_buffer_size(data):
self.send_error_response(413, b"Request body too large")
Expand All @@ -370,10 +444,17 @@ def data_received(self, data: bytes) -> None:
if self.headers_parsed:
if self.request.method == "CONNECT":
self.handle_connect()
self.buffer.clear()
else:
# Only process the request once we have the complete body
asyncio.create_task(self.handle_http_request())
else:
asyncio.create_task(self._forward_data_to_target(data))
if self._has_complete_body():
# Process the complete request through the pipeline
complete_request = bytes(self.buffer)
logger.debug(f"Complete request: {complete_request}")
self.buffer.clear()
asyncio.create_task(self._forward_data_to_target(complete_request))

except Exception as e:
logger.error(f"Error processing received data: {e}")
Expand Down

0 comments on commit d4c1f3c

Please sign in to comment.