Skip to content

Commit

Permalink
feat: Bring your own types (#342)
Browse files Browse the repository at this point in the history
* attribute proc macro to bring your own types

* keep original fn as it is add new with _byot suffix

* update macro

* update macro

* use macro in main crate + add test

* byot: assistants

* byot: vector_stores

* add where_clause attribute arg

* remove print

* byot: files

* byot: images

* add stream arg to attribute

* byot: chat

* byot: completions

* fix comment

* fix

* byot: audio

* byot: embeddings

* byot: Fine Tunning

* add byot tests

* byot: moderations

* byot tests: moderations

* byot: threads

* byot tests: threads

* byot: messages

* byot tests: messages

* byot: runs

* byot tests: runs

* byot: steps

* byot tests: run steps

* byot: vector store files

* byot test: vector store files

* byot: vector store file batches

* byot test: vector store file batches

* cargo fmt

* byot: batches

* byot tests: batches

* format

* remove AssistantFiles and related apis (/assistants/assistant_id/files/..)

* byot: audit logs

* byot tests: audit logs

* keep non byot code checks

* byot: invites

* byot tests: invites

* remove message files API

* byot: project api keys

* byot tests: project api keys

* byot: project service accounts

* byot tests: project service accounts

* byot: project users

* byot tests: project users

* byot: projects

* byot tests: projects

* byot: uploads

* byot tests: uploads

* byot: users

* byot tests: users

* add example to demonstrate bring-your-own-types

* update README

* update doc

* cargo fmt

* update doc in lib.rs

* tests passing

* fix for complier warning

* fix compiler #[allow(unused_mut)]

* cargo fix

* fix all warnings

* add Voices

* publish = false for all examples

* specify versions
  • Loading branch information
64bit authored Mar 3, 2025
1 parent c48e62e commit 638bf75
Show file tree
Hide file tree
Showing 49 changed files with 1,133 additions and 317 deletions.
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[workspace]
members = [ "async-openai", "examples/*" ]
members = [ "async-openai", "async-openai-*", "examples/*" ]
# Only check / build main crates by default (check all with `--workspace`)
default-members = ["async-openai"]
default-members = ["async-openai", "async-openai-*"]
resolver = "2"

[workspace.package]
rust-version = "1.75"
19 changes: 19 additions & 0 deletions async-openai-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "async-openai-macros"
version = "0.1.0"
authors = ["Himanshu Neema"]
keywords = ["openai", "macros", "ai"]
description = "Macros for async-openai"
edition = "2021"
license = "MIT"
homepage = "https://github.com/64bit/async-openai"
repository = "https://github.com/64bit/async-openai"
rust-version = { workspace = true }

[lib]
proc-macro = true

[dependencies]
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"
141 changes: 141 additions & 0 deletions async-openai-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
token::Comma,
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
};

// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)]
struct BoundArgs {
bounds: Vec<(String, syn::TypeParamBound)>,
where_clause: Option<String>,
stream: bool, // Add stream flag
}

impl Parse for BoundArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut bounds = Vec::new();
let mut where_clause = None;
let mut stream = false; // Default to false
let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;

for var in vars {
let name = var.path.get_ident().unwrap().to_string();
match name.as_str() {
"where_clause" => {
where_clause = Some(var.value.into_token_stream().to_string());
}
"stream" => {
stream = var.value.into_token_stream().to_string().contains("true");
}
_ => {
let bound: syn::TypeParamBound =
syn::parse_str(&var.value.into_token_stream().to_string())?;
bounds.push((name, bound));
}
}
}
Ok(BoundArgs {
bounds,
where_clause,
stream,
})
}
}

#[proc_macro_attribute]
pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
item
}

#[proc_macro_attribute]
pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
let bounds_args = parse_macro_input!(args as BoundArgs);
let input = parse_macro_input!(item as ItemFn);
let mut new_generics = Generics::default();
let mut param_count = 0;

// Process function arguments
let mut new_params = Vec::new();
let args = input
.sig
.inputs
.iter()
.map(|arg| {
match arg {
FnArg::Receiver(receiver) => receiver.to_token_stream(),
FnArg::Typed(PatType { pat, .. }) => {
if let Pat::Ident(pat_ident) = &**pat {
let generic_name = format!("T{}", param_count);
let generic_ident =
syn::Ident::new(&generic_name, proc_macro2::Span::call_site());

// Create type parameter with optional bounds
let mut type_param = TypeParam::from(generic_ident.clone());
if let Some((_, bound)) = bounds_args
.bounds
.iter()
.find(|(name, _)| name == &generic_name)
{
type_param.bounds.extend(vec![bound.clone()]);
}

new_params.push(GenericParam::Type(type_param));
param_count += 1;
quote! { #pat_ident: #generic_ident }
} else {
arg.to_token_stream()
}
}
}
})
.collect::<Vec<_>>();

// Add R type parameter with optional bounds
let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
let mut return_type_param = TypeParam::from(generic_r.clone());
if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
return_type_param.bounds.extend(vec![bound.clone()]);
}
new_params.push(GenericParam::Type(return_type_param));

