Skip to content

Commit 807e468

Browse files
authored
feat: Ergonomics: builder pattern for API groups and request objects (#29)
* update depedencies * make Client cloneable * Add derive_builder dependency * Ergonomics: Completions API is fully supported by builder pattern * Update all Completions example to use ergonomic builder pattern * Ergonomics: Edits API is fully supported by builder pattern * Update Edits example to use ergonomic builder pattern * Ergonomics: Images group API is fully supported by builder pattern * Update Images example to use ergonomic builder pattern * Ergonomics: Moderations group API is fully supported by builder pattern * Update Moderations example to use ergonomic builder pattern * Ergonomics: Files group API is fully supported by builder pattern * Ergonomics: FineTunes group API is fully supported by builder pattern * Ergonomics: Embeddings group API is fully supported by builder pattern * Update Completions Prompt macro test * Ergonomics: Models group API is fully supported by builder pattern * Update Models example to use ergonomic builder pattern * update readme * updated file test
1 parent 6f624ad commit 807e468

File tree

42 files changed

+975
-498
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+975
-498
lines changed

async-openai/Cargo.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@ repository = "https://github.com/64bit/async-openai"
1717

1818
[dependencies]
1919
backoff = {version = "0.4.0", features = ["tokio"] }
20-
base64 = "0.13.1"
20+
base64 = "0.20.0"
2121
futures = "0.3.25"
2222
rand = "0.8.5"
2323
reqwest = { version = "0.11.13", features = ["json", "stream", "multipart"] }
2424
reqwest-eventsource = "0.4.0"
25-
serde = { version = "1.0.148", features = ["derive", "rc"] }
26-
serde_json = "1.0.87"
27-
thiserror = "1.0.37"
28-
tokio = { version = "1.22.0", features = ["fs", "macros"] }
25+
serde = { version = "1.0.152", features = ["derive", "rc"] }
26+
serde_json = "1.0.91"
27+
thiserror = "1.0.38"
28+
tokio = { version = "1.23.0", features = ["fs", "macros"] }
2929
tokio-stream = "0.1.11"
3030
tokio-util = { version = "0.7.4", features = ["codec", "io-util"] }
3131
tracing = "0.1.37"
32+
derive_builder = "0.12.0"
3233

3334
[dev-dependencies]
3435
tokio-test = "0.4.2"

async-openai/README.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
- [x] Models
3434
- [x] Moderations
3535
- Non-streaming requests are retried with exponential backoff when [rate limited](https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits) by the API server.
36+
- Ergonomic Rust library with builder pattern for all request objects.
3637

3738
*Being a young project there are rough edges*
3839

@@ -56,31 +57,30 @@ $Env:OPENAI_API_KEY='sk-...'
5657
## Image Generation Example
5758

5859
```rust
59-
use std::error::Error;
60-
6160
use async_openai::{
62-
types::{CreateImageRequest, ImageSize, ResponseFormat},
63-
Client, Image,
61+
types::{CreateImageRequestArgs, ImageSize, ResponseFormat},
62+
Client,
6463
};
64+
use std::error::Error;
6565

6666
#[tokio::main]
6767
async fn main() -> Result<(), Box<dyn Error>> {
6868
// create client, reads OPENAI_API_KEY environment variable for API key.
6969
let client = Client::new();
7070

71-
let request = CreateImageRequest {
72-
prompt: "cats on sofa and carpet in living room".to_owned(),
73-
n: Some(2),
74-
response_format: Some(ResponseFormat::Url),
75-
size: Some(ImageSize::S256x256),
76-
user: Some("async-openai".to_owned()),
77-
};
71+
let request = CreateImageRequestArgs::default()
72+
.prompt("cats on sofa and carpet in living room")
73+
.n(2)
74+
.response_format(ResponseFormat::Url)
75+
.size(ImageSize::S256x256)
76+
.user("async-openai")
77+
.build()?;
7878

79-
let response = Image::create(&client, request).await?;
79+
let response = client.images().create(request).await?;
8080

81-
// Download and save images to ./data directory
82-
// Each url download and save happens in dedicated Tokio task
83-
// (creates directory when it doesn't exist)
81+
// Download and save images to ./data directory.
82+
// Each url is downloaded and saved in dedicated Tokio task.
83+
// Directory is created if it doesn't exist.
8484
response.save("./data").await?;
8585

8686
Ok(())

async-openai/src/client.rs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,16 @@ use reqwest::header::HeaderMap;
55
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
66
use serde::{de::DeserializeOwned, Serialize};
77

8-
use crate::error::{OpenAIError, WrappedError};
9-
10-
#[derive(Debug, Default)]
8+
use crate::{
9+
edit::Edits,
10+
error::{OpenAIError, WrappedError},
11+
file::Files,
12+
image::Images,
13+
moderation::Moderations,
14+
Completions, Embeddings, FineTunes, Models,
15+
};
16+
17+
#[derive(Debug, Default, Clone)]
1118
/// Client is a container for api key, base url, organization id, and backoff
1219
/// configuration used to make API calls.
1320
pub struct Client {
@@ -64,6 +71,48 @@ impl Client {
6471
&self.api_key
6572
}
6673

74+
// API groups
75+
76+
/// To call [Models] group related APIs using this client.
77+
pub fn models(&self) -> Models {
78+
Models::new(self)
79+
}
80+
81+
/// To call [Completions] group related APIs using this client.
82+
pub fn completions(&self) -> Completions {
83+
Completions::new(self)
84+
}
85+
86+
/// To call [Edits] group related APIs using this client.
87+
pub fn edits(&self) -> Edits {
88+
Edits::new(self)
89+
}
90+
91+
/// To call [Images] group related APIs using this client.
92+
pub fn images(&self) -> Images {
93+
Images::new(self)
94+
}
95+
96+
/// To call [Moderations] group related APIs using this client.
97+
pub fn moderations(&self) -> Moderations {
98+
Moderations::new(self)
99+
}
100+
101+
/// To call [Files] group related APIs using this client.
102+
pub fn files(&self) -> Files {
103+
Files::new(self)
104+
}
105+
106+
/// To call [FineTunes] group related APIs using this client.
107+
pub fn fine_tunes(&self) -> FineTunes {
108+
FineTunes::new(self)
109+
}
110+
111+
/// To call [Embeddings] group related APIs using this client.
112+
pub fn embeddings(&self) -> Embeddings {
113+
Embeddings::new(self)
114+
}
115+
67116
fn headers(&self) -> HeaderMap {
68117
let mut headers = HeaderMap::new();
69118
if !self.org_id.is_empty() {

async-openai/src/completion.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,26 @@ use crate::{
77
/// Given a prompt, the model will return one or more predicted
88
/// completions, and can also return the probabilities of alternative
99
/// tokens at each position.
10-
pub struct Completion;
10+
pub struct Completions<'c> {
11+
client: &'c Client,
12+
}
13+
14+
impl<'c> Completions<'c> {
15+
pub fn new(client: &'c Client) -> Self {
16+
Self { client }
17+
}
1118

12-
impl Completion {
1319
/// Creates a completion for the provided prompt and parameters
1420
pub async fn create(
15-
client: &Client,
21+
&self,
1622
request: CreateCompletionRequest,
1723
) -> Result<CreateCompletionResponse, OpenAIError> {
1824
if request.stream.is_some() && request.stream.unwrap() {
1925
return Err(OpenAIError::InvalidArgument(
2026
"When stream is true, use Completion::create_stream".into(),
2127
));
2228
}
23-
client.post("/completions", request).await
29+
self.client.post("/completions", request).await
2430
}
2531

2632
/// Creates a completion request for the provided prompt and parameters
@@ -31,7 +37,7 @@ impl Completion {
3137
///
3238
/// [CompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
3339
pub async fn create_stream(
34-
client: &Client,
40+
&self,
3541
mut request: CreateCompletionRequest,
3642
) -> Result<CompletionResponseStream, OpenAIError> {
3743
if request.stream.is_some() && !request.stream.unwrap() {
@@ -42,6 +48,6 @@ impl Completion {
4248

4349
request.stream = Some(true);
4450

45-
Ok(client.post_stream("/completions", request).await)
51+
Ok(self.client.post_stream("/completions", request).await)
4652
}
4753
}

async-openai/src/edit.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@ use crate::{
66

77
/// Given a prompt and an instruction, the model will return
88
/// an edited version of the prompt.
9-
pub struct Edit;
9+
pub struct Edits<'c> {
10+
client: &'c Client,
11+
}
12+
13+
impl<'c> Edits<'c> {
14+
pub fn new(client: &'c Client) -> Self {
15+
Self { client }
16+
}
1017

11-
impl Edit {
1218
/// Creates a new edit for the provided input, instruction, and parameters
1319
pub async fn create(
14-
client: &Client,
20+
&self,
1521
request: CreateEditRequest,
1622
) -> Result<CreateEditResponse, OpenAIError> {
17-
client.post("/edits", request).await
23+
self.client.post("/edits", request).await
1824
}
1925
}

async-openai/src/embedding.rs

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,84 +8,100 @@ use crate::{
88
/// consumed by machine learning models and algorithms.
99
///
1010
/// Related guide: [Embeddings](https://beta.openai.com/docs/guides/embeddings/what-are-embeddings)
11-
pub struct Embeddings;
11+
pub struct Embeddings<'c> {
12+
client: &'c Client,
13+
}
14+
15+
impl<'c> Embeddings<'c> {
16+
pub fn new(client: &'c Client) -> Self {
17+
Self { client }
18+
}
1219

13-
impl Embeddings {
1420
/// Creates an embedding vector representing the input text.
1521
pub async fn create(
16-
client: &Client,
22+
&self,
1723
request: CreateEmbeddingRequest,
1824
) -> Result<CreateEmbeddingResponse, OpenAIError> {
19-
client.post("/embeddings", request).await
25+
self.client.post("/embeddings", request).await
2026
}
2127
}
2228

2329
#[cfg(test)]
2430
mod tests {
25-
use crate::{types::CreateEmbeddingRequest, Client, Embeddings};
31+
use crate::{types::CreateEmbeddingRequestArgs, Client};
2632

2733
#[tokio::test]
2834
async fn test_embedding_string() {
2935
let client = Client::new();
30-
let request = CreateEmbeddingRequest {
31-
model: "text-embedding-ada-002".to_owned(),
32-
input: crate::types::EmbeddingInput::String(
33-
"The food was delicious and the waiter...".to_owned(),
34-
),
35-
..Default::default()
36-
};
3736

38-
let response = Embeddings::create(&client, request).await;
37+
let request = CreateEmbeddingRequestArgs::default()
38+
.model("text-embedding-ada-002")
39+
.input("The food was delicious and the waiter...")
40+
.build()
41+
.unwrap();
3942

40-
println!("{:#?}", response);
43+
let response = client.embeddings().create(request).await;
44+
45+
assert!(response.is_ok());
4146
}
4247

4348
#[tokio::test]
4449
async fn test_embedding_string_array() {
4550
let client = Client::new();
46-
let request = CreateEmbeddingRequest {
47-
model: "text-embedding-ada-002".to_owned(),
48-
input: crate::types::EmbeddingInput::StringArray(vec![
49-
"The food was delicious".to_owned(),
50-
"The waiter was good".to_owned(),
51-
]),
52-
..Default::default()
53-
};
54-
55-
let response = Embeddings::create(&client, request).await;
56-
57-
println!("{:#?}", response);
51+
52+
let request = CreateEmbeddingRequestArgs::default()
53+
.model("text-embedding-ada-002")
54+
.input(["The food was delicious", "The waiter was good"])
55+
.build()
56+
.unwrap();
57+
58+
let response = client.embeddings().create(request).await;
59+
60+
assert!(response.is_ok());
5861
}
5962

6063
#[tokio::test]
6164
async fn test_embedding_integer_array() {
6265
let client = Client::new();
63-
let request = CreateEmbeddingRequest {
64-
model: "text-embedding-ada-002".to_owned(),
65-
input: crate::types::EmbeddingInput::IntegerArray(vec![1, 2, 3]),
66-
..Default::default()
67-
};
6866

69-
let response = Embeddings::create(&client, request).await;
67+
let request = CreateEmbeddingRequestArgs::default()
68+
.model("text-embedding-ada-002")
69+
.input([1, 2, 3])
70+
.build()
71+
.unwrap();
72+
73+
let response = client.embeddings().create(request).await;
74+
75+
assert!(response.is_ok());
76+
}
77+
78+
#[tokio::test]
79+
async fn test_embedding_array_of_integer_array_matrix() {
80+
let client = Client::new();
81+
82+
let request = CreateEmbeddingRequestArgs::default()
83+
.model("text-embedding-ada-002")
84+
.input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
85+
.build()
86+
.unwrap();
7087

71-
println!("{:#?}", response);
88+
let response = client.embeddings().create(request).await;
89+
90+
assert!(response.is_ok());
7291
}
7392

7493
#[tokio::test]
7594
async fn test_embedding_array_of_integer_array() {
7695
let client = Client::new();
77-
let request = CreateEmbeddingRequest {
78-
model: "text-embedding-ada-002".to_owned(),
79-
input: crate::types::EmbeddingInput::ArrayOfIntegerArray(vec![
80-
vec![1, 2, 3],
81-
vec![4, 5, 6],
82-
vec![7, 8, 100257],
83-
]),
84-
..Default::default()
85-
};
86-
87-
let response = Embeddings::create(&client, request).await;
88-
89-
println!("{:#?}", response);
96+
97+
let request = CreateEmbeddingRequestArgs::default()
98+
.model("text-embedding-ada-002")
99+
.input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
100+
.build()
101+
.unwrap();
102+
103+
let response = client.embeddings().create(request).await;
104+
105+
assert!(response.is_ok());
90106
}
91107
}

async-openai/src/error.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ pub enum OpenAIError {
2121
/// Error when trying to stream completions SSE
2222
#[error("stream failed: {0}")]
2323
StreamError(String),
24-
/// Error from client side validation before making API call
24+
/// Error from client side validation
25+
/// or when builder fails to build request before making API call
2526
#[error("invalid args: {0}")]
2627
InvalidArgument(String),
2728
}

0 commit comments

Comments
 (0)