Skip to content

Commit

Permalink
Clean up error handling and add "PathTooLong" test
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusTomlinson committed Jan 9, 2024
1 parent 13f6b16 commit f84e961
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 34 deletions.
26 changes: 13 additions & 13 deletions src/IpcClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <IpcCommon.h>

#include <iostream>
#include <mutex>

using namespace Ipc;
Expand Down Expand Up @@ -85,7 +84,6 @@ class ClientImpl
#endif
{
result.clear();
std::cerr << "recv failed (error: " << errorCode << ")" << std::endl;
}
}
} while ( recvResult > 0 );
Expand All @@ -98,6 +96,7 @@ class ClientImpl
return result;
}

std::string initError;
std::string socketPath = "";
sockaddr_un socketAddr;
unsigned char recvBytes[c_recvBufferSize] = {};
Expand All @@ -114,14 +113,14 @@ Client::Client( const std::filesystem::path& socketPath )
WSADATA wsd;
if ( WSAStartup( WINSOCK_VERSION, &wsd ) != 0 )
{
std::cerr << "WSAStartup failed" << std::endl;
p->initError = "WSAStartup() failed";
return;
}
#endif

if ( p->socketPath.length() > sizeof( sockaddr_un::sun_path ) )
{
std::cerr << "socket path too long: " << socketPath << std::endl;
p->initError = "socket path too long: " + p->socketPath;
return;
}

Expand All @@ -140,6 +139,11 @@ Client::~Client()

