Skip to content

Commit

Permalink
feat(net): add more async fun to tcp and udp
Browse files Browse the repository at this point in the history
  • Loading branch information
Stone749990226 committed Jul 21, 2024
1 parent a1ee2ac commit 61bc03f
Show file tree
Hide file tree
Showing 15 changed files with 373 additions and 233 deletions.
7 changes: 4 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
"./kernel/Cargo.toml",
],
"rust-analyzer.checkOnSave": true,
"rust-analyzer.files.excludeDirs": [
"third-party"
],
"files.watcherExclude": {
"**/third-party": true
},
// not work
"rust-analyzer.files.excludeDirs": [
"**/third-party/**"
],
}
13 changes: 2 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions kernel/src/net/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub trait ProtoOps: Sync + Send + Any + DowncastSync {
async fn recvfrom(&self, _buf: &mut [u8]) -> SysResult<(usize, SockAddr)> {
Err(SysError::EOPNOTSUPP)
}
fn poll(&self) -> NetPollState {
async fn poll(&self) -> NetPollState {
log::error!("[net poll] unimplemented");
NetPollState {
readable: false,
Expand Down Expand Up @@ -136,7 +136,8 @@ impl File for Socket {

async fn base_poll(&self, events: PollEvents) -> PollEvents {
let mut res = PollEvents::empty();
let netstate = self.sk.poll();
poll_interfaces();
let netstate = self.sk.poll().await;
if events.contains(PollEvents::IN) {
if netstate.readable {
res |= PollEvents::IN;
Expand Down
4 changes: 2 additions & 2 deletions kernel/src/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ impl ProtoOps for TcpSock {
Ok((bytes, peer_addr))
}

fn poll(&self) -> NetPollState {
self.tcp.poll()
async fn poll(&self) -> NetPollState {
self.tcp.poll().await
}

fn shutdown(&self, _how: SocketShutdownFlag) -> SysResult<()> {
Expand Down
4 changes: 2 additions & 2 deletions kernel/src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ impl ProtoOps for UdpSock {
.map(|(len, addr)| (len, addr.into()))
}

fn poll(&self) -> NetPollState {
self.udp.poll()
async fn poll(&self) -> NetPollState {
self.udp.poll().await
}

fn shutdown(&self, _how: SocketShutdownFlag) -> SysResult<()> {
Expand Down
11 changes: 11 additions & 0 deletions kernel/src/syscall/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,17 @@ impl Syscall<'_> {
socket.sk.shutdown(how)?;
Ok(0)
}

pub fn sys_socketpair(
&self,
domain: usize,
types: usize,
protocol: usize,
sv: UserWritePtr<[u32; 2]>,
) -> SyscallResult {
log::error!("[sys_socketpair] unsupport syscall now");
Ok(0)
}
}

impl Task {
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/task/tid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct TidAddress {
}

impl TidAddress {
pub fn new() -> Self {
pub const fn new() -> Self {
Self {
set_child_tid: None,
clear_child_tid: None,
Expand Down
1 change: 1 addition & 0 deletions modules/net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ features = [
"socket-tcp",
"socket-dns",
"proto-ipv6",
"async",
# "fragmentation-buffer-size-65536", "proto-ipv4-fragmentation",
# "reassembly-buffer-size-65536", "reassembly-buffer-count-32",
# "assembler-max-segment-count-32",
Expand Down
30 changes: 29 additions & 1 deletion modules/net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#![feature(new_uninit)]
extern crate alloc;
use alloc::{boxed::Box, vec, vec::Vec};
use core::{cell::RefCell, ops::DerefMut, panic};
use core::{cell::RefCell, future::Future, ops::DerefMut, panic};

use arch::time::get_time_us;
use device_core::{error::DevError, NetBufPtrOps, NetDriverOps};
Expand Down Expand Up @@ -122,6 +122,34 @@ impl<'a> SocketSetWrapper<'a> {
f(socket)
}

pub async fn with_socket_async<T: AnySocket<'a>, R, F, Fut>(
&self,
handle: SocketHandle,
f: F,
) -> R
where
F: FnOnce(&T) -> Fut,
Fut: Future<Output = R>,
{
let set = self.0.lock();
let socket = set.get(handle);
f(socket).await
}

pub async fn with_socket_mut_async<T: AnySocket<'a>, R, F, Fut>(
&self,
handle: SocketHandle,
f: F,
) -> R
where
F: FnOnce(&mut T) -> Fut,
Fut: Future<Output = R>,
{
let mut set = self.0.lock();
let socket = set.get_mut(handle);
f(socket).await
}

pub fn with_socket_mut<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
where
F: FnOnce(&mut T) -> R,
Expand Down
79 changes: 58 additions & 21 deletions modules/net/src/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use alloc::boxed::Box;
use core::{
cell::UnsafeCell,
future::Future,
net::SocketAddr,
sync::atomic::{AtomicBool, AtomicU8, Ordering},
};

use async_utils::yield_now;
use async_utils::{get_waker, suspend_now, yield_now};
use log::*;
use smoltcp::{
iface::SocketHandle,
Expand Down Expand Up @@ -176,8 +178,8 @@ impl TcpSocket {
if self.is_nonblocking() {
Err(SysError::EINPROGRESS)
} else {
self.block_on(|| {
let NetPollState { writable, .. } = self.poll_connect();
self.block_on_async(|| async {
let NetPollState { writable, .. } = self.poll_connect().await;
if !writable {
warn!("socket connect() failed: invalid state");
Err(SysError::EAGAIN)
Expand Down Expand Up @@ -255,6 +257,7 @@ impl TcpSocket {
// SAFETY: `self.local_addr` should be initialized after `bind()`.
let local_port = unsafe { self.local_addr.get().read().port };
self.block_on(|| {
// TODO: 这里waker还没有注册到Socket上,可能会丢失 waker
let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
debug!("TCP socket accepted a new connection {}", peer_addr);
Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
Expand Down Expand Up @@ -368,10 +371,10 @@ impl TcpSocket {
}

/// Whether the socket is readable or writable.
pub fn poll(&self) -> NetPollState {
pub async fn poll(&self) -> NetPollState {
match self.get_state() {
STATE_CONNECTING => self.poll_connect(),
STATE_CONNECTED => self.poll_stream(),
STATE_CONNECTING => self.poll_connect().await,
STATE_CONNECTED => self.poll_stream().await,
STATE_LISTENING => self.poll_listener(),
_ => NetPollState {
readable: false,
Expand Down Expand Up @@ -464,12 +467,17 @@ impl TcpSocket {
///
/// Returning `true` indicates that the socket has entered a stable
/// state(connected or failed) and can proceed to the next step
fn poll_connect(&self) -> NetPollState {
async fn poll_connect(&self) -> NetPollState {
// SAFETY: `self.handle` should be initialized above.
let handle = unsafe { self.handle.get().read().unwrap() };
let writable =
SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
State::SynSent => false, // The connection request has been sent but no response
let waker = get_waker().await;
let writable = SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
match socket.state() {
State::SynSent => {
// The connection request has been sent but no response
socket.register_recv_waker(&waker);
false
}
// has been received yet
State::Established => {
self.set_state(STATE_CONNECTED); // connected
Expand All @@ -488,25 +496,32 @@ impl TcpSocket {
self.set_state(STATE_CLOSED); // connection failed
true
}
});
}
});
NetPollState {
readable: false,
writable,
}
}

fn poll_stream(&self) -> NetPollState {
async fn poll_stream(&self) -> NetPollState {
// SAFETY: `self.handle` should be initialized in a connected socket.
let handle = unsafe { self.handle.get().read().unwrap() };
SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
NetPollState {
// readable 本质上是是否应该继续阻塞,因此为 true 时的条件可以理解为:
// 1. 套接字已经关闭接收:在这种情况下,即使没有新数据到达,读取操作也不会阻塞,
// 因为读取会立即返回
// 2. 套接字中有数据可读:这是最常见的可读情况,表示可以从套接字中读取到数据
readable: !socket.may_recv() || socket.can_recv(),
writable: !socket.may_send() || socket.can_send(),
let waker = get_waker().await;
SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
// readable 本质上是是否应该继续阻塞,因此为 true 时的条件可以理解为:
// 1. 套接字已经关闭接收:在这种情况下,即使没有新数据到达,读取操作也不会阻塞,
// 因为读取会立即返回
// 2. 套接字中有数据可读:这是最常见的可读情况,表示可以从套接字中读取到数据
let readable = !socket.may_recv() || socket.can_recv();
let writable = !socket.may_send() || socket.can_send();
if !readable {
socket.register_recv_waker(&waker);
}
if !writable {
socket.register_send_waker(&waker);
}
NetPollState { readable, writable }
})
}

Expand Down Expand Up @@ -538,7 +553,29 @@ impl TcpSocket {
Err(SysError::EAGAIN) => {
// TODO:判断是否有信号

yield_now().await
suspend_now().await
}
Err(e) => return Err(e),
}
}
}
}

async fn block_on_async<F, T, Fut>(&self, mut f: F) -> SysResult<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = SysResult<T>>,
{
if self.is_nonblocking() {
f().await
} else {
loop {
SOCKET_SET.poll_interfaces();
match f().await {
Ok(t) => return Ok(t),
Err(SysError::EAGAIN) => {
// TODO:判断是否有信号
suspend_now().await
}
Err(e) => return Err(e),
}
Expand Down
Loading

0 comments on commit 61bc03f

Please sign in to comment.