// Add all generic parameters
new_generics.params.extend(new_params);

let fn_name = &input.sig.ident;
let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span());
let vis = &input.vis;
let block = &input.block;
let attrs = &input.attrs;
let asyncness = &input.sig.asyncness;

// Parse where clause if provided
let where_clause = if let Some(where_str) = bounds_args.where_clause {
match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
Ok(where_clause) => quote! { #where_clause },
Err(e) => return TokenStream::from(e.to_compile_error()),
}
} else {
quote! {}
};

// Generate return type based on stream flag
let return_type = if bounds_args.stream {
quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
} else {
quote! { Result<R, OpenAIError> }
};

let expanded = quote! {
#(#attrs)*
#input

#(#attrs)*
#vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
};

expanded.into()
}
12 changes: 10 additions & 2 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[package]
name = "async-openai"
version = "0.27.2"
version = "0.28.0"
authors = ["Himanshu Neema"]
categories = ["api-bindings", "web-programming", "asynchronous"]
keywords = ["openai", "async", "openapi", "ai"]
description = "Rust library for OpenAI"
edition = "2021"
rust-version = "1.75"
rust-version = { workspace = true }
license = "MIT"
readme = "README.md"
homepage = "https://github.com/64bit/async-openai"
Expand All @@ -23,8 +23,11 @@ native-tls = ["reqwest/native-tls"]
# Remove dependency on OpenSSL
native-tls-vendored = ["reqwest/native-tls-vendored"]
realtime = ["dep:tokio-tungstenite"]
# Bring your own types
byot = []

[dependencies]
async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" }
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
futures = "0.3.31"
Expand All @@ -50,6 +53,11 @@ tokio-tungstenite = { version = "0.26.1", optional = true, default-features = fa

[dev-dependencies]
tokio-test = "0.4.4"
serde_json = "1.0"

[[test]]
name = "bring-your-own-type"
required-features = ["byot"]

[package.metadata.docs.rs]
all-features = true
Expand Down
36 changes: 35 additions & 1 deletion async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
- [x] Organizations | Administration (partially implemented)
- [x] Realtime (Beta) (partially implemented)
- [x] Uploads
- Bring your own custom types for Request or Response objects.
- SSE streaming on available APIs
- Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits).
- Ergonomic builder pattern for all request objects.
Expand All @@ -62,7 +63,7 @@ $Env:OPENAI_API_KEY='sk-...'
## Realtime API

Only types for Realtime API are implemented, and can be enabled with feature flag `realtime`.
These types may change if/when OpenAI releases official specs for them.
These types were written before OpenAI released official specs.

## Image Generation Example

Expand Down Expand Up @@ -108,6 +109,39 @@ async fn main() -> Result<(), Box<dyn Error>> {
<sub>Scaled up for README, actual size 256x256</sub>
</div>

## Bring Your Own Types

Enable methods whose input and outputs are generics with `byot` feature. It creates a new method with same name and `_byot` suffix.

For example, to use `serde_json::Value` as request and response type:
```rust
let response: Value = client
.chat()
.create_byot(json!({
"messages": [
{
"role": "developer",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What do you think about life?"
}
],
"model": "gpt-4o",
"store": false
}))
.await?;
```

This can be useful in many scenarios:
- To use this library with other OpenAI compatible APIs whose types don't exactly match OpenAI.
- Extend existing types in this crate with new fields with `serde`.
- To avoid verbose types.
- To escape deserialization errors.

Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type) directory to learn more.

## Contributing

Thank you for taking the time to contribute and improve the project. I'd be happy to have you!
Expand Down
66 changes: 0 additions & 66 deletions async-openai/src/assistant_files.rs

This file was deleted.

14 changes: 7 additions & 7 deletions async-openai/src/assistants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse,
ModifyAssistantRequest,
},
AssistantFiles, Client,
Client,
};

/// Build assistants that can call models and use tools to perform tasks.
Expand All @@ -22,12 +22,8 @@ impl<'c, C: Config> Assistants<'c, C> {
Self { client }
}

/// Assistant [AssistantFiles] API group
pub fn files(&self, assistant_id: &str) -> AssistantFiles<C> {
AssistantFiles::new(self.client, assistant_id)
}

/// Create an assistant with a model and instructions.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateAssistantRequest,
Expand All @@ -36,13 +32,15 @@ impl<'c, C: Config> Assistants<'c, C> {
}

/// Retrieves an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, assistant_id: &str) -> Result<AssistantObject, OpenAIError> {
self.client
.get(&format!("/assistants/{assistant_id}"))
.await
}

/// Modifies an assistant.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn update(
&self,
assistant_id: &str,
Expand All @@ -54,17 +52,19 @@ impl<'c, C: Config> Assistants<'c, C> {
}

/// Delete an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn delete(&self, assistant_id: &str) -> Result<DeleteAssistantResponse, OpenAIError> {
self.client
.delete(&format!("/assistants/{assistant_id}"))
.await
}

/// Returns a list of assistants.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListAssistantsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/assistants", query).await
self.client.get_with_query("/assistants", &query).await
}
}
Loading

0 comments on commit 638bf75

Please sign in to comment.