Message Client::Send( const Message& header, const Message& message )
{
if ( !p->initError.empty() )
{
return Message( p->initError, true );
}

std::lock_guard<std::mutex> lock( p->sendMutex );

if ( header.Size() == 0 )
Expand All @@ -154,7 +158,7 @@ Message Client::Send( const Message& header, const Message& message )
SOCKET clientSocket = socket( AF_UNIX, SOCK_STREAM, PF_UNSPEC );
if ( clientSocket == INVALID_SOCKET )
{
return Message( "socket failed", true );
return Message( "socket() failed (error: " + std::to_string( lastError() ) + ")", true );
}

#ifdef _WIN32
Expand All @@ -172,34 +176,30 @@ Message Client::Send( const Message& header, const Message& message )
if ( connect( clientSocket, reinterpret_cast<const sockaddr*>( &p->socketAddr ), sizeof( p->socketAddr ) ) ==
SOCKET_ERROR )
{
std::cerr << "connect failed (error: " << lastError() << ")" << std::endl;
closesocket( clientSocket );
return Message( "connect failed", true );
return Message( "connect() failed (error: " + std::to_string( lastError() ) + ")", true );
}

// Send header data
if ( !p->Send( clientSocket, header ) )
{
std::cerr << "send header failed (error: " << lastError() << ")" << std::endl;
closesocket( clientSocket );
return Message( "send header failed", true );
return Message( "header send() failed (error: " + std::to_string( lastError() ) + ")", true );
}

// Receive ack
auto recvBytes = p->Receive( clientSocket );
if ( recvBytes.empty() || recvBytes[0] != 1 )
{
std::cerr << "receive ack failed (error: " << lastError() << ")" << std::endl;
closesocket( clientSocket );
return Message( "receive ack failed", true );
return Message( "ack recv() failed (error: " + std::to_string( lastError() ) + ")", true );
}

// Send message data
if ( !p->Send( clientSocket, message ) )
{
std::cerr << "send message failed (error: " << lastError() << ")" << std::endl;
closesocket( clientSocket );
return Message( "send message failed", true );
return Message( "message send() failed (error: " + std::to_string( lastError() ) + ")", true );
}

// Receive some data
Expand Down
37 changes: 19 additions & 18 deletions src/IpcServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <IpcCommon.h>
#include <IpcMessage.h>

#include <iostream>

using namespace Ipc;

namespace Ipc::Private
Expand All @@ -51,22 +49,22 @@ class ServerImpl final
WSADATA wsd;
if ( WSAStartup( WINSOCK_VERSION, &wsd ) != 0 )
{
std::cerr << "WSAStartup failed" << std::endl;
initError = "WSAStartup() failed";
return;
}
#endif

if ( socketPath.length() > sizeof( sockaddr_un::sun_path ) )
{
std::cerr << "socket path too long: " << socketPath << std::endl;
initError = "socket path too long: " + socketPath;
return;
}

// Create a AF_UNIX stream server socket
serverSocket = socket( AF_UNIX, SOCK_STREAM, 0 );
if ( serverSocket == INVALID_SOCKET )
{
std::cerr << "socket failed (error: " << lastError() << ")" << std::endl;
initError = "socket() failed (error: " + std::to_string( lastError() ) + ")";
return;
}

Expand All @@ -93,7 +91,7 @@ class ServerImpl final
if ( bind( serverSocket, reinterpret_cast<const sockaddr*>( &socketAddr ), sizeof( socketAddr ) ) ==
SOCKET_ERROR )
{
std::cerr << "bind failed (error: " << lastError() << ")" << std::endl;
initError = "bind() failed (error: " + std::to_string( lastError() ) + ")";
closesocket( serverSocket );
serverSocket = INVALID_SOCKET;
return;
Expand All @@ -102,13 +100,11 @@ class ServerImpl final
// Listen to start accepting connections
if ( listen( serverSocket, SOMAXCONN ) == SOCKET_ERROR )
{
std::cerr << "listen failed (error: " << lastError() << ")" << std::endl;
initError = "listen() failed (error: " + std::to_string( lastError() ) + ")";
closesocket( serverSocket );
serverSocket = INVALID_SOCKET;
return;
}

std::cout << "Accepting connections" << std::endl;
}

~ServerImpl()
Expand All @@ -128,7 +124,7 @@ class ServerImpl final
{
if ( serverSocket == INVALID_SOCKET )
{
std::cerr << "Server socket is invalid" << std::endl;
callback( Message( "", true ), Message( initError, true ) );
return false;
}

Expand All @@ -143,13 +139,16 @@ class ServerImpl final

if ( select( (int)serverSocket + 1, &fd, nullptr, nullptr, &timeout ) <= 0 )
{
callback( Message( "", true ),
Message( "select() failed (error: " + std::to_string( lastError() ) + ")", true ) );
return false;
}

SOCKET clientSocket = accept( serverSocket, NULL, NULL );
if ( clientSocket == INVALID_SOCKET )
{
std::cerr << "accept failed (error: " << lastError() << ")" << std::endl;
callback( Message( "", true ),
Message( "accept() failed (error: " + std::to_string( lastError() ) + ")", true ) );
return false;
}

Expand All @@ -170,7 +169,8 @@ class ServerImpl final
{
if ( lastError() != EINVAL )
{
std::cerr << "recieve header failed (error: " << lastError() << ")" << std::endl;
callback( Message( "", true ),
Message( "header recv() failed (error: " + std::to_string( lastError() ) + ")", true ) );
}
closesocket( clientSocket );
return false;
Expand All @@ -179,7 +179,8 @@ class ServerImpl final
// Send ack
if ( !Send( clientSocket, std::vector<unsigned char>{ 1 } ) )
{
std::cerr << "send ack failed (error: " << lastError() << ")" << std::endl;
callback( Message( "", true ),
Message( "ack send() failed (error: " + std::to_string( lastError() ) + ")", true ) );
closesocket( clientSocket );
return false;
}
Expand All @@ -188,7 +189,8 @@ class ServerImpl final
auto recvMessageBytes = Receive( clientSocket );
if ( recvMessageBytes.empty() )
{
std::cerr << "recieve message failed (error: " << lastError() << ")" << std::endl;
callback( Message( "", true ),
Message( "message recv() failed (error: " + std::to_string( lastError() ) + ")", true ) );
closesocket( clientSocket );
return false;
}
Expand All @@ -197,7 +199,8 @@ class ServerImpl final
auto sendMessage = callback( recvHeaderBytes, recvMessageBytes );
if ( !Send( clientSocket, sendMessage ) )
{
std::cerr << "send response failed (error: " << lastError() << ")" << std::endl;
callback( Message( "", true ),
Message( "response send() failed (error: " + std::to_string( lastError() ) + ")", true ) );
closesocket( clientSocket );
return false;
}
Expand All @@ -211,14 +214,12 @@ class ServerImpl final
SOCKET clientSocket = socket( AF_UNIX, SOCK_STREAM, PF_UNSPEC );
if ( clientSocket == INVALID_SOCKET )
{
std::cerr << "socket failed (error: " << lastError() << ")" << std::endl;
return false;
}

if ( connect( clientSocket, reinterpret_cast<const sockaddr*>( &socketAddr ), sizeof( socketAddr ) ) ==
SOCKET_ERROR )
{
std::cerr << "connect failed (error: " << lastError() << ")" << std::endl;
closesocket( clientSocket );
return false;
}
Expand Down Expand Up @@ -266,7 +267,6 @@ class ServerImpl final
#endif
{
result.clear();
std::cerr << "recv failed (error: " << errCode << ")" << std::endl;
}
}
} while ( recvResult > 0 );
Expand All @@ -279,6 +279,7 @@ class ServerImpl final
return result;
}

