Skip to content

Commit

Permalink
KTOR-8136 Introduce ServerSocket.port to simplify port access for the…
Browse files Browse the repository at this point in the history
… bound server
  • Loading branch information
e5l committed Jan 31, 2025
1 parent 55106bc commit dd99e7c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
22 changes: 22 additions & 0 deletions ktor-network/common/src/io/ktor/network/sockets/SocketAddress.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,30 @@

package io.ktor.network.sockets

/**
* Represents a socket address abstraction.
*
* This sealed class serves as the base type for different kinds of socket addresses,
* such as Internet-specific or other platform-dependent address types.
* Implementations of this class are expected to be platform-specific and provide
* details necessary to work with socket connections or bindings.
*/
public expect sealed class SocketAddress

/**
* Retrieves the port number associated with this socket address.
*
* If the `SocketAddress` instance is of type `InetSocketAddress`, the associated port is returned.
* Otherwise, an `UnsupportedOperationException` is thrown, as the provided address type does not support ports.
*
* @return the port number of the socket address if available.
* @throws UnsupportedOperationException if the socket address type does not support a port.
*/
public fun SocketAddress.port(): Int = when (this) {
is InetSocketAddress -> port
else -> throw UnsupportedOperationException("SocketAddress $this does not have a port")
}

public expect class InetSocketAddress(
hostname: String,
port: Int
Expand Down
9 changes: 8 additions & 1 deletion ktor-network/common/src/io/ktor/network/sockets/Sockets.kt
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,17 @@ public fun AWritable.openWriteChannel(autoFlush: Boolean = false): ByteWriteChan
public interface Socket : ReadWriteSocket, ABoundSocket, AConnectedSocket, CoroutineScope

/**
* Represents a server bound socket ready for accepting connections
* Represents a server-bound socket ready for accepting connections
*/
public interface ServerSocket : ASocket, ABoundSocket, Acceptable<Socket>

/**
* The port number of the current server.
*
* @throws UnsupportedOperationException if the local socket address does not support a port.
*/
public val ServerSocket.port: Int get() = localAddress.port()

public expect class SocketTimeoutException(message: String) : IOException

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class TCPSocketTest {
@Test
fun testEcho() = testSockets { selector ->
val tcp = aSocket(selector).tcp()
val server = tcp.bind("127.0.0.1", 8000)
val server: ServerSocket = tcp.bind("127.0.0.1", port = 0)

val serverConnectionPromise = async {
server.accept()
}

val clientConnection = tcp.connect("127.0.0.1", 8000)
val clientConnection = tcp.connect("127.0.0.1", port = server.port)
val serverConnection = serverConnectionPromise.await()

val clientOutput = clientConnection.openWriteChannel()
Expand Down Expand Up @@ -115,8 +115,7 @@ class TCPSocketTest {
server.accept()
}

val port = (server.localAddress as InetSocketAddress).port
val client: Socket = tcp.connect("127.0.0.1", port)
val client: Socket = tcp.connect("127.0.0.1", server.port)
val readChannel = client.openReadChannel()
serverConnection.await()

Expand All @@ -130,29 +129,33 @@ class TCPSocketTest {

@Test
fun testConnectToNonExistingSocket() = testSockets(timeout = 10.seconds) { selector ->
val tcp = aSocket(selector).tcp()
val server = tcp.bind("127.0.0.1")
server.close()

assertFailsWith<IOException> {
aSocket(selector)
.tcp()
.connect("127.0.0.1", 8001) // there should be no server active on this port
.connect("127.0.0.1", server.port) // trying to connect to a port that was available but now closed
}
}

@Test
fun testDisconnect() = testSockets { selector ->
val tcp = aSocket(selector).tcp()
val server = tcp.bind("127.0.0.1", 8003)
val server: ServerSocket = tcp.bind("127.0.0.1", port = 0)

val serverConnectionPromise = async {
server.accept()
}

val clientConnection = tcp.connect("127.0.0.1", 8003)
val clientConnection = tcp.connect("127.0.0.1", port = server.port)
val serverConnection = serverConnectionPromise.await()

val serverInput = serverConnection.openReadChannel()

// Need to make sure reading from server is done first, which will suspend because there is nothing to read.
// Then close the connection from client side, which should cancel the reading because the socket disconnected.
// Need to make sure reading from the server is done first, which will suspend because there is nothing to read.
// Then close the connection from the client side, which should cancel the reading because the socket is disconnected.
launch {
delay(100)
clientConnection.close()
Expand Down

0 comments on commit dd99e7c

Please sign in to comment.