Skip to content

Commit

Permalink
fix: add barrier for better opencl memory fencing (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
stringhandler authored Dec 9, 2024
1 parent c537428 commit d81697e
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/config_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub(crate) struct ConfigFile {
pub single_grid_size: u32,
pub per_device_grid_sizes: Vec<u32>,
pub template_timeout_secs: u64,
#[serde(default = "default_max_template_failures")]
pub max_template_failures: u64,
}

impl Default for ConfigFile {
Expand All @@ -36,6 +38,7 @@ impl Default for ConfigFile {
single_grid_size: 1024,
per_device_grid_sizes: vec![],
template_timeout_secs: 1,
max_template_failures: 10,
}
}
}
Expand All @@ -54,3 +57,7 @@ impl ConfigFile {
Ok(())
}
}

fn default_max_template_failures() -> u64 {
10
}
127 changes: 127 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ struct Cli {

#[arg(short, long)]
template_timeout_secs: Option<u64>,

#[arg(long)]
max_template_failures: Option<usize>,
}

async fn main_inner() -> Result<(), anyhow::Error> {
Expand Down Expand Up @@ -446,6 +449,122 @@ async fn main_inner() -> Result<(), anyhow::Error> {
return Ok(());
}

if let Some(max_template_failures) = cli.max_template_failures {
config.max_template_failures = max_template_failures as u64;
}
// create a list of devices (by index) to use
let devices_to_use: Vec<u32> = (0..num_devices)
.filter(|x| {
if let Some(use_devices) = &cli.use_devices {
use_devices.contains(x)
} else {
true
}
})
.filter(|x| {
if let Some(excluded_devices) = &cli.exclude_devices {
!excluded_devices.contains(x)
} else {
true
}
})
.collect();

info!(target: LOG_TARGET, "Device indexes to use: {:?} from the total number of devices: {:?}", devices_to_use, num_devices);

println!(
"Device indexes to use: {:?} from the total number of devices: {:?}",
devices_to_use, num_devices
);

if cli.find_optimal {
let mut best_hashrate = 0;
let mut best_grid_size = 1;
let mut current_grid_size = 32;
let mut is_doubling_stage = true;
let mut last_grid_size_increase = 0;
let mut prev_hashrate = 0;

while true {
dbg!("here");
let mut config = config.clone();
config.single_grid_size = current_grid_size;
// config.block_size = ;
let mut threads = vec![];
let (tx, rx) = tokio::sync::broadcast::channel(100);
for i in 0..num_devices {
if !devices_to_use.contains(&i) {
continue;
}
let c = config.clone();
let gpu = gpu_engine.clone();
let x = tx.clone();
threads.push(thread::spawn(move || {
run_thread(gpu, num_devices as u64, i as u32, c, true, x)
}));
}
let thread_len = threads.len();
let mut thread_hashrate = Vec::with_capacity(thread_len);
for t in threads {
match t.join() {
Ok(res) => match res {
Ok(hashrate) => {
info!(target: LOG_TARGET, "Thread join succeeded: {}", hashrate.to_formatted_string(&Locale::en));
thread_hashrate.push(hashrate);
},
Err(err) => {
eprintln!("Thread join succeeded but result failed: {:?}", err);
error!(target: LOG_TARGET, "Thread join succeeded but result failed: {:?}", err);
},
},
Err(err) => {
eprintln!("Thread join failed: {:?}", err);
error!(target: LOG_TARGET, "Thread join failed: {:?}", err);
},
}
}
let total_hashrate: u64 = thread_hashrate.iter().sum();
if total_hashrate > best_hashrate {
best_hashrate = total_hashrate;
best_grid_size = current_grid_size;
// best_grid_size = config.single_grid_size;
// best_block_size = config.block_size;
println!(
"Best hashrate: {} grid_size: {}, current_grid: {} block_size: {} Prev Hash {}",
best_hashrate, best_grid_size, current_grid_size, config.block_size, prev_hashrate
);
}
// if total_hashrate < prev_hashrate {
// println!("total decreased, breaking");
// break;
// }
if is_doubling_stage {
if total_hashrate > prev_hashrate {
last_grid_size_increase = current_grid_size;
current_grid_size = current_grid_size * 2;
} else {
is_doubling_stage = false;
last_grid_size_increase = last_grid_size_increase / 2;
current_grid_size = current_grid_size.saturating_sub(last_grid_size_increase);
}
} else {
// Bisecting stage
if last_grid_size_increase < 2 {
break;
}
if total_hashrate > prev_hashrate {
last_grid_size_increase = last_grid_size_increase / 2;
current_grid_size += last_grid_size_increase;
} else {
last_grid_size_increase = last_grid_size_increase / 2;
current_grid_size = current_grid_size.saturating_sub(last_grid_size_increase);
}
}
prev_hashrate = total_hashrate;
}
return Ok(());
}

let (stats_tx, stats_rx) = tokio::sync::broadcast::channel(100);
if config.http_server_enabled {
let mut stats_collector = stats_collector::StatsCollector::new(shutdown.to_signal(), stats_rx);
Expand Down Expand Up @@ -536,6 +655,7 @@ fn run_thread<T: EngineImpl>(
} else {
ClientType::BaseNode
};
let mut template_fetch_failures = 0;
let coinbase_extra = config.coinbase_extra.clone();
let node_client = Arc::new(RwLock::new(runtime.block_on(async move {
node_client::create_client(client_type, &tari_node_url, coinbase_extra).await
Expand Down Expand Up @@ -586,6 +706,7 @@ fn run_thread<T: EngineImpl>(
let mining_hash: FixedHash;
match runtime.block_on(async move { get_template(clone_config, clone_node_client, rounds, benchmark).await }) {
Ok((res_target_difficulty, res_block, res_header, res_mining_hash)) => {
template_fetch_failures = 0;
info!(target: LOG_TARGET, "Getting next block...");
println!("Getting next block...{}", res_header.height);
target_difficulty = res_target_difficulty;
Expand All @@ -595,6 +716,12 @@ fn run_thread<T: EngineImpl>(
previous_template = Some((target_difficulty, block.clone(), header.clone(), mining_hash.clone()));
},
Err(error) => {
template_fetch_failures += 1;
if template_fetch_failures > config.max_template_failures {
eprintln!("Too many template fetch failures, exiting");
error!(target: LOG_TARGET, "Too many template fetch failures, exiting");
return Err(error);
}
println!("Error during getting next block: {error:?}");
error!(target: LOG_TARGET, "Error during getting next block: {:?}", error);
if previous_template.is_none() {
Expand Down
1 change: 1 addition & 0 deletions src/opencl_sha3.cl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ kernel void sha3(global ulong *buffer, ulong nonce_start, ulong difficulty,

// check difficulty
ulong swap = swap_endian_64(state[0]);
barrier(CLK_GLOBAL_MEM_FENCE);
if (swap < difficulty) {
if (output_1[1] == 0 || output_1[1] > swap) {
output_1[0] = nonce_start + get_global_id(0) + i * get_global_size(0);
Expand Down

0 comments on commit d81697e

Please sign in to comment.