std::string initError;
SOCKET serverSocket = INVALID_SOCKET;
std::string socketPath = "";
sockaddr_un socketAddr;
Expand Down
36 changes: 33 additions & 3 deletions tests/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Ipc::Message RecvCallback( const Ipc::Message& recvHeader, const Ipc::Message& r
return std::string( "Unix Domain Sockets!" );
}

TEST( Client, SameProcess )
TEST( Ipc, SameProcess )
{
Ipc::Server server( c_serverSocket );
auto listenThread = std::thread(
Expand All @@ -73,7 +73,7 @@ TEST( Client, SameProcess )
listenThread.join();
}

TEST( Simple, SeparateProcess )
TEST( Ipc, SeparateProcess )
{
std::promise<void> ready;
std::thread(
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST( Simple, SeparateProcess )
testing::ExitedWithCode( 0 ), "" );
}

TEST( Server, StopListening )
TEST( Ipc, StopListening )
{
Ipc::Server server( c_serverSocket );
auto listenThread = std::thread( [&server] { ASSERT_FALSE( server.Listen( RecvCallback ) ); } );
Expand All @@ -122,6 +122,36 @@ TEST( Server, StopListening )
listenThread.join();
}

TEST( Ipc, PathTooLong )
{
auto longPath =
"really/really/really/really/really/really/really/really/really/really/really/really/really/really/really/"
"really/really/really/really/really/really/really/really/really/really/really/really/really/really/long/path";

// Server
Ipc::Server server( longPath );

bool callbackCalled = false;

ASSERT_FALSE( server.Listen(
[&longPath, &callbackCalled]( const Ipc::Message& header, const Ipc::Message& message ) -> Ipc::Message
{
callbackCalled = true;
EXPECT_TRUE( header.IsError() );
EXPECT_EQ( header.AsString(), "" );
EXPECT_TRUE( message.IsError() );
EXPECT_EQ( message.AsString(), std::string( "socket path too long: " ) + longPath );
return Ipc::Message( "" );
} ) );

ASSERT_TRUE( callbackCalled );

// Client
Ipc::Client client( longPath );
auto response = client.Send( Ipc::Message( "" ), Ipc::Message( "" ) );
ASSERT_EQ( response.AsString(), std::string( "socket path too long: " ) + longPath );
}

int main( int argc, char** argv )
{
::testing::InitGoogleTest( &argc, argv );
Expand Down

0 comments on commit f84e961

Please sign in to comment.