Skip to content

Commit

Permalink
Allow DmaFileStream to be polled multiple times (#611)
Browse files Browse the repository at this point in the history
Co-authored-by: Glauber Costa <[email protected]>
  • Loading branch information
tontinton and Glauber Costa authored Nov 23, 2023
1 parent c6ca6f2 commit a315d9c
Showing 1 changed file with 127 additions and 17 deletions.
144 changes: 127 additions & 17 deletions glommio/src/io/dma_file_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,6 @@ macro_rules! ensure_not_closed {

impl DmaStreamWriterState {
fn add_waker(&mut self, waker: Waker) {
// Linear file stream, not supposed to have parallel writers!!
assert!(self.waker.is_none());
self.waker = Some(waker);
}

Expand All @@ -839,13 +837,7 @@ impl DmaStreamWriterState {
self.flushed_pos() > self.aligned_pos
}

fn initiate_close(
&mut self,
waker: Waker,
state: Rc<RefCell<Self>>,
file: Rc<DmaFile>,
do_close: bool,
) {
fn initiate_close(&mut self, state: Rc<RefCell<Self>>, file: Rc<DmaFile>, do_close: bool) {
self.file_status = FileStatus::Closing;
let final_pos = self.current_pos();
self.flush_padded(state.clone(), file.clone());
Expand All @@ -854,7 +846,11 @@ impl DmaStreamWriterState {
futures_lite::future::yield_now().await;

defer! {
waker.wake();
let mut state = state.borrow_mut();
if let Some(waker) = state.waker.take() {
drop(state);
waker.wake();
}
}

for flush in pending.drain(..) {
Expand Down Expand Up @@ -960,9 +956,14 @@ impl DmaStreamWriterState {
if !collect_error!(state, res) {
state.flush_state.on_complete(flush_pos);
}
if let Some(waker) = state.waker.take() {
drop(state);
waker.wake();

// When the file is closing registers a waker, we don't want to
// wake it now.
if let FileStatus::Open = state.file_status {
if let Some(waker) = state.waker.take() {
drop(state);
waker.wake();
}
}
})
.detach();
Expand Down Expand Up @@ -1272,13 +1273,24 @@ impl DmaStreamWriter {
cx: &Context<'_>,
) -> Poll<io::Result<DmaStreamReaderBuilder>> {
let mut state = self.state.borrow_mut();
let previous_waker = state.waker.take();
match state.file_status {
FileStatus::Open => {
assert!(
previous_waker.is_none(),
"Cannot seal while flushing / writing"
);
let file = ensure_not_closed!(self.file.clone());
state.initiate_close(cx.waker().clone(), self.state.clone(), file, false);
state.add_waker(cx.waker().clone());
state.initiate_close(self.state.clone(), file, false);
Poll::Pending
}
FileStatus::Closing => {
if previous_waker.is_some() {
state.add_waker(cx.waker().clone());
return Poll::Pending;
}

state.file_status = FileStatus::Closed;
let file = ensure_not_closed!(self.file);

Expand Down Expand Up @@ -1325,10 +1337,17 @@ impl<'a> AsyncWrite for &'a DmaStreamWriter {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut state = self.state.borrow_mut();
let previous_waker = state.waker.take();

if let Some(err) = current_error!(state) {
return Poll::Ready(err);
}

if previous_waker.is_some() {
state.add_waker(cx.waker().clone());
return Poll::Pending;
}

let mut written = 0;
while written < buf.len() {
match state.current_buffer.take() {
Expand Down Expand Up @@ -1374,26 +1393,40 @@ impl<'a> AsyncWrite for &'a DmaStreamWriter {

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut state = self.state.borrow_mut();
let previous_waker = state.waker.take();
if let Some(err) = current_error!(state) {
return Poll::Ready(err);
}
if state.flushed_pos() == state.current_pos() {
return Poll::Ready(Ok(()));
}
state.flush_padded(self.state.clone(), self.file.borrow().clone().unwrap());
state.add_waker(cx.waker().clone());
if previous_waker.is_none() {
state.flush_padded(self.state.clone(), self.file.borrow().clone().unwrap());
}
Poll::Pending
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let file = self.file.borrow_mut().take();
let mut state = self.state.borrow_mut();
let previous_waker = state.waker.take();
match state.file_status {
FileStatus::Open => {
state.initiate_close(cx.waker().clone(), self.state.clone(), file.unwrap(), true);
assert!(
previous_waker.is_none(),
"Cannot close while flushing / writing"
);
state.add_waker(cx.waker().clone());
let file = self.file.borrow_mut().take();
state.initiate_close(self.state.clone(), file.unwrap(), true);
Poll::Pending
}
FileStatus::Closing => {
if previous_waker.is_some() {
state.add_waker(cx.waker().clone());
return Poll::Pending;
}

state.file_status = FileStatus::Closed;
match current_error!(state) {
Some(err) => Poll::Ready(err),
Expand All @@ -1414,6 +1447,8 @@ mod test {
use futures::{task::noop_waker_ref, AsyncRead, AsyncReadExt, AsyncWriteExt};
use std::{io::ErrorKind, path::Path, time::Duration};

const NUM_CONCURRENT_WRITERS: usize = 10;

macro_rules! file_stream_read_test {
( $name:ident, $dir:ident, $kind:ident, $file:ident, $file_size:ident: $size:tt, $code:block) => {
#[test]
Expand Down Expand Up @@ -1464,6 +1499,32 @@ mod test {
};
}

macro_rules! multiple_file_streams_write_test {
( $name:ident, $dir:ident, $kind:ident, $filenames:ident, $files:ident, $code:block) => {
#[test]
fn $name() {
for dir in make_test_directories(stringify!($name)) {
let $dir = dir.path.clone();
let $kind = dir.kind;
test_executor!(async move {
let $filenames = (0..NUM_CONCURRENT_WRITERS)
.into_iter()
.map(|i| $dir.join(format!("testfile-{}", i)))
.collect::<Vec<_>>();

let mut $files = Vec::with_capacity($filenames.len());
for filename in $filenames.iter() {
$files.push(DmaFile::create(filename).await.unwrap());
}
let $files = $files; // remove mut

$code
});
}
}
};
}

macro_rules! check_contents {
( $buf:expr, $start:expr ) => {
for (idx, i) in $buf.iter().enumerate() {
Expand Down Expand Up @@ -2106,4 +2167,53 @@ mod test {
assert_eq!(state.flushes.len(), 0);
});
}

multiple_file_streams_write_test!(join_write_multiple_writers, path, _k, filenames, files, {
let mut writers = files
.into_iter()
.map(|file| {
DmaStreamWriterBuilder::new(file)
.with_buffer_size(4096)
.with_write_behind(1)
.build()
})
.collect::<Vec<_>>();

let buffer = [0u8; 5000];
let results = join_all(writers.iter_mut().map(|w| w.write_all(&buffer))).await;
for result in results {
result.unwrap();
}

for mut writer in writers {
writer.close().await.unwrap();
}
});

multiple_file_streams_write_test!(join_close_multiple_writers, path, _k, filenames, files, {
let mut writers = files
.into_iter()
.map(|file| {
DmaStreamWriterBuilder::new(file)
.with_buffer_size(8)
.with_write_behind(16)
.build()
})
.collect::<Vec<_>>();

let buffer = [0u8; 5000];
for writer in &mut writers {
writer.write_all(&buffer).await.unwrap();
}

let results = join_all(writers.iter_mut().map(|w| w.close())).await;
for result in results {
result.unwrap();
}

for (writer, filename) in writers.into_iter().zip(filenames) {
assert_eq!(writer.current_flushed_pos(), 5000);
assert_eq!(file_size(&filename), 5000);
}
});
}

0 comments on commit a315d9c

Please sign in to comment.