Skip to content

Commit

Permalink
fix: fixed dllama-api bug with java's HttpURLConnection (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkeegan authored Feb 6, 2025
1 parent fde35c9 commit 9f04161
Showing 1 changed file with 66 additions and 21 deletions.
87 changes: 66 additions & 21 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,80 @@ class HttpRequest {
}

std::vector<char> readHttpRequest() {
std::vector<char> httpRequest;
char buffer[1024 * 1024]; // TODO: this should be refactored asap
std::string httpRequest;
char buffer[1024 * 64];
ssize_t bytesRead;

// Peek into the socket buffer to check available data
bytesRead = recv(serverSocket, buffer, sizeof(buffer), MSG_PEEK);
if (bytesRead <= 0) {
// No data available or error occurred
if (bytesRead == 0) {
// No more data to read
return httpRequest;
} else {
// Error while peeking
throw std::runtime_error("Error while peeking into socket");
// First, read all headers
std::string headerData;
size_t headerEnd;
bool headerDone = false;
std::string extraReadPastHeader;
while (!headerDone) {
bytesRead = recv(serverSocket, buffer, sizeof(buffer) - 1, 0);
if (bytesRead <= 0) {
throw std::runtime_error("Error while reading headers from socket");
}
buffer[bytesRead] = '\0';
headerData.append(buffer);

// Check for end of headers (http spec says "\r\n\r\n")
headerEnd = headerData.find("\r\n\r\n");
if (headerEnd != std::string::npos) {
headerDone = true;
if (headerEnd < headerData.size()-4) {
// We read something past the header
extraReadPastHeader = headerData.substr(headerEnd+4);
}
}
}

// Resize buffer according to the amount of data available
std::vector<char> peekBuffer(bytesRead);
bytesRead = recv(serverSocket, peekBuffer.data(), bytesRead, 0);
if (bytesRead <= 0) {
// Error while reading
throw std::runtime_error("Error while reading from socket");
httpRequest.append(headerData);

// Next, find Content-Length header for body length
std::istringstream headerStream(headerData);
std::string line;
ssize_t contentLength = 0;
while (std::getline(headerStream, line) && line != "\r") {
size_t pos = line.find(':');
if (pos != std::string::npos) {
std::string key = line.substr(0, pos);
std::string value = line.substr(pos + 2); // Skip ': ' after key
if (key == "Content-Length") {
try {
contentLength = std::stoi(value); // stoi ignores any whitespace
} catch (const std::invalid_argument& e) {
throw std::runtime_error("Bad Content-Length header - not a number");
}
break;
}
}
}

// Append data to httpRequest
httpRequest.insert(httpRequest.end(), peekBuffer.begin(), peekBuffer.end());
// Now read the full content body
if (contentLength > 0) {
// If we read any extra past the header before, read that much less now
// But first, sanity check to make sure Content-Length isn't lying and there is actually more
if (extraReadPastHeader.size() > static_cast<size_t>(contentLength)) {
throw std::runtime_error("Received more body data than Content-Length header said");
}
contentLength -= extraReadPastHeader.size();

std::vector<char> body(contentLength);
ssize_t totalRead = 0;
while (totalRead < contentLength) {
bytesRead = recv(serverSocket, body.data() + totalRead, contentLength - totalRead, 0);
if (bytesRead <= 0) {
throw std::runtime_error("Error while reading body from socket");
}
totalRead += bytesRead;
}
if (body.size() > 0) {
httpRequest.append(body.data(), contentLength);
}
}

return httpRequest;
return std::vector<char>(httpRequest.begin(), httpRequest.end());
}

std::string getMethod() {
Expand Down

0 comments on commit 9f04161

Please sign in to comment.