Skip to content

Commit

Permalink
local-tun: Manage TCP socket state properly
Browse files Browse the repository at this point in the history
- close() only once when shutting down sender
- read() returns EOF when local client closed our receiver
- ref #1138
  • Loading branch information
zonyitoo committed Mar 12, 2023
1 parent 08bf30a commit 549ffd3
Showing 1 changed file with 98 additions and 36 deletions.
134 changes: 98 additions & 36 deletions crates/shadowsocks-service/src/local/tun/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,21 @@ use super::virt_device::VirtTunDevice;
const DEFAULT_TCP_SEND_BUFFER_SIZE: u32 = 0x3FFF * 20;
const DEFAULT_TCP_RECV_BUFFER_SIZE: u32 = 0x3FFF * 20;

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum TcpSocketState {
Normal,
Close,
Closing,
Closed,
}

struct TcpSocketControl {
send_buffer: RingBuffer<'static, u8>,
send_waker: Option<Waker>,
recv_buffer: RingBuffer<'static, u8>,
recv_waker: Option<Waker>,
is_closed: bool,
recv_state: TcpSocketState,
send_state: TcpSocketState,
}

struct ManagerNotify {
Expand Down Expand Up @@ -89,7 +98,16 @@ struct TcpConnection {
impl Drop for TcpConnection {
fn drop(&mut self) {
let mut control = self.control.lock();
control.is_closed = true;

if matches!(control.recv_state, TcpSocketState::Normal) {
control.recv_state = TcpSocketState::Close;
}

if matches!(control.send_state, TcpSocketState::Normal) {
control.send_state = TcpSocketState::Close;
}

self.manager_notify.notify();
}
}

Expand All @@ -108,7 +126,8 @@ impl TcpConnection {
send_waker: None,
recv_buffer: RingBuffer::new(vec![0u8; recv_buffer_size as usize]),
recv_waker: None,
is_closed: false,
recv_state: TcpSocketState::Normal,
send_state: TcpSocketState::Normal,
}));

