Skip to content

Commit

Permalink
Add causual sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 28, 2023
1 parent dcc5e80 commit e8e366e
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions data_server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use clap::Parser;
use log::info;
use prost::Message;
use rand::prelude::IteratorRandom;
use rand::seq::SliceRandom;
use rand::thread_rng;
use rand::{thread_rng, Rng};
use std::fs::File;
use std::io::{self, BufReader, Read, Result as IoResult};
use std::vec;
Expand All @@ -21,6 +20,7 @@ use text_data::{
#[derive(Default)]
pub struct MyDataService {
groups: Vec<TextData>,
causual_sampling: bool,
weights: Vec<f32>,
}

Expand Down Expand Up @@ -56,7 +56,7 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
}

impl MyDataService {
pub fn new(files: Vec<String>) -> IoResult<Self> {
pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
let mut groups = Vec::new();
let mut weights = Vec::new();

Expand All @@ -73,7 +73,7 @@ impl MyDataService {

info!("Loaded {} groups", groups.len());

Ok(MyDataService { groups, weights })
Ok(MyDataService { groups, weights, causual_sampling })
}
}

Expand All @@ -90,15 +90,36 @@ impl DataService for MyDataService {
.groups
.choose_weighted(&mut rng, |item| item.sentences.len() as f32);

if group.is_ok() {
let group = group.unwrap();
if group.is_err() {
return Err(Status::internal("Failed to select a group"));
}

let group = group.unwrap();

if self.causual_sampling {
if num_samples > group.sentences.len() {
num_samples = group.sentences.len();
}

// Random number between 0 and group.sentences.len() - num_samples
let max = group.sentences.len() - num_samples;
if max <= 0 {
return Ok(Response::new(SampledData {
name: group.name.clone(),
source: group.source.clone(),
samples: group.sentences.clone(),
}));
}

let start = rng.gen_range(0..max);
Ok(Response::new(SampledData {
name: group.name.clone(),
source: group.source.clone(),
samples: group.sentences[start..start + num_samples].to_vec(),
}))
} else {
let sentences_ref = group
.sentences
.iter()
.choose_multiple(&mut rng, num_samples);

let sentences: Vec<Sentence> = sentences_ref
Expand All @@ -109,10 +130,8 @@ impl DataService for MyDataService {
Ok(Response::new(SampledData {
name: group.name.clone(),
source: group.source.clone(),
samples: sentences
samples: sentences,
}))
} else {
Err(Status::internal("Failed to select a group"))
}
}
}
Expand All @@ -124,6 +143,10 @@ struct Args {
/// Files to process
#[clap(short, long, value_name = "FILE", required = true)]
files: Vec<String>,

/// Causual sampling
#[clap(short, long, default_value = "false")]
causal: bool
}

#[tokio::main]
Expand All @@ -132,9 +155,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// Parse command-line arguments
let args = Args::parse();
info!("Arguments: {:?}", args);

let addr = "127.0.0.1:50051".parse()?;
let data_service = MyDataService::new(args.files)?;
let data_service = MyDataService::new(args.files, args.causal)?;

info!("Starting server at {}", addr);

Expand Down

0 comments on commit e8e366e

Please sign in to comment.