From c0ebec144591aa4038f210bcf5df0f37dee987c4 Mon Sep 17 00:00:00 2001 From: lhw Date: Mon, 2 Dec 2024 22:26:46 +0800 Subject: [PATCH 1/3] add unixsocket without real inode --- Cargo.lock | 19 +- api/ruxos_posix_api/build.rs | 2 + api/ruxos_posix_api/ctypes.h | 1 + api/ruxos_posix_api/src/imp/net.rs | 237 ++++++++- modules/ruxfs/src/root.rs | 9 +- modules/ruxnet/Cargo.toml | 5 + modules/ruxnet/src/lib.rs | 4 + modules/ruxnet/src/lwip_impl/tcp.rs | 6 + modules/ruxnet/src/smoltcp_impl/tcp.rs | 11 +- modules/ruxnet/src/unix.rs | 702 +++++++++++++++++++++++++ 10 files changed, 966 insertions(+), 30 deletions(-) create mode 100644 modules/ruxnet/src/unix.rs diff --git a/Cargo.lock b/Cargo.lock index 997c8ce9c..599b3a91a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,12 @@ dependencies = [ "slab_allocator", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -888,9 +894,13 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heapless" @@ -1758,17 +1768,22 @@ name = "ruxnet" version = "0.1.0" dependencies = [ "axerrno", + "axfs_vfs", "axio", "axlog", "axsync", "cfg-if", "cty", "driver_net", + "flatten_objects", + "hashbrown", "lazy_init", + "lazy_static", "log", "lwip_rust", "printf-compat", "ruxdriver", + "ruxfs", "ruxhal", "ruxtask", "smoltcp", diff --git a/api/ruxos_posix_api/build.rs b/api/ruxos_posix_api/build.rs index 1c72378ed..0529c64e6 100644 --- a/api/ruxos_posix_api/build.rs +++ b/api/ruxos_posix_api/build.rs @@ -111,7 +111,9 @@ typedef struct {{ let allow_vars = [ "O_.*", "AF_.*", + "SO_.*", "SOCK_.*", + "SOL_.*", "IPPROTO_.*", "FD_.*", "F_.*", diff --git a/api/ruxos_posix_api/ctypes.h b/api/ruxos_posix_api/ctypes.h index 6298ce6c3..1d43015d5 100644 --- a/api/ruxos_posix_api/ctypes.h +++ b/api/ruxos_posix_api/ctypes.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include diff --git a/api/ruxos_posix_api/src/imp/net.rs b/api/ruxos_posix_api/src/imp/net.rs index c284ad06f..7cafda5f4 100644 --- a/api/ruxos_posix_api/src/imp/net.rs +++ b/api/ruxos_posix_api/src/imp/net.rs @@ -11,19 +11,36 @@ use alloc::{sync::Arc, vec, vec::Vec}; use core::ffi::{c_char, c_int, c_void}; use core::mem::size_of; use core::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; +use core::sync::atomic::AtomicIsize; use axerrno::{LinuxError, LinuxResult}; use axio::PollState; use axsync::Mutex; use ruxfdtable::{FileLike, RuxStat}; -use ruxnet::{TcpSocket, UdpSocket}; +use ruxnet::{SocketAddrUnix, TcpSocket, UdpSocket, UnixSocket, UnixSocketType}; +use ruxtask::fs::RUX_FILE_LIMIT; use crate::ctypes; use crate::utils::char_ptr_to_str; +fn addrun_convert(addr: *const ctypes::sockaddr_un) -> SocketAddrUnix { + unsafe { + SocketAddrUnix { + sun_family: (*addr).sun_family, + sun_path: (*addr).sun_path, + } + } +} + +pub enum UnifiedSocketAddress { + Net(SocketAddr), + Unix(SocketAddrUnix), +} + pub enum Socket { Udp(Mutex), Tcp(Mutex), + Unix(Mutex), } impl Socket { @@ -42,6 +59,7 @@ impl Socket { match self { Socket::Udp(udpsocket) => Ok(udpsocket.lock().send(buf)?), Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().send(buf)?), + Socket::Unix(socket) => Ok(socket.lock().send(buf)?), } } @@ -49,6 +67,7 @@ impl Socket { match self { Socket::Udp(udpsocket) => Ok(udpsocket.lock().recv_from(buf).map(|e| e.0)?), Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf, flags)?), + Socket::Unix(socket) => Ok(socket.lock().recv(buf, flags)?), } } @@ -56,6 +75,7 @@ impl Socket { match self { Socket::Udp(udpsocket) => Ok(udpsocket.lock().poll()?), Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().poll()?), + Socket::Unix(socket) => Ok(socket.lock().poll()?), } } @@ -63,27 +83,73 @@ impl Socket { match self { Socket::Udp(udpsocket) => Ok(udpsocket.lock().local_addr()?), Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().local_addr()?), + Socket::Unix(_) => Err(LinuxError::EOPNOTSUPP), } } - fn peer_addr(&self) -> LinuxResult { + fn peer_addr(&self) -> LinuxResult { match self { - Socket::Udp(udpsocket) => Ok(udpsocket.lock().peer_addr()?), - Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().peer_addr()?), + Socket::Udp(udpsocket) => Ok(UnifiedSocketAddress::Net(udpsocket.lock().peer_addr()?)), + Socket::Tcp(tcpsocket) => Ok(UnifiedSocketAddress::Net(tcpsocket.lock().peer_addr()?)), + Socket::Unix(unixsocket) => { + Ok(UnifiedSocketAddress::Unix(unixsocket.lock().peer_addr()?)) + } } } - fn bind(&self, addr: SocketAddr) -> LinuxResult { + fn bind( + &self, + socket_addr: *const ctypes::sockaddr, + addrlen: ctypes::socklen_t, + ) -> LinuxResult { match self { - Socket::Udp(udpsocket) => Ok(udpsocket.lock().bind(addr)?), - Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().bind(addr)?), + Socket::Udp(udpsocket) => { + let addr = from_sockaddr(socket_addr, addrlen)?; + Ok(udpsocket.lock().bind(addr)?) + } + Socket::Tcp(tcpsocket) => { + let addr = from_sockaddr(socket_addr, addrlen)?; + Ok(tcpsocket.lock().bind(addr)?) + } + Socket::Unix(socket) => { + if socket_addr.is_null() { + return Err(LinuxError::EFAULT); + } + if addrlen != size_of::() as _ { + return Err(LinuxError::EINVAL); + } + Ok(socket + .lock() + .bind(addrun_convert(socket_addr as *const ctypes::sockaddr_un))?) + } } } - fn connect(&self, addr: SocketAddr) -> LinuxResult { + fn connect( + &self, + socket_addr: *const ctypes::sockaddr, + addrlen: ctypes::socklen_t, + ) -> LinuxResult { match self { - Socket::Udp(udpsocket) => Ok(udpsocket.lock().connect(addr)?), - Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().connect(addr)?), + Socket::Udp(udpsocket) => { + let addr = from_sockaddr(socket_addr, addrlen)?; + Ok(udpsocket.lock().connect(addr)?) + } + Socket::Tcp(tcpsocket) => { + let addr = from_sockaddr(socket_addr, addrlen)?; + Ok(tcpsocket.lock().connect(addr)?) + } + Socket::Unix(socket) => { + if socket_addr.is_null() { + return Err(LinuxError::EFAULT); + } + if addrlen != size_of::() as _ { + return Err(LinuxError::EINVAL); + } + Ok(socket + .lock() + .connect(addrun_convert(socket_addr as *const ctypes::sockaddr_un))?) + } } } @@ -92,6 +158,7 @@ impl Socket { // diff: must bind before sendto Socket::Udp(udpsocket) => Ok(udpsocket.lock().send_to(buf, addr)?), Socket::Tcp(_) => Err(LinuxError::EISCONN), + Socket::Unix(_) => Err(LinuxError::EISCONN), } } @@ -103,6 +170,7 @@ impl Socket { .recv_from(buf) .map(|res| (res.0, Some(res.1)))?), Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf, 0).map(|res| (res, None))?), + Socket::Unix(socket) => Ok(socket.lock().recv(buf, 0).map(|res| (res, None))?), } } @@ -110,13 +178,15 @@ impl Socket { match self { Socket::Udp(_) => Err(LinuxError::EOPNOTSUPP), Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().listen()?), + Socket::Unix(socket) => Ok(socket.lock().listen()?), } } - fn accept(&self) -> LinuxResult { + fn accept(&self) -> LinuxResult { match self { Socket::Udp(_) => Err(LinuxError::EOPNOTSUPP), - Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().accept()?), + Socket::Tcp(tcpsocket) => Ok(Socket::Tcp(Mutex::new(tcpsocket.lock().accept()?))), + Socket::Unix(unixsocket) => Ok(Socket::Unix(Mutex::new(unixsocket.lock().accept()?))), } } @@ -135,6 +205,12 @@ impl Socket { tcpsocket.shutdown()?; Ok(()) } + Socket::Unix(socket) => { + let socket = socket.lock(); + socket.peer_addr()?; + socket.shutdown()?; + Ok(()) + } } } } @@ -179,6 +255,7 @@ impl FileLike for Socket { match self { Socket::Udp(udpsocket) => udpsocket.lock().set_nonblocking(nonblock), Socket::Tcp(tcpsocket) => tcpsocket.lock().set_nonblocking(nonblock), + Socket::Unix(unixsocket) => unixsocket.lock().set_nonblocking(nonblock), } Ok(()) } @@ -199,6 +276,15 @@ impl From for ctypes::sockaddr_in { } } +impl From for ctypes::sockaddr_un { + fn from(addr: SocketAddrUnix) -> ctypes::sockaddr_un { + ctypes::sockaddr_un { + sun_family: addr.sun_family, + sun_path: addr.sun_path, + } + } +} + impl From for SocketAddrV4 { fn from(addr: ctypes::sockaddr_in) -> SocketAddrV4 { SocketAddrV4::new( @@ -208,8 +294,16 @@ impl From for SocketAddrV4 { } } +fn un_into_sockaddr(addr: SocketAddrUnix) -> (ctypes::sockaddr, ctypes::socklen_t) { + debug!("convert unixsocket address {:?} into ctypes sockaddr", addr); + ( + unsafe { *(&ctypes::sockaddr_un::from(addr) as *const _ as *const ctypes::sockaddr) }, + size_of::() as _, + ) +} + fn into_sockaddr(addr: SocketAddr) -> (ctypes::sockaddr, ctypes::socklen_t) { - debug!(" Sockaddr: {}", addr); + debug!("convert socket address {} into ctypes sockaddr", addr); match addr { SocketAddr::V4(addr) => ( unsafe { *(&ctypes::sockaddr_in::from(addr) as *const _ as *const ctypes::sockaddr) }, @@ -262,6 +356,10 @@ pub fn sys_socket(domain: c_int, socktype: c_int, protocol: c_int) -> c_int { tcp_socket.set_nonblocking(true); Socket::Tcp(Mutex::new(tcp_socket)).add_to_fd_table() } + (ctypes::AF_UNIX, ctypes::SOCK_STREAM, 0) => { + Socket::Unix(Mutex::new(UnixSocket::new(UnixSocketType::SockStream))) + .add_to_fd_table() + } _ => Err(LinuxError::EINVAL), } }) @@ -297,8 +395,7 @@ pub fn sys_bind( socket_fd, socket_addr as usize, addrlen ); syscall_body!(sys_bind, { - let addr = from_sockaddr(socket_addr, addrlen)?; - Socket::from_fd(socket_fd)?.bind(addr)?; + Socket::from_fd(socket_fd)?.bind(socket_addr, addrlen)?; Ok(0) }) } @@ -316,8 +413,7 @@ pub fn sys_connect( socket_fd, socket_addr as usize, addrlen ); syscall_body!(sys_connect, { - let addr = from_sockaddr(socket_addr, addrlen)?; - Socket::from_fd(socket_fd)?.connect(addr)?; + Socket::from_fd(socket_fd)?.connect(socket_addr, addrlen)?; Ok(0) }) } @@ -464,10 +560,16 @@ pub unsafe fn sys_accept( let socket = Socket::from_fd(socket_fd)?; let new_socket = socket.accept()?; let addr = new_socket.peer_addr()?; - let new_fd = Socket::add_to_fd_table(Socket::Tcp(Mutex::new(new_socket)))?; - unsafe { - (*socket_addr, *socket_len) = into_sockaddr(addr); + let new_fd = Socket::add_to_fd_table(new_socket)?; + match addr { + UnifiedSocketAddress::Net(addr) => unsafe { + (*socket_addr, *socket_len) = into_sockaddr(addr); + }, + UnifiedSocketAddress::Unix(addr) => unsafe { + (*socket_addr, *socket_len) = un_into_sockaddr(addr); + }, } + Ok(new_fd) }) } @@ -601,6 +703,87 @@ pub unsafe fn sys_getsockname( }) } +/// get socket option +/// +/// TODO: some options not impl, just return 0, like SO_RCVBUF SO_SNDBUF +pub fn sys_getsockopt( + socket_fd: c_int, + level: c_int, + optname: c_int, + optval: *mut c_void, + optlen: *mut ctypes::socklen_t, +) -> c_int { + unsafe { + info!( + "sys_getsockopt <= fd: {}, level: {}, optname: {}, optlen: {}, IGNORED", + socket_fd, + level, + optname, + core::ptr::read(optlen as *mut usize) + ); + } + syscall_body!(sys_getsockopt, { + return Ok(0); + if optval.is_null() { + return Err(LinuxError::EFAULT); + } + let socket = Socket::from_fd(socket_fd)?; + match level as u32 { + ctypes::SOL_SOCKET => { + let val = match optname as u32 { + ctypes::SO_ACCEPTCONN => match &*socket { + Socket::Udp(_) => 0, + Socket::Tcp(tcpsocket) => { + if tcpsocket.lock().is_listening() { + 1 + } else { + 0 + } + } + Socket::Unix(unixsocket) => { + if unixsocket.lock().is_listening() { + 1 + } else { + 0 + } + } + }, + ctypes::SO_TYPE => match &*socket { + Socket::Udp(_) => ctypes::SOCK_DGRAM, + Socket::Tcp(_) => ctypes::SOCK_STREAM, + Socket::Unix(unixsocket) => match unixsocket.lock().get_sockettype() { + UnixSocketType::SockStream => ctypes::SOCK_STREAM, + UnixSocketType::SockDgram | UnixSocketType::SockSeqpacket => { + ctypes::SOCK_DGRAM + } + }, + }, + ctypes::SO_RCVLOWAT | ctypes::SO_SNDLOWAT | ctypes::SO_BROADCAST => 1, + ctypes::SO_ERROR + | ctypes::SO_DONTROUTE + | ctypes::SO_KEEPALIVE + | ctypes::SO_LINGER + | ctypes::SO_OOBINLINE + | ctypes::SO_RCVBUF + | ctypes::SO_RCVTIMEO + | ctypes::SO_REUSEADDR + | ctypes::SO_SNDBUF + | ctypes::SO_SNDTIMEO => 0, + _ => return Err(LinuxError::ENOPROTOOPT), + }; + + unsafe { + core::ptr::write(optlen as *mut usize, core::mem::size_of::()); + core::ptr::write(optval as *mut i32, val as i32); + } + + Ok(0) + } + _ => Err(LinuxError::ENOSYS), + } + }) +} + /// Get peer address to which the socket sockfd is connected. pub unsafe fn sys_getpeername( sock_fd: c_int, @@ -618,8 +801,14 @@ pub unsafe fn sys_getpeername( if unsafe { *addrlen } < size_of::() as u32 { return Err(LinuxError::EINVAL); } - unsafe { - (*addr, *addrlen) = into_sockaddr(Socket::from_fd(sock_fd)?.peer_addr()?); + let sockaddr = Socket::from_fd(sock_fd)?.peer_addr()?; + match sockaddr { + UnifiedSocketAddress::Net(netaddr) => unsafe { + (*addr, *addrlen) = into_sockaddr(netaddr); + }, + UnifiedSocketAddress::Unix(unixaddr) => unsafe { + (*addr, *addrlen) = un_into_sockaddr(unixaddr); + }, } Ok(0) }) @@ -658,6 +847,10 @@ pub unsafe fn sys_sendmsg( from_sockaddr(msg.msg_name as *const ctypes::sockaddr, msg.msg_namelen)?, )?, Socket::Tcp(tcpsocket) => tcpsocket.lock().send(buf)?, + Socket::Unix(unixsocket) => unixsocket.lock().sendto( + buf, + addrun_convert(msg.msg_name as *const ctypes::sockaddr_un), + )?, }; } Ok(ret) diff --git a/modules/ruxfs/src/root.rs b/modules/ruxfs/src/root.rs index 7ffb6a520..117dca0b3 100644 --- a/modules/ruxfs/src/root.rs +++ b/modules/ruxfs/src/root.rs @@ -23,6 +23,7 @@ pub struct MountPoint { pub fs: Arc, } +/// fs root directory pub struct RootDirectory { main_fs: Arc, mounts: Vec, @@ -44,6 +45,7 @@ impl Drop for MountPoint { } impl RootDirectory { + /// Creates a new `RootDirectory` with the specified main filesystem. pub const fn new(main_fs: Arc) -> Self { Self { main_fs, @@ -51,6 +53,7 @@ impl RootDirectory { } } + /// Mounts a new filesystem at the specified path within the root directory. pub fn mount(&mut self, path: &'static str, fs: Arc) -> AxResult { if path == "/" { return ax_err!(InvalidInput, "cannot mount root filesystem"); @@ -75,10 +78,12 @@ impl RootDirectory { Ok(()) } + /// Unmounts a filesystem at the specified path, if it exists. pub fn _umount(&mut self, path: &str) { self.mounts.retain(|mp| mp.path != path); } + /// Checks if a given path is a mount point in the root directory. pub fn contains(&self, path: &str) -> bool { self.mounts.iter().any(|mp| mp.path == path) } @@ -156,6 +161,7 @@ impl VfsNodeOps for RootDirectory { } } +/// Looks up a node in the virtual file system by its path. pub fn lookup(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { if path.is_empty() { return ax_err!(NotFound); @@ -168,7 +174,8 @@ pub fn lookup(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { } } -pub(crate) fn create_file(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { +/// Creates a file in the virtual file system at the specified path. +pub fn create_file(dir: Option<&VfsNodeRef>, path: &str) -> AxResult { if path.is_empty() { return ax_err!(NotFound); } else if path.ends_with('/') { diff --git a/modules/ruxnet/Cargo.toml b/modules/ruxnet/Cargo.toml index bfc19c5f5..0ec18fc25 100644 --- a/modules/ruxnet/Cargo.toml +++ b/modules/ruxnet/Cargo.toml @@ -16,18 +16,23 @@ smoltcp = [] default = ["smoltcp", "loopback"] [dependencies] +hashbrown = "0.14.5" log = "0.4" cfg-if = "1.0" spin = "0.9" driver_net = { path = "../../crates/driver_net" } +flatten_objects = { path = "../../crates/flatten_objects" } lazy_init = { path = "../../crates/lazy_init" } +lazy_static = { version = "1.4", features = ["spin_no_std"] } lwip_rust = { path = "../../crates/lwip_rust", optional = true } printf-compat = { version = "0.1", default-features = false, optional = true } axerrno = { path = "../../crates/axerrno" } +axfs_vfs = { path = "../../crates/axfs_vfs" } ruxhal = { path = "../ruxhal" } axsync = { path = "../axsync" } axlog = { path = "../axlog" } ruxtask = { path = "../ruxtask" } +ruxfs = { path = "../ruxfs" } ruxdriver = { path = "../ruxdriver", features = ["net"] } cty = { version = "0.2.2", optional = true } axio = { path = "../../crates/axio" } diff --git a/modules/ruxnet/src/lib.rs b/modules/ruxnet/src/lib.rs index f7304e199..c2b48311b 100644 --- a/modules/ruxnet/src/lib.rs +++ b/modules/ruxnet/src/lib.rs @@ -37,6 +37,9 @@ extern crate log; extern crate alloc; +mod unix; +pub use unix::{SocketAddrUnix, UnixSocket, UnixSocketType}; + cfg_if::cfg_if! { if #[cfg(feature = "lwip")] { mod lwip_impl; @@ -73,6 +76,7 @@ pub fn init_network(mut net_devs: AxDeviceContainer) { } } net_impl::init(); + unix::init_unix(); while !net_devs.is_empty() { let dev = net_devs.take_one().expect("No NIC device found!"); info!(" use NIC: {:?}", dev.device_name()); diff --git a/modules/ruxnet/src/lwip_impl/tcp.rs b/modules/ruxnet/src/lwip_impl/tcp.rs index 2c621f0eb..15a09961f 100644 --- a/modules/ruxnet/src/lwip_impl/tcp.rs +++ b/modules/ruxnet/src/lwip_impl/tcp.rs @@ -191,6 +191,12 @@ impl TcpSocket { } } + /// Returens if this socket is listening + #[inline] + pub fn is_listening(&self) -> bool { + unsafe { (*self.pcb.get()).state == tcp_state_LISTEN } + } + /// Returns whether this socket is in nonblocking mode. #[inline] pub fn is_nonblocking(&self) -> bool { diff --git a/modules/ruxnet/src/smoltcp_impl/tcp.rs b/modules/ruxnet/src/smoltcp_impl/tcp.rs index c20a2424c..2634b8e25 100644 --- a/modules/ruxnet/src/smoltcp_impl/tcp.rs +++ b/modules/ruxnet/src/smoltcp_impl/tcp.rs @@ -108,6 +108,12 @@ impl TcpSocket { } } + /// Returens if this socket is listening + #[inline] + pub fn is_listening(&self) -> bool { + self.get_state() == STATE_LISTENING + } + /// Returns whether this socket is in nonblocking mode. #[inline] pub fn is_nonblocking(&self) -> bool { @@ -432,11 +438,6 @@ impl TcpSocket { self.get_state() == STATE_CONNECTED } - #[inline] - fn is_listening(&self) -> bool { - self.get_state() == STATE_LISTENING - } - fn bound_endpoint(&self) -> AxResult { // SAFETY: no other threads can read or write `self.local_addr`. let local_addr = unsafe { self.local_addr.get().read() }; diff --git a/modules/ruxnet/src/unix.rs b/modules/ruxnet/src/unix.rs new file mode 100644 index 000000000..5436ab237 --- /dev/null +++ b/modules/ruxnet/src/unix.rs @@ -0,0 +1,702 @@ +/* Copyright (c) [2023] [Syswonder Community] +* [Ruxos] is licensed under Mulan PSL v2. +* You can use this software according to the terms and conditions of the Mulan PSL v2. +* You may obtain a copy of Mulan PSL v2 at: +* http://license.coscl.org.cn/MulanPSL2 +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +* See the Mulan PSL v2 for more details. +*/ + +use alloc::{sync::Arc, vec}; +use axerrno::{ax_err, AxError, AxResult, LinuxError, LinuxResult}; +use axio::PollState; +use axsync::Mutex; +use core::ffi::{c_char, c_int}; +use core::net::SocketAddr; +use core::sync::atomic::{AtomicBool, Ordering}; +use spin::RwLock; + +use lazy_init::LazyInit; + +use smoltcp::socket::tcp::SocketBuffer; + +use hashbrown::HashMap; + +use ruxfs::root::{create_file, lookup}; +use ruxtask::yield_now; + +const SOCK_ADDR_UN_PATH_LEN: usize = 108; +const UNIX_SOCKET_BUFFER_SIZE: usize = 4096; + +/// rust form for ctype sockaddr_un +#[derive(Clone, Copy, Debug)] +pub struct SocketAddrUnix { + /// AF_UNIX + pub sun_family: u16, + /// socket path + pub sun_path: [c_char; SOCK_ADDR_UN_PATH_LEN], /* Pathname */ +} + +impl SocketAddrUnix { + /// Sets the socket address to the specified new address. + pub fn set_addr(&mut self, new_addr: &SocketAddrUnix) { + self.sun_family = new_addr.sun_family; + self.sun_path = new_addr.sun_path; + } +} + +//To avoid owner question of FDTABLE outside and UnixTable in this crate we split the unixsocket +struct UnixSocketInner<'a> { + pub addr: Mutex, + pub buf: SocketBuffer<'a>, + pub peer_socket: Option, + pub status: UnixSocketStatus, +} + +impl<'a> UnixSocketInner<'a> { + pub fn new() -> Self { + Self { + addr: Mutex::new(SocketAddrUnix { + sun_family: 1, //AF_UNIX + sun_path: [0; SOCK_ADDR_UN_PATH_LEN], + }), + buf: SocketBuffer::new(vec![0; 64 * 1024]), + peer_socket: None, + status: UnixSocketStatus::Closed, + } + } + + pub fn get_addr(&self) -> SocketAddrUnix { + self.addr.lock().clone() + } + + pub fn get_peersocket(&self) -> Option { + self.peer_socket + } + + pub fn set_peersocket(&mut self, peer: usize) { + self.peer_socket = Some(peer) + } + + pub fn get_state(&self) -> UnixSocketStatus { + self.status + } + + pub fn set_state(&mut self, state: UnixSocketStatus) { + self.status = state + } + + pub fn can_accept(&mut self) -> bool { + match self.status { + UnixSocketStatus::Listening => !self.buf.is_empty(), + _ => false, + } + } + + pub fn may_recv(&mut self) -> bool { + match self.status { + UnixSocketStatus::Connected => true, + //State::FinWait1 | State::FinWait2 => true, + _ if !self.buf.is_empty() => true, + _ => false, + } + } + + pub fn can_recv(&mut self) -> bool { + if !self.may_recv() { + return false; + } + + !self.buf.is_empty() + } + + pub fn may_send(&mut self) -> bool { + match self.status { + UnixSocketStatus::Connected => true, + //State::CloseWait => true, + _ => false, + } + } + + pub fn can_send(&mut self) -> bool { + self.may_send() + } +} + +/// unix domain socket. +pub struct UnixSocket { + sockethandle: Option, + unixsocket_type: UnixSocketType, + nonblock: AtomicBool, +} + +// now there is no real inode, this func is to check whether file exists +// TODO: if inode impl, this should return inode +fn get_inode(addr: SocketAddrUnix) -> AxResult { + let slice = unsafe { core::slice::from_raw_parts(addr.sun_path.as_ptr(), addr.sun_path.len()) }; + + let socket_path = unsafe { + core::ffi::CStr::from_ptr(slice.as_ptr()) + .to_str() + .expect("Invalid UTF-8 string") + }; + let _vfsnode = match lookup(None, socket_path) { + Ok(node) => node, + Err(_) => { + return Err(AxError::NotFound); + } + }; + + Err(AxError::Unsupported) +} + +fn create_socket_file(addr: SocketAddrUnix) -> AxResult { + let slice = unsafe { core::slice::from_raw_parts(addr.sun_path.as_ptr(), addr.sun_path.len()) }; + + let socket_path = unsafe { + core::ffi::CStr::from_ptr(slice.as_ptr()) + .to_str() + .expect("Invalid UTF-8 string") + }; + let _vfsnode = create_file(None, socket_path)?; + Err(AxError::Unsupported) +} + +struct HashMapWarpper<'a> { + inner: HashMap>>>, + index_allcator: Mutex, +} +impl<'a> HashMapWarpper<'a> { + pub fn new() -> Self { + Self { + inner: HashMap::new(), + index_allcator: Mutex::new(0), + } + } + pub fn find(&self, predicate: F) -> Option<(&usize, &Arc>>)> + where + F: Fn(&Arc>>) -> bool, + { + self.inner.iter().find(|(_k, v)| predicate(v)) + } + + pub fn add(&mut self, value: Arc>>) -> Option { + let index_allcator = self.index_allcator.get_mut(); + while self.inner.contains_key(index_allcator) { + *index_allcator += 1; + } + self.inner.insert(*index_allcator, value); + Some(*index_allcator) + } + + pub fn replace_handle(&mut self, old: usize, new: usize) -> Option { + if let Some(value) = self.inner.remove(&old) { + self.inner.insert(new, value); + } + Some(new) + } + + pub fn get(&self, id: usize) -> Option<&Arc>>> { + self.inner.get(&id) + } + + pub fn get_mut(&mut self, id: usize) -> Option<&mut Arc>>> { + self.inner.get_mut(&id) + } +} +static UNIX_TABLE: LazyInit> = LazyInit::new(); + +/// unix socket type +#[derive(Debug, Clone, Copy)] +pub enum UnixSocketType { + /// A stream-oriented Unix domain socket. + SockStream, + /// A datagram-oriented Unix domain socket. + SockDgram, + /// A sequenced packet Unix domain socket. + SockSeqpacket, +} + +// State transitions: +// CLOSED -(connect)-> BUSY -> CONNECTING -> CONNECTED -(shutdown)-> BUSY -> CLOSED +// | +// |-(listen)-> BUSY -> LISTENING -(shutdown)-> BUSY -> CLOSED +// | +// -(bind)-> BUSY -> CLOSED +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum UnixSocketStatus { + Closed, + Busy, + Connecting, + Connected, + Listening, +} + +impl UnixSocket { + /// create a new socket + /// only support sock_stream + pub fn new(_type: UnixSocketType) -> Self { + match _type { + UnixSocketType::SockDgram | UnixSocketType::SockSeqpacket => unimplemented!(), + UnixSocketType::SockStream => { + let mut unixsocket = UnixSocket { + sockethandle: None, + unixsocket_type: _type, + nonblock: AtomicBool::new(false), + }; + let handle = UNIX_TABLE + .write() + .add(Arc::new(Mutex::new(UnixSocketInner::new()))) + .unwrap(); + unixsocket.set_sockethandle(handle); + unixsocket + } + } + } + + /// Sets the socket handle. + pub fn set_sockethandle(&mut self, fd: usize) { + self.sockethandle = Some(fd); + } + + /// Returns the socket handle. + pub fn get_sockethandle(&self) -> usize { + self.sockethandle.unwrap() + } + + /// Returns the peer socket handle, if available. + pub fn get_peerhandle(&self) -> Option { + UNIX_TABLE + .read() + .get(self.get_sockethandle()) + .unwrap() + .lock() + .get_peersocket() + } + + /// Returns the current state of the socket. + pub fn get_state(&self) -> UnixSocketStatus { + UNIX_TABLE + .read() + .get(self.get_sockethandle()) + .unwrap() + .lock() + .status + } + + /// Enqueues data into the socket buffer. + /// returns the number of bytes enqueued, or an error if the socket is closed. + pub fn enqueue_buf(&mut self, data: &[u8]) -> AxResult { + match self.get_state() { + UnixSocketStatus::Closed => Err(AxError::BadState), + _ => Ok(UNIX_TABLE + .write() + .get_mut(self.get_sockethandle()) + .unwrap() + .lock() + .buf + .enqueue_slice(data)), + } + } + + /// Dequeues data from the socket buffer. + /// return the number of bytes dequeued, or a BadState error if the socket is closed or a WouldBlock error if buffer is empty. + pub fn dequeue_buf(&mut self, data: &mut [u8]) -> AxResult { + match self.get_state() { + UnixSocketStatus::Closed => Err(AxError::BadState), + _ => { + if UNIX_TABLE + .write() + .get_mut(self.get_sockethandle()) + .unwrap() + .lock() + .buf + .is_empty() + { + return Err(AxError::WouldBlock); + } + Ok(UNIX_TABLE + .write() + .get_mut(self.get_sockethandle()) + .unwrap() + .lock() + .buf + .dequeue_slice(data)) + } + } + } + + /// Binds the socket to a specified address, get inode number of the address as handle + // TODO: bind to file system + pub fn bind(&mut self, addr: SocketAddrUnix) -> LinuxResult { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Closed => { + { + match get_inode(addr) { + Ok(inode_addr) => { + UNIX_TABLE + .write() + .replace_handle(self.get_sockethandle(), inode_addr); + self.set_sockethandle(inode_addr); + } + Err(AxError::NotFound) => match create_socket_file(addr) { + Ok(inode_addr) => { + UNIX_TABLE + .write() + .replace_handle(self.get_sockethandle(), inode_addr); + self.set_sockethandle(inode_addr); + } + _ => { + warn!("unix socket can not get real inode"); + } + }, + _ => { + warn!("unix socket can not get real inode"); + } + } + } + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + socket_inner.addr.lock().set_addr(&addr); + socket_inner.set_state(UnixSocketStatus::Busy); + Ok(()) + } + _ => Err(LinuxError::EINVAL), + } + } + + /// Sends data through the socket to the connected peer, push data into buffer of peer socket + /// this will block if not connected by default + pub fn send(&self, buf: &[u8]) -> LinuxResult { + match self.unixsocket_type { + UnixSocketType::SockDgram | UnixSocketType::SockSeqpacket => Err(LinuxError::ENOTCONN), + UnixSocketType::SockStream => loop { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Connecting => { + if self.is_nonblocking() { + return Err(LinuxError::EINPROGRESS); + } else { + yield_now(); + } + } + UnixSocketStatus::Connected => { + let peer_handle = UNIX_TABLE + .read() + .get(self.get_sockethandle()) + .unwrap() + .lock() + .get_peersocket() + .unwrap(); + return Ok(UNIX_TABLE + .write() + .get_mut(peer_handle) + .unwrap() + .lock() + .buf + .enqueue_slice(buf)); + } + _ => { + return Err(LinuxError::ENOTCONN); + } + } + }, + } + } + + /// Receives data from the socket, check if there any data in buffer + /// this will block if not connected or buffer is empty by default + pub fn recv(&self, buf: &mut [u8], _flags: i32) -> LinuxResult { + match self.unixsocket_type { + UnixSocketType::SockDgram | UnixSocketType::SockSeqpacket => Err(LinuxError::ENOTCONN), + UnixSocketType::SockStream => loop { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Connecting => { + if self.is_nonblocking() { + return Err(LinuxError::EAGAIN); + } else { + yield_now(); + } + } + UnixSocketStatus::Connected => { + if UNIX_TABLE + .read() + .get(self.get_sockethandle()) + .unwrap() + .lock() + .buf + .is_empty() + { + if self.is_nonblocking() { + return Err(LinuxError::EAGAIN); + } else { + yield_now(); + } + } else { + return Ok(UNIX_TABLE + .read() + .get(self.get_sockethandle()) + .unwrap() + .lock() + .buf + .dequeue_slice(buf)); + } + } + _ => { + return Err(LinuxError::ENOTCONN); + } + } + }, + } + } + + /// Polls the socket's readiness for connection. + fn poll_connect(&self) -> LinuxResult { + let writable = { + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + if !socket_inner.get_peersocket().is_none() { + socket_inner.set_state(UnixSocketStatus::Connected); + true + } else { + false + } + }; + Ok(PollState { + readable: false, + writable, + }) + } + + /// Polls the socket's readiness for reading or writing. + pub fn poll(&self) -> LinuxResult { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Connecting => self.poll_connect(), + UnixSocketStatus::Connected => { + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + Ok(PollState { + readable: !socket_inner.may_recv() || socket_inner.can_recv(), + writable: !socket_inner.may_send() || socket_inner.can_send(), + }) + } + UnixSocketStatus::Listening => { + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + Ok(PollState { + readable: socket_inner.can_accept(), + writable: false, + }) + } + _ => Ok(PollState { + readable: false, + writable: false, + }), + } + } + + /// Returns the local address of the socket. + pub fn local_addr(&self) -> LinuxResult { + unimplemented!() + } + + /// Returns the file descriptor for the socket. + fn fd(&self) -> c_int { + UNIX_TABLE + .write() + .get_mut(self.get_sockethandle()) + .unwrap() + .lock() + .addr + .lock() + .sun_path[0] as _ + } + + /// Returns the peer address of the socket. + pub fn peer_addr(&self) -> AxResult { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Connected | UnixSocketStatus::Listening => { + let peer_sockethandle = self.get_peerhandle().unwrap(); + Ok(UNIX_TABLE + .read() + .get(peer_sockethandle) + .unwrap() + .lock() + .get_addr()) + } + _ => Err(AxError::NotConnected), + } + } + + /// Connects the socket to a specified address, push info into remote socket + pub fn connect(&mut self, addr: SocketAddrUnix) -> LinuxResult { + let now_state = self.get_state(); + if now_state != UnixSocketStatus::Connecting && now_state != UnixSocketStatus::Connected { + //a new block is needed to free rwlock + { + match get_inode(addr) { + Ok(inode_addr) => { + let binding = UNIX_TABLE.write(); + let remote_socket = binding.get(inode_addr).unwrap(); + if remote_socket.lock().get_state() != UnixSocketStatus::Listening { + error!("unix conncet error: remote socket not listening"); + return Err(LinuxError::EFAULT); + } + let data = &self.get_sockethandle().to_ne_bytes(); + let _res = remote_socket.lock().buf.enqueue_slice(data); + } + Err(AxError::NotFound) => return Err(LinuxError::ENOENT), + _ => { + warn!("unix socket can not get real inode"); + let binding = UNIX_TABLE.write(); + let (_remote_sockethandle, remote_socket) = binding + .find(|socket| socket.lock().addr.lock().sun_path == addr.sun_path) + .unwrap(); + let data = &self.get_sockethandle().to_ne_bytes(); + let _res = remote_socket.lock().buf.enqueue_slice(data); + } + } + } + { + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + socket_inner.set_state(UnixSocketStatus::Connecting); + } + } + + loop { + let PollState { writable, .. } = self.poll_connect()?; + if !writable { + // When set to non_blocking, directly return inporgress + if self.is_nonblocking() { + return Err(LinuxError::EINPROGRESS); + } else { + yield_now(); + } + } else if self.get_state() == UnixSocketStatus::Connected { + return Ok(()); + } else { + // When set to non_blocking, directly return inporgress + if self.is_nonblocking() { + return Err(LinuxError::EINPROGRESS); + } + warn!("socket connect() failed") + } + } + } + + /// Sends data to a specified address. + pub fn sendto(&self, buf: &[u8], addr: SocketAddrUnix) -> LinuxResult { + unimplemented!() + } + + /// Receives data from the socket and returns the sender's address. + pub fn recvfrom(&self, buf: &mut [u8]) -> LinuxResult<(usize, Option)> { + unimplemented!() + } + + /// Listens for incoming connections on the socket. + // TODO: check file system + pub fn listen(&mut self) -> LinuxResult { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Busy => { + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + socket_inner.set_state(UnixSocketStatus::Listening); + Ok(()) + } + _ => { + Ok(()) //ignore simultaneous `listen`s. + } + } + } + + /// Accepts a new connection from a listening socket, get info from self buffer + pub fn accept(&mut self) -> AxResult { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Listening => { + //buf dequeue as handle to get socket + loop { + let data: &mut [u8] = &mut [0u8; core::mem::size_of::()]; + let res = self.dequeue_buf(data); + match res { + Ok(_len) => { + let mut array = [0u8; core::mem::size_of::()]; + array.copy_from_slice(data); + let remote_handle = usize::from_ne_bytes(array); + let unix_socket = UnixSocket::new(UnixSocketType::SockStream); + { + let mut binding = UNIX_TABLE.write(); + let remote_socket = binding.get_mut(remote_handle).unwrap(); + remote_socket + .lock() + .set_peersocket(unix_socket.get_sockethandle()); + } + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding + .get_mut(unix_socket.get_sockethandle()) + .unwrap() + .lock(); + socket_inner.set_peersocket(remote_handle); + socket_inner.set_state(UnixSocketStatus::Connected); + return Ok(unix_socket); + } + Err(AxError::WouldBlock) => { + if self.is_nonblocking() { + return Err(AxError::WouldBlock); + } else { + yield_now(); + } + } + Err(e) => { + return Err(e); + } + } + } + } + _ => ax_err!(InvalidInput, "socket accept() failed: not listen"), + } + } + + //TODO + /// Shuts down the socket. + pub fn shutdown(&self) -> LinuxResult { + unimplemented!() + } + + /// Returns whether this socket is in nonblocking mode. + #[inline] + pub fn is_nonblocking(&self) -> bool { + self.nonblock.load(Ordering::Acquire) + } + + /// Sets the nonblocking mode for the socket. + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblock.store(nonblocking, Ordering::Release); + } + + /// Checks if the socket is in a listening state. + pub fn is_listening(&self) -> bool { + let now_state = self.get_state(); + match now_state { + UnixSocketStatus::Listening => true, + _ => false, + } + } + + /// Returns the socket type of the `UnixSocket`. + pub fn get_sockettype(&self) -> UnixSocketType { + self.unixsocket_type + } +} + +/// Initializes the global UNIX socket table, `UNIX_TABLE`, for managing Unix domain sockets. +pub(crate) fn init_unix() { + UNIX_TABLE.init_by(RwLock::new(HashMapWarpper::new())); +} From e149cde883352f2e6eccf9b41c1fe5cf55f4d4a0 Mon Sep 17 00:00:00 2001 From: WuZheng Date: Tue, 3 Dec 2024 00:44:53 +0800 Subject: [PATCH 2/3] fix bug for unexpected pagefault when nested fork. --- api/ruxos_posix_api/src/imp/pthread/mod.rs | 1 - crates/driver_net/src/loopback.rs | 8 +++----- modules/ruxnet/src/smoltcp_impl/mod.rs | 4 +--- modules/ruxtask/src/run_queue.rs | 1 + modules/ruxtask/src/task.rs | 21 +++++++++++++++------ 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/api/ruxos_posix_api/src/imp/pthread/mod.rs b/api/ruxos_posix_api/src/imp/pthread/mod.rs index da2694ff8..48a205c2e 100644 --- a/api/ruxos_posix_api/src/imp/pthread/mod.rs +++ b/api/ruxos_posix_api/src/imp/pthread/mod.rs @@ -310,7 +310,6 @@ pub unsafe fn sys_clone( } else if (flags as u32 & ctypes::SIGCHLD) != 0 { TID_TO_PTHREAD.read(); let pid = if let Some(task_ref) = ruxtask::fork_task() { - warn!("fork_task success, pid: {}", task_ref.id().as_u64()); task_ref.id().as_u64() } else { let children_ref = ruxtask::current(); diff --git a/crates/driver_net/src/loopback.rs b/crates/driver_net/src/loopback.rs index 3574cee75..072bd0d6f 100644 --- a/crates/driver_net/src/loopback.rs +++ b/crates/driver_net/src/loopback.rs @@ -57,8 +57,6 @@ impl BaseDriverOps for LoopbackDevice { } } -use log::info; - impl NetDriverOps for LoopbackDevice { #[inline] fn mac_address(&self) -> EthernetAddress { @@ -85,11 +83,11 @@ impl NetDriverOps for LoopbackDevice { self.queue.len() } - fn fill_rx_buffers(&mut self, buf_pool: &Arc) -> DevResult { + fn fill_rx_buffers(&mut self, _buf_pool: &Arc) -> DevResult { Ok(()) } - fn recycle_rx_buffer(&mut self, rx_buf: NetBufPtr) -> DevResult { + fn recycle_rx_buffer(&mut self, _rx_buf: NetBufPtr) -> DevResult { Ok(()) } @@ -97,7 +95,7 @@ impl NetDriverOps for LoopbackDevice { Ok(()) } - fn prepare_tx_buffer(&self, tx_buf: &mut NetBuf, pkt_len: usize) -> DevResult { + fn prepare_tx_buffer(&self, _tx_buf: &mut NetBuf, _pkt_len: usize) -> DevResult { Ok(()) } diff --git a/modules/ruxnet/src/smoltcp_impl/mod.rs b/modules/ruxnet/src/smoltcp_impl/mod.rs index 192fe2d93..50bfbfe9d 100644 --- a/modules/ruxnet/src/smoltcp_impl/mod.rs +++ b/modules/ruxnet/src/smoltcp_impl/mod.rs @@ -36,8 +36,6 @@ pub use self::dns::dns_query; pub use self::tcp::TcpSocket; pub use self::udp::UdpSocket; -pub use driver_net::loopback::LoopbackDevice; - macro_rules! env_or_default { ($key:literal) => { match option_env!($key) { @@ -347,7 +345,7 @@ pub fn bench_receive() { } pub(crate) fn init() { - let mut socketset = SocketSetWrapper::new(); + let socketset = SocketSetWrapper::new(); IFACE_LIST.init_by(Mutex::new(vec::Vec::new())); SOCKET_SET.init_by(socketset); diff --git a/modules/ruxtask/src/run_queue.rs b/modules/ruxtask/src/run_queue.rs index 82fa5ffd1..95de4926f 100644 --- a/modules/ruxtask/src/run_queue.rs +++ b/modules/ruxtask/src/run_queue.rs @@ -129,6 +129,7 @@ impl AxRunQueue { assert!(!curr.is_idle()); // we must not block current task with preemption disabled. + // only allow blocking current task with run_queue lock held. #[cfg(feature = "preempt")] assert!(curr.can_preempt(1)); diff --git a/modules/ruxtask/src/task.rs b/modules/ruxtask/src/task.rs index cc1a45d20..7311d53ad 100644 --- a/modules/ruxtask/src/task.rs +++ b/modules/ruxtask/src/task.rs @@ -81,6 +81,7 @@ pub struct TaskInner { exit_code: AtomicI32, wait_for_exit: WaitQueue, + stack_map_addr: SpinNoIrq, kstack: SpinNoIrq>>, ctx: UnsafeCell, @@ -235,6 +236,7 @@ impl TaskInner { preempt_disable_count: AtomicUsize::new(0), exit_code: AtomicI32::new(0), wait_for_exit: WaitQueue::new(), + stack_map_addr: SpinNoIrq::new(VirtAddr::from(0)), // should be set later kstack: SpinNoIrq::new(Arc::new(None)), ctx: UnsafeCell::new(TaskContext::new()), #[cfg(feature = "tls")] @@ -279,6 +281,7 @@ impl TaskInner { preempt_disable_count: AtomicUsize::new(0), exit_code: AtomicI32::new(0), wait_for_exit: WaitQueue::new(), + stack_map_addr: SpinNoIrq::new(VirtAddr::from(0)), kstack: SpinNoIrq::new(Arc::new(None)), ctx: UnsafeCell::new(TaskContext::new()), #[cfg(feature = "tls")] @@ -299,6 +302,7 @@ impl TaskInner { pub fn set_stack_top(&self, begin: usize, size: usize) { debug!("set_stack_top: begin={:#x}, size={:#x}", begin, size); + *self.stack_map_addr.lock() = VirtAddr::from(begin); *self.kstack.lock() = Arc::new(Some(TaskStack { ptr: NonNull::new(begin as *mut u8).unwrap(), layout: Layout::from_size_align(size, PAGE_SIZE_4K).unwrap(), @@ -406,14 +410,14 @@ impl TaskInner { // Note: the stack region is mapped to the same position as the parent process's stack, be careful when update the stack region for the forked process. let (_, prev_flag, _) = cloned_page_table - .query(current_stack.end()) + .query(*current().stack_map_addr.lock()) .expect("failed to query stack region when forking"); cloned_page_table - .unmap_region(current_stack.end(), align_up_4k(stack_size)) + .unmap_region(*current().stack_map_addr.lock(), align_up_4k(stack_size)) .expect("failed to unmap stack region when forking"); cloned_page_table .map_region( - current_stack.end(), + *current().stack_map_addr.lock(), stack_paddr, stack_size, prev_flag, @@ -477,10 +481,11 @@ impl TaskInner { need_resched: AtomicBool::new(current_task.need_resched.load(Ordering::Relaxed)), #[cfg(feature = "preempt")] preempt_disable_count: AtomicUsize::new( - current_task.preempt_disable_count.load(Ordering::Relaxed), + current_task.preempt_disable_count.load(Ordering::Acquire), ), exit_code: AtomicI32::new(0), wait_for_exit: WaitQueue::new(), + stack_map_addr: SpinNoIrq::new(*current().stack_map_addr.lock()), kstack: SpinNoIrq::new(Arc::new(Some(new_stack))), ctx: UnsafeCell::new(TaskContext::new()), #[cfg(feature = "tls")] @@ -515,6 +520,7 @@ impl TaskInner { .lock() .insert(new_pid.as_u64(), task_ref.clone()); + warn!("forked task: save_current_content {}", task_ref.id_name()); unsafe { // copy the stack content from current stack to new stack (*task_ref.ctx_mut_ptr()).save_current_content( @@ -554,6 +560,7 @@ impl TaskInner { preempt_disable_count: AtomicUsize::new(0), exit_code: AtomicI32::new(0), wait_for_exit: WaitQueue::new(), + stack_map_addr: SpinNoIrq::new(VirtAddr::from(0)), // set in set_stack_top kstack: SpinNoIrq::new(Arc::new(None)), ctx: UnsafeCell::new(TaskContext::new()), #[cfg(feature = "tls")] @@ -590,6 +597,7 @@ impl TaskInner { let bindings = PROCESS_MAP.lock(); let (&_parent_id, &ref task_ref) = bindings.first_key_value().unwrap(); let idle_kstack = TaskStack::alloc(align_up_4k(IDLE_STACK_SIZE)); + let idle_kstack_top = idle_kstack.top(); let mut t = Self { parent_process: Some(Arc::downgrade(task_ref)), @@ -609,7 +617,8 @@ impl TaskInner { preempt_disable_count: AtomicUsize::new(0), exit_code: AtomicI32::new(0), wait_for_exit: WaitQueue::new(), - kstack: SpinNoIrq::new(Arc::new(None)), + stack_map_addr: SpinNoIrq::new(idle_kstack.end()), + kstack: SpinNoIrq::new(Arc::new(Some(idle_kstack))), ctx: UnsafeCell::new(TaskContext::new()), #[cfg(feature = "tls")] tls: TlsArea::alloc(), @@ -633,7 +642,7 @@ impl TaskInner { debug!("new idle task: {}", t.id_name()); t.ctx .get_mut() - .init(task_entry as usize, idle_kstack.top(), tls); + .init(task_entry as usize, idle_kstack_top, tls); let task_ref = Arc::new(AxTask::new(t)); From bd02cdd8227819e8a3d2e07e62b108cea32cba4d Mon Sep 17 00:00:00 2001 From: lhw Date: Wed, 11 Dec 2024 20:06:30 +0800 Subject: [PATCH 3/3] add unix socket drop --- modules/ruxnet/src/unix.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/modules/ruxnet/src/unix.rs b/modules/ruxnet/src/unix.rs index 5436ab237..9136090b8 100644 --- a/modules/ruxnet/src/unix.rs +++ b/modules/ruxnet/src/unix.rs @@ -203,6 +203,10 @@ impl<'a> HashMapWarpper<'a> { pub fn get_mut(&mut self, id: usize) -> Option<&mut Arc>>> { self.inner.get_mut(&id) } + + pub fn remove(&mut self, id: usize) -> Option>>> { + self.inner.remove(&id) + } } static UNIX_TABLE: LazyInit> = LazyInit::new(); @@ -664,10 +668,12 @@ impl UnixSocket { } } - //TODO /// Shuts down the socket. pub fn shutdown(&self) -> LinuxResult { - unimplemented!() + let mut binding = UNIX_TABLE.write(); + let mut socket_inner = binding.get_mut(self.get_sockethandle()).unwrap().lock(); + socket_inner.set_state(UnixSocketStatus::Closed); + Ok(()) } /// Returns whether this socket is in nonblocking mode. @@ -696,6 +702,13 @@ impl UnixSocket { } } +impl Drop for UnixSocket { + fn drop(&mut self) { + self.shutdown(); + UNIX_TABLE.write().remove(self.get_sockethandle()); + } +} + /// Initializes the global UNIX socket table, `UNIX_TABLE`, for managing Unix domain sockets. pub(crate) fn init_unix() { UNIX_TABLE.init_by(RwLock::new(HashMapWarpper::new()));