diff --git a/CHANGELOG.md b/CHANGELOG.md index 855f900..3bd177f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.1.2 + +- Add exponential backoff when rate limited + ## 0.1.1 - Add context length management features diff --git a/Cargo.lock b/Cargo.lock index 26dbce9..9fcb920 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1162,6 +1162,7 @@ dependencies = [ "dirs", "dotenvy", "error-stack", + "fastrand", "flume", "itertools 0.11.0", "liquid", diff --git a/Cargo.toml b/Cargo.toml index 957caaf..c8f03a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ clap = { version = "4.4.7", features = ["derive", "env", "string"] } dirs = "5.0.1" dotenvy = "0.15.7" error-stack = "0.4.1" +fastrand = "2.0.1" flume = "0.11.0" itertools = "0.11.0" liquid = "0.26.4" diff --git a/src/main.rs b/src/main.rs index aa0dee9..7cf0542 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,7 @@ mod model; mod ollama; mod openai; mod option; +mod requests; mod template; #[cfg(test)] mod tests; diff --git a/src/openai.rs b/src/openai.rs index 12cbee5..b22bb7f 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -4,7 +4,10 @@ use error_stack::{Report, ResultExt}; use serde::Deserialize; use serde_json::json; -use crate::model::{map_model_response_err, ModelComms, ModelError, ModelOptions}; +use crate::{ + model::{map_model_response_err, ModelComms, ModelError, ModelOptions}, + requests::request_with_retry, +}; pub const OPENAI_HOST: &str = "https://api.openai.com"; @@ -98,12 +101,13 @@ pub fn send_chat_request( body["max_tokens"] = json!(max_tokens); } - let mut response: ChatCompletion = create_base_request(&options, "v1/chat/completions") - .timeout(Duration::from_secs(30)) - .send_json(body) - .map_err(map_model_response_err)? - .into_json() - .change_context(ModelError::Deserialize)?; + let mut response: ChatCompletion = request_with_retry( + create_base_request(&options, "v1/chat/completions").timeout(Duration::from_secs(30)), + body, + ) + .map_err(map_model_response_err)? + .into_json() + .change_context(ModelError::Deserialize)?; // TODO streaming let result = response diff --git a/src/requests.rs b/src/requests.rs new file mode 100644 index 0000000..3daf865 --- /dev/null +++ b/src/requests.rs @@ -0,0 +1,32 @@ +use serde::Serialize; + +pub fn request_with_retry( + req: ureq::Request, + body: impl Serialize, +) -> Result { + const MAX_TRIES: u32 = 4; + let mut try_num = 0; + let delay = 1000; + loop { + let response = req.clone().send_json(&body); + match response { + Ok(res) => return Ok(res), + Err(ureq::Error::Status(code, response)) => { + if code != 429 || try_num > MAX_TRIES { + return Err(ureq::Error::Status(code, response)); + } + + // This is potentially retryable. We don't do anything smart right now, just a + // random exponential backoff. + + let perturb = fastrand::i32(-100..100); + let this_delay = 2i32.pow(try_num) * delay + perturb; + + eprintln!("Rate limited... waiting {this_delay}ms to retry"); + std::thread::sleep(std::time::Duration::from_millis(this_delay as u64)); + try_num += 1; + } + e @ Err(_) => return e, + } + } +}