-
Notifications
You must be signed in to change notification settings - Fork 242
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* attribute proc macro to bring your own types * keep original fn as it is add new with _byot suffix * update macro * update macro * use macro in main crate + add test * byot: assistants * byot: vector_stores * add where_clause attribute arg * remove print * byot: files * byot: images * add stream arg to attribute * byot: chat * byot: completions * fix comment * fix * byot: audio * byot: embeddings * byot: Fine Tunning * add byot tests * byot: moderations * byot tests: moderations * byot: threads * byot tests: threads * byot: messages * byot tests: messages * byot: runs * byot tests: runs * byot: steps * byot tests: run steps * byot: vector store files * byot test: vector store files * byot: vector store file batches * byot test: vector store file batches * cargo fmt * byot: batches * byot tests: batches * format * remove AssistantFiles and related apis (/assistants/assistant_id/files/..) * byot: audit logs * byot tests: audit logs * keep non byot code checks * byot: invites * byot tests: invites * remove message files API * byot: project api keys * byot tests: project api keys * byot: project service accounts * byot tests: project service accounts * byot: project users * byot tests: project users * byot: projects * byot tests: projects * byot: uploads * byot tests: uploads * byot: users * byot tests: users * add example to demonstrate bring-your-own-types * update README * update doc * cargo fmt * update doc in lib.rs * tests passing * fix for complier warning * fix compiler #[allow(unused_mut)] * cargo fix * fix all warnings * add Voices * publish = false for all examples * specify versions
- Loading branch information
Showing
49 changed files
with
1,133 additions
and
317 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
[workspace] | ||
members = [ "async-openai", "examples/*" ] | ||
members = [ "async-openai", "async-openai-*", "examples/*" ] | ||
# Only check / build main crates by default (check all with `--workspace`) | ||
default-members = ["async-openai"] | ||
default-members = ["async-openai", "async-openai-*"] | ||
resolver = "2" | ||
|
||
[workspace.package] | ||
rust-version = "1.75" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[package] | ||
name = "async-openai-macros" | ||
version = "0.1.0" | ||
authors = ["Himanshu Neema"] | ||
keywords = ["openai", "macros", "ai"] | ||
description = "Macros for async-openai" | ||
edition = "2021" | ||
license = "MIT" | ||
homepage = "https://github.com/64bit/async-openai" | ||
repository = "https://github.com/64bit/async-openai" | ||
rust-version = { workspace = true } | ||
|
||
[lib] | ||
proc-macro = true | ||
|
||
[dependencies] | ||
syn = { version = "2.0", features = ["full"] } | ||
quote = "1.0" | ||
proc-macro2 = "1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
use proc_macro::TokenStream; | ||
use quote::{quote, ToTokens}; | ||
use syn::{ | ||
parse::{Parse, ParseStream}, | ||
parse_macro_input, | ||
punctuated::Punctuated, | ||
token::Comma, | ||
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause, | ||
}; | ||
|
||
// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)] | ||
struct BoundArgs { | ||
bounds: Vec<(String, syn::TypeParamBound)>, | ||
where_clause: Option<String>, | ||
stream: bool, // Add stream flag | ||
} | ||
|
||
impl Parse for BoundArgs { | ||
fn parse(input: ParseStream) -> syn::Result<Self> { | ||
let mut bounds = Vec::new(); | ||
let mut where_clause = None; | ||
let mut stream = false; // Default to false | ||
let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?; | ||
|
||
for var in vars { | ||
let name = var.path.get_ident().unwrap().to_string(); | ||
match name.as_str() { | ||
"where_clause" => { | ||
where_clause = Some(var.value.into_token_stream().to_string()); | ||
} | ||
"stream" => { | ||
stream = var.value.into_token_stream().to_string().contains("true"); | ||
} | ||
_ => { | ||
let bound: syn::TypeParamBound = | ||
syn::parse_str(&var.value.into_token_stream().to_string())?; | ||
bounds.push((name, bound)); | ||
} | ||
} | ||
} | ||
Ok(BoundArgs { | ||
bounds, | ||
where_clause, | ||
stream, | ||
}) | ||
} | ||
} | ||
|
||
#[proc_macro_attribute] | ||
pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream { | ||
item | ||
} | ||
|
||
#[proc_macro_attribute] | ||
pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream { | ||
let bounds_args = parse_macro_input!(args as BoundArgs); | ||
let input = parse_macro_input!(item as ItemFn); | ||
let mut new_generics = Generics::default(); | ||
let mut param_count = 0; | ||
|
||
// Process function arguments | ||
let mut new_params = Vec::new(); | ||
let args = input | ||
.sig | ||
.inputs | ||
.iter() | ||
.map(|arg| { | ||
match arg { | ||
FnArg::Receiver(receiver) => receiver.to_token_stream(), | ||
FnArg::Typed(PatType { pat, .. }) => { | ||
if let Pat::Ident(pat_ident) = &**pat { | ||
let generic_name = format!("T{}", param_count); | ||
let generic_ident = | ||
syn::Ident::new(&generic_name, proc_macro2::Span::call_site()); | ||
|
||
// Create type parameter with optional bounds | ||
let mut type_param = TypeParam::from(generic_ident.clone()); | ||
if let Some((_, bound)) = bounds_args | ||
.bounds | ||
.iter() | ||
.find(|(name, _)| name == &generic_name) | ||
{ | ||
type_param.bounds.extend(vec![bound.clone()]); | ||
} | ||
|
||
new_params.push(GenericParam::Type(type_param)); | ||
param_count += 1; | ||
quote! { #pat_ident: #generic_ident } | ||
} else { | ||
arg.to_token_stream() | ||
} | ||
} | ||
} | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
// Add R type parameter with optional bounds | ||
let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site()); | ||
let mut return_type_param = TypeParam::from(generic_r.clone()); | ||
if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") { | ||
return_type_param.bounds.extend(vec![bound.clone()]); | ||
} | ||
new_params.push(GenericParam::Type(return_type_param)); | ||
|
||
// Add all generic parameters | ||
new_generics.params.extend(new_params); | ||
|
||
let fn_name = &input.sig.ident; | ||
let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span()); | ||
let vis = &input.vis; | ||
let block = &input.block; | ||
let attrs = &input.attrs; | ||
let asyncness = &input.sig.asyncness; | ||
|
||
// Parse where clause if provided | ||
let where_clause = if let Some(where_str) = bounds_args.where_clause { | ||
match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) { | ||
Ok(where_clause) => quote! { #where_clause }, | ||
Err(e) => return TokenStream::from(e.to_compile_error()), | ||
} | ||
} else { | ||
quote! {} | ||
}; | ||
|
||
// Generate return type based on stream flag | ||
let return_type = if bounds_args.stream { | ||
quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> } | ||
} else { | ||
quote! { Result<R, OpenAIError> } | ||
}; | ||
|
||
let expanded = quote! { | ||
#(#attrs)* | ||
#input | ||
|
||
#(#attrs)* | ||
#vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block | ||
}; | ||
|
||
expanded.into() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.