let _ = socket_creation_tx.send(TcpSocketCreation {
Expand All @@ -127,14 +146,13 @@ impl AsyncRead for TcpConnection {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let mut control = self.control.lock();

// If socket is already closed, just return EOF directly.
if control.is_closed {
return Ok(()).into();
}

// Read from buffer

if control.recv_buffer.is_empty() {
// If socket is already closed / half closed, just return EOF directly.
if matches!(control.recv_state, TcpSocketState::Closed) {
return Ok(()).into();
}

// Nothing could be read. Wait for notify.
if let Some(old_waker) = control.recv_waker.replace(cx.waker().clone()) {
if !old_waker.will_wake(cx.waker()) {
Expand All @@ -159,7 +177,9 @@ impl AsyncRead for TcpConnection {
impl AsyncWrite for TcpConnection {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut control = self.control.lock();
if control.is_closed {

// If state == Close | Closing | Closed, the TCP stream WR half is closed.
if !matches!(control.send_state, TcpSocketState::Normal) {
return Err(io::ErrorKind::BrokenPipe.into()).into();
}

Expand Down Expand Up @@ -190,17 +210,22 @@ impl AsyncWrite for TcpConnection {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut control = self.control.lock();

if control.is_closed {
if matches!(control.send_state, TcpSocketState::Closed) {
return Ok(()).into();
}

control.is_closed = true;
// SHUT_WR
if matches!(control.send_state, TcpSocketState::Normal) {
control.send_state = TcpSocketState::Close;
}

if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
if !old_waker.will_wake(cx.waker()) {
old_waker.wake();
}
}

self.manager_notify.notify();
Poll::Pending
}
}
Expand All @@ -219,6 +244,7 @@ pub struct TcpTun {
impl Drop for TcpTun {
fn drop(&mut self) {
self.manager_running.store(false, Ordering::Relaxed);
self.manager_notify.notify();
let _ = self.manager_handle.take().unwrap().join();
}
}
Expand Down Expand Up @@ -297,33 +323,37 @@ impl TcpTun {
let socket = socket_set.get_mut::<TcpSocket>(socket_handle);
let mut control = control.lock();

#[inline]
fn close_socket_control(control: &mut TcpSocketControl) {
control.is_closed = true;
// Remove the socket only when it is in the closed state.
if socket.state() == TcpState::Closed {
sockets_to_remove.push(socket_handle);

control.send_state = TcpSocketState::Closed;
control.recv_state = TcpSocketState::Closed;

if let Some(waker) = control.send_waker.take() {
waker.wake();
}
if let Some(waker) = control.recv_waker.take() {
waker.wake();
}
}

if !socket.is_open() || socket.state() == TcpState::Closed {
sockets_to_remove.push(socket_handle);
close_socket_control(&mut control);
trace!("closed TCP connection");
continue;
}

if control.is_closed {
// Close the socket.
// SHUT_WR
if matches!(control.send_state, TcpSocketState::Close) {
trace!("closing TCP Write Half, {:?}", socket.state());

// Close the socket. Set to FIN state
socket.close();
// sockets_to_remove.push(socket_handle);
// close_socket_control(&mut *control);
continue;
control.send_state = TcpSocketState::Closing;

// We can still process the pending buffer.
}

// Check if readable
let mut has_received = false;
let mut wake_receiver = false;
while socket.can_recv() && !control.recv_buffer.is_full() {
let result = socket.recv(|buffer| {
let n = control.recv_buffer.enqueue_slice(buffer);
Expand All @@ -332,25 +362,49 @@ impl TcpTun {

match result {
Ok(..) => {
has_received = true;
wake_receiver = true;
}
Err(err) => {
error!("socket recv error: {:?}", err);
sockets_to_remove.push(socket_handle);
close_socket_control(&mut control);
error!("socket recv error: {:?}, {:?}", err, socket.state());

// Don't know why. Abort the connection.
socket.abort();

if matches!(control.recv_state, TcpSocketState::Normal) {
control.recv_state = TcpSocketState::Closed;
}
wake_receiver = true;

// The socket will be recycled in the next poll.
break;
}
}
}

if has_received && control.recv_waker.is_some() {
// If socket is not in ESTABLISH, FIN-WAIT-1, FIN-WAIT-2,
// the local client have closed our receiver.
if matches!(control.recv_state, TcpSocketState::Normal)
&& !socket.may_recv()
&& !matches!(
socket.state(),
TcpState::SynReceived | TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2
)
{
trace!("closed TCP Read Half, {:?}", socket.state());

// Let TcpConnection::poll_read returns EOF.
control.recv_state = TcpSocketState::Closed;
wake_receiver = true;
}

if wake_receiver && control.recv_waker.is_some() {
if let Some(waker) = control.recv_waker.take() {
waker.wake();
}
}

// Check if writable
let mut has_sent = false;
let mut wake_sender = false;
while socket.can_send() && !control.send_buffer.is_empty() {
let result = socket.send(|buffer| {
let n = control.send_buffer.dequeue_slice(buffer);
Expand All @@ -359,18 +413,26 @@ impl TcpTun {

match result {
Ok(..) => {
has_sent = true;
wake_sender = true;
}
Err(err) => {
error!("socket send error: {:?}", err);
sockets_to_remove.push(socket_handle);
close_socket_control(&mut control);
error!("socket send error: {:?}, {:?}", err, socket.state());

// Don't know why. Abort the connection.
socket.abort();

if matches!(control.send_state, TcpSocketState::Normal) {
control.send_state = TcpSocketState::Closed;
}
wake_sender = true;

// The socket will be recycled in the next poll.
break;
}
}
}

if has_sent && control.send_waker.is_some() {
if wake_sender && control.send_waker.is_some() {
if let Some(waker) = control.send_waker.take() {
waker.wake();
}
Expand Down

0 comments on commit 549ffd3

Please sign in to comment.