Skip to content

Commit 3353700

Browse files
authored
feat: Fine tune events stream; Most complete request-response Rust types (#28)
* Client with SSE stream functionality; use that in fine-tunes list events stream and in create completions stream * update readme * updated types: the most complete and correct-bug-free types yet * Updated examples with updated types * updated readme and src/lib.rs doc comment * updated comment * remove 'use async_openai as openai' from everywhere
1 parent dd2e624 commit 3353700

File tree

18 files changed

+240
-127
lines changed

18 files changed

+240
-127
lines changed

async-openai/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
<img src="https://docs.rs/async-openai/badge.svg" />
1414
</a>
1515
</div>
16+
<div align="center">
17+
<sub>Logo created by this <a href="https://github.com/64bit/async-openai/tree/main/examples/create-image-b64-json">repo itself</a></sub>
18+
</div>
1619

1720
## Overview
1821

@@ -24,7 +27,7 @@
2427
- [x] Edits
2528
- [x] Embeddings
2629
- [x] Files (List, Upload, Delete, Retrieve, Retrieve Content)
27-
- [x] Fine-Tuning
30+
- [x] Fine-Tuning (including SSE streaming Fine-tuning events)
2831
- [x] Images (Generation, Edit, Variation)
2932
- [ ] Microsoft Azure Endpoints / AD Authentication
3033
- [x] Models
@@ -55,8 +58,7 @@ $Env:OPENAI_API_KEY='sk-...'
5558
```rust
5659
use std::error::Error;
5760

58-
use async_openai as openai;
59-
use openai::{
61+
use async_openai::{
6062
types::{CreateImageRequest, ImageSize, ResponseFormat},
6163
Client, Image,
6264
};

async-openai/src/client.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
use std::pin::Pin;
2+
3+
use futures::{stream::StreamExt, Stream};
14
use reqwest::header::HeaderMap;
5+
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
26
use serde::{de::DeserializeOwned, Serialize};
37

48
use crate::error::{OpenAIError, WrappedError};
@@ -212,4 +216,92 @@ impl Client {
212216
}
213217
}
214218
}
219+
220+
/// Make HTTP POST request to receive SSE
221+
pub(crate) async fn post_stream<I, O>(
222+
&self,
223+
path: &str,
224+
request: I,
225+
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
226+
where
227+
I: Serialize,
228+
O: DeserializeOwned + std::marker::Send + 'static,
229+
{
230+
let event_source = reqwest::Client::new()
231+
.post(format!("{}{path}", self.api_base()))
232+
.headers(self.headers())
233+
.bearer_auth(self.api_key())
234+
.json(&request)
235+
.eventsource()
236+
.unwrap();
237+
238+
Client::stream(event_source).await
239+
}
240+
241+
/// Make HTTP GET request to receive SSE
242+
pub(crate) async fn get_stream<Q, O>(
243+
&self,
244+
path: &str,
245+
query: &Q,
246+
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
247+
where
248+
Q: Serialize + ?Sized,
249+
O: DeserializeOwned + std::marker::Send + 'static,
250+
{
251+
let event_source = reqwest::Client::new()
252+
.get(format!("{}{path}", self.api_base()))
253+
.query(query)
254+
.headers(self.headers())
255+
.bearer_auth(self.api_key())
256+
.eventsource()
257+
.unwrap();
258+
259+
Client::stream(event_source).await
260+
}
261+
262+
/// Request which responds with SSE.
263+
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
264+
pub(crate) async fn stream<O>(
265+
mut event_source: EventSource,
266+
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
267+
where
268+
O: DeserializeOwned + std::marker::Send + 'static,
269+
{
270+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
271+
272+
tokio::spawn(async move {
273+
while let Some(ev) = event_source.next().await {
274+
match ev {
275+
Err(e) => {
276+
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
277+
// rx dropped
278+
break;
279+
}
280+
}
281+
Ok(event) => match event {
282+
Event::Message(message) => {
283+
if message.data == "[DONE]" {
284+
break;
285+
}
286+
287+
let response = match serde_json::from_str::<O>(&message.data) {
288+
Err(e) => Err(OpenAIError::JSONDeserialize(e)),
289+
Ok(output) => Ok(output),
290+
};
291+
292+
if let Err(_e) = tx.send(response) {
293+
// rx dropped
294+
break;
295+
}
296+
}
297+
Event::Open => continue,
298+
},
299+
}
300+
}
301+
302+
event_source.close();
303+
});
304+
305+
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
306+
}
215307
}

async-openai/src/completion.rs

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
use futures::stream::StreamExt;
2-
use reqwest_eventsource::{Event, RequestBuilderExt};
3-
41
use crate::{
52
client::Client,
63
error::OpenAIError,
@@ -45,53 +42,6 @@ impl Completion {
4542

4643
request.stream = Some(true);
4744

48-
let mut event_source = reqwest::Client::new()
49-
.post(format!("{}/completions", client.api_base()))
50-
.bearer_auth(client.api_key())
51-
.json(&request)
52-
.eventsource()
53-
.unwrap();
54-
55-
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
56-
57-
tokio::spawn(async move {
58-
while let Some(ev) = event_source.next().await {
59-
match ev {
60-
Err(e) => {
61-
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
62-
// rx dropped
63-
break;
64-
}
65-
}
66-
Ok(event) => match event {
67-
Event::Message(message) => {
68-
if message.data == "[DONE]" {
69-
break;
70-
}
71-
72-
let response = match serde_json::from_str::<CreateCompletionResponse>(
73-
&message.data,
74-
) {
75-
Err(e) => Err(OpenAIError::JSONDeserialize(e)),
76-
Ok(ccr) => Ok(ccr),
77-
};
78-
79-
if let Err(_e) = tx.send(response) {
80-
// rx dropped
81-
break;
82-
}
83-
}
84-
Event::Open => continue,
85-
},
86-
}
87-
}
88-
89-
event_source.close();
90-
});
91-
92-
Ok(
93-
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
94-
as CompletionResponseStream,
95-
)
45+
Ok(client.post_stream("/completions", request).await)
9646
}
9747
}

async-openai/src/embedding.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ mod tests {
7979
input: crate::types::EmbeddingInput::ArrayOfIntegerArray(vec![
8080
vec![1, 2, 3],
8181
vec![4, 5, 6],
82-
vec![7, 8, 9],
82+
vec![7, 8, 100257],
8383
]),
8484
..Default::default()
8585
};

async-openai/src/fine_tune.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::{
22
error::OpenAIError,
3-
types::{CreateFineTuneRequest, FineTune, ListFineTuneEventsResponse, ListFineTuneResponse},
3+
types::{
4+
CreateFineTuneRequest, FineTune, FineTuneEventsResponseStream, ListFineTuneEventsResponse,
5+
ListFineTuneResponse,
6+
},
47
Client,
58
};
69

@@ -52,4 +55,20 @@ impl FineTunes {
5255
.get(format!("/fine-tunes/{fine_tune_id}/events").as_str())
5356
.await
5457
}
58+
59+
/// Get fine-grained status updates for a fine-tune job.
60+
///
61+
/// Stream fine tuning events. [FineTuneEventsResponseStream] is a parsed SSE
62+
/// stream until a \[DONE\] is received from server.
63+
pub async fn list_events_stream(
64+
client: &Client,
65+
fine_tune_id: &str,
66+
) -> Result<FineTuneEventsResponseStream, OpenAIError> {
67+
Ok(client
68+
.get_stream(
69+
format!("/fine-tunes/{fine_tune_id}/events").as_str(),
70+
&[("stream", true)],
71+
)
72+
.await)
73+
}
5574
}

async-openai/src/lib.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,32 @@
33
//! ## Creating client
44
//!
55
//! ```
6-
//! use async_openai as openai;
6+
//! use async_openai::Client;
77
//!
88
//! // Create a client with api key from env var OPENAI_API_KEY and default base url.
9-
//! let client = openai::Client::new();
9+
//! let client = Client::new();
1010
//!
1111
//! // OR use API key from different source
1212
//! let api_key = "sk-..."; // This secret could be from a file, or environment variable.
13-
//! let client = openai::Client::new().with_api_key(api_key);
13+
//! let client = Client::new().with_api_key(api_key);
1414
//!
15-
//! // Use organization different then default when making requests
16-
//! let client = openai::Client::new().with_org_id("the-org");
15+
//! // Use organization other than default when making requests
16+
//! let client = Client::new().with_org_id("the-org");
1717
//! ```
1818
//!
1919
//! ## Making requests
2020
//!
2121
//!```
2222
//!# tokio_test::block_on(async {
23-
//! use async_openai as openai;
24-
//! use openai::{Client, Completion, types::{CreateCompletionRequest}};
23+
//!
24+
//! use async_openai::{Client, Completion, types::{CreateCompletionRequest, Prompt}};
2525
//!
2626
//! // Create client
2727
//! let client = Client::new();
2828
//! // Create request
2929
//! let request = CreateCompletionRequest {
3030
//! model: "text-davinci-003".to_owned(),
31-
//! prompt: Some("Tell me a joke about the universe".to_owned()),
31+
//! prompt: Some(Prompt::String("Tell me the recipe of alfredo pasta".to_owned())),
3232
//! ..Default::default()
3333
//! };
3434
//! // Call API

0 commit comments

Comments
 (0)