Skip to content

Commit d96faab

Browse files
luketpeterson64bit
andauthored
Adding Send bound to boxed future types (#37)
* Adding Send bound to boxed future types so that they will work in a multithreaded runtime * Update async-openai/tests/boxed_future.rs --------- Co-authored-by: Himanshu Neema <himanshun.iitkgp@gmail.com>
1 parent 132e1c5 commit d96faab

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

async-openai/src/client.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ impl Client {
279279
&self,
280280
path: &str,
281281
request: I,
282-
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
282+
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
283283
where
284284
I: Serialize,
285285
O: DeserializeOwned + std::marker::Send + 'static,
@@ -300,7 +300,7 @@ impl Client {
300300
&self,
301301
path: &str,
302302
query: &Q,
303-
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
303+
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
304304
where
305305
Q: Serialize + ?Sized,
306306
O: DeserializeOwned + std::marker::Send + 'static,
@@ -320,7 +320,7 @@ impl Client {
320320
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
321321
pub(crate) async fn stream<O>(
322322
mut event_source: EventSource,
323-
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
323+
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
324324
where
325325
O: DeserializeOwned + std::marker::Send + 'static,
326326
{

async-openai/src/types/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ pub struct CreateCompletionResponse {
169169

170170
/// Parsed server side events stream until an \[DONE\] is received from server.
171171
pub type CompletionResponseStream =
172-
Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>>>>;
172+
Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>;
173173

174174
#[derive(Debug, Clone, Serialize, Default, Builder)]
175175
#[builder(name = "CreateEditRequestArgs")]
@@ -616,7 +616,7 @@ pub struct ListFineTuneEventsResponse {
616616

617617
/// Parsed server side events stream until an \[DONE\] is received from server.
618618
pub type FineTuneEventsResponseStream =
619-
Pin<Box<dyn Stream<Item = Result<ListFineTuneEventsResponse, OpenAIError>>>>;
619+
Pin<Box<dyn Stream<Item = Result<ListFineTuneEventsResponse, OpenAIError>> + Send>>;
620620

621621
#[derive(Debug, Deserialize)]
622622
pub struct DeleteModelResponse {

async-openai/tests/boxed_future.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
use futures::StreamExt;
3+
use futures::future::{BoxFuture, FutureExt};
4+
5+
use async_openai::Client;
6+
use async_openai::types::{CompletionResponseStream, CreateCompletionRequestArgs};
7+
8+
#[tokio::test]
9+
async fn boxed_future_test() {
10+
11+
fn interpret_bool(token_stream: &mut CompletionResponseStream) -> BoxFuture<'_, bool> {
12+
async move {
13+
while let Some(response) = token_stream.next().await {
14+
match response {
15+
Ok(response) => {
16+
let token_str = &response.choices[0].text.trim();
17+
if !token_str.is_empty() {
18+
return token_str.contains("yes") || token_str.contains("Yes");
19+
}
20+
},
21+
Err(e) => eprintln!("Error: {e}"),
22+
}
23+
}
24+
false
25+
}.boxed()
26+
}
27+
28+
let client = Client::new();
29+
30+
let request = CreateCompletionRequestArgs::default()
31+
.model("text-babbage-001")
32+
.n(1)
33+
.prompt("does 2 and 2 add to four? (yes/no):\n")
34+
.stream(true)
35+
.logprobs(3)
36+
.max_tokens(64_u16)
37+
.build().unwrap();
38+
39+
let mut stream = client.completions().create_stream(request).await.unwrap();
40+
41+
let result = interpret_bool(&mut stream).await;
42+
assert_eq!(result, true);
43+
44+
}

0 commit comments

Comments
 (0)