Skip to content

Commit b326ca7

Browse files
authored
Support for ChatGPT API (#43)
* Add support for Chat requests * make Message buildable * Ensure proper casting of MessageRole
1 parent 4bec9f9 commit b326ca7

File tree

7 files changed

+192
-14
lines changed

7 files changed

+192
-14
lines changed

async-openai/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
- It's based on [OpenAI OpenAPI spec](https://github.com/openai/openai-openapi)
2525
- Current features:
26-
- [x] Completions (including SSE streaming)
26+
- [x] Completions (including SSE streaming & Chat)
2727
- [x] Edits
2828
- [x] Embeddings
2929
- [x] Files
@@ -35,7 +35,7 @@
3535
- Non-streaming requests are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server.
3636
- Ergonomic Rust library with builder pattern for all request objects.
3737

38-
*Being a young project there could be rough edges.*
38+
_Being a young project there could be rough edges._
3939

4040
## Usage
4141

async-openai/src/chat.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use crate::{
2+
client::Client,
3+
error::OpenAIError,
4+
types::{ChatResponseStream, CreateChatRequest, CreateChatResponse},
5+
};
6+
7+
/// Given a series of messages, the model will return one or more predicted
8+
/// completion messages.
9+
pub struct Chat<'c> {
10+
client: &'c Client,
11+
}
12+
13+
impl<'c> Chat<'c> {
14+
pub fn new(client: &'c Client) -> Self {
15+
Self { client }
16+
}
17+
18+
/// Creates a completion for the provided messages and parameters
19+
pub async fn create(
20+
&self,
21+
request: CreateChatRequest,
22+
) -> Result<CreateChatResponse, OpenAIError> {
23+
if request.stream.is_some() && request.stream.unwrap() {
24+
return Err(OpenAIError::InvalidArgument(
25+
"When stream is true, use Chat::create_stream".into(),
26+
));
27+
}
28+
self.client.post("/chat/completions", request).await
29+
}
30+
31+
/// Creates a completion request for the provided messages and parameters
32+
///
33+
/// Stream back partial progress. Tokens will be sent as data-only
34+
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
35+
/// as they become available, with the stream terminated by a data: \[DONE\] message.
36+
///
37+
/// [ChatResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
38+
pub async fn create_stream(
39+
&self,
40+
mut request: CreateChatRequest,
41+
) -> Result<ChatResponseStream, OpenAIError> {
42+
if request.stream.is_some() && !request.stream.unwrap() {
43+
return Err(OpenAIError::InvalidArgument(
44+
"When stream is false, use Chat::create".into(),
45+
));
46+
}
47+
48+
request.stream = Some(true);
49+
50+
Ok(self.client.post_stream("/chat/completions", request).await)
51+
}
52+
}

async-openai/src/client.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
66
use serde::{de::DeserializeOwned, Serialize};
77

88
use crate::{
9+
chat::Chat,
910
edit::Edits,
1011
error::{OpenAIError, WrappedError},
1112
file::Files,
@@ -91,6 +92,11 @@ impl Client {
9192
Completions::new(self)
9293
}
9394

95+
/// To call [Chat] group related APIs using this client.
96+
pub fn chat(&self) -> Chat {
97+
Chat::new(self)
98+
}
99+
94100
/// To call [Edits] group related APIs using this client.
95101
pub fn edits(&self) -> Edits {
96102
Edits::new(self)

async-openai/src/file.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,6 @@ mod tests {
104104
let delete_response = client.files().delete(&openai_file.id).await.unwrap();
105105

106106
assert_eq!(openai_file.id, delete_response.id);
107-
assert_eq!(delete_response.deleted, true);
107+
assert!(delete_response.deleted);
108108
}
109109
}

async-openai/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
//! ## Examples
4949
//! For full working examples for all supported features see [examples](https://github.com/64bit/async-openai/tree/main/examples) directory in the repository.
5050
//!
51+
mod chat;
5152
mod client;
5253
mod completion;
5354
mod download;

async-openai/src/types/types.rs

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
use std::{collections::HashMap, path::PathBuf, pin::Pin};
1+
use std::{
2+
collections::HashMap,
3+
fmt::{Display, Formatter},
4+
path::PathBuf,
5+
pin::Pin,
6+
};
27

38
use derive_builder::Builder;
49
use futures::Stream;
@@ -134,6 +139,101 @@ pub struct CreateCompletionRequest {
134139
pub user: Option<String>,
135140
}
136141

142+
#[derive(Clone, Serialize, Debug, Deserialize)]
143+
pub enum MessageRole {
144+
#[serde(rename = "assistant")]
145+
Assistant,
146+
#[serde(rename = "system")]
147+
System,
148+
#[serde(rename = "user")]
149+
User,
150+
}
151+
152+
#[derive(Clone, Serialize, Deserialize, Debug, Builder)]
153+
#[builder(name = "MessageArgs")]
154+
#[builder(pattern = "mutable")]
155+
#[builder(derive(Debug))]
156+
#[builder(build_fn(error = "OpenAIError"))]
157+
pub struct Message {
158+
pub role: MessageRole,
159+
pub content: String,
160+
}
161+
162+
#[derive(Clone, Serialize, Default, Debug, Builder)]
163+
#[builder(name = "CreateChatRequestArgs")]
164+
#[builder(pattern = "mutable")]
165+
#[builder(setter(into, strip_option), default)]
166+
#[builder(derive(Debug))]
167+
#[builder(build_fn(error = "OpenAIError"))]
168+
pub struct CreateChatRequest {
169+
/// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
170+
pub model: String,
171+
172+
/// The message(s) to generate a response to, encoded as an array of the message type.
173+
///
174+
/// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.
175+
#[serde(skip_serializing_if = "Option::is_none")]
176+
pub messages: Option<Vec<Message>>,
177+
178+
/// What [sampling temperature](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277) to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
179+
///
180+
/// We generally recommend altering this or `top_p` but not both.
181+
#[serde(skip_serializing_if = "Option::is_none")]
182+
pub temperature: Option<f32>, // min: 0, max: 2, default: 1,
183+
184+
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
185+
///
186+
/// We generally recommend altering this or `temperature` but not both.
187+
#[serde(skip_serializing_if = "Option::is_none")]
188+
pub top_p: Option<f32>, // min: 0, max: 1, default: 1
189+
190+
/// How many completions to generate for each prompt.
191+
192+
/// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
193+
///
194+
#[serde(skip_serializing_if = "Option::is_none")]
195+
pub n: Option<u8>, // min:1 max: 128, default: 1
196+
197+
/// Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
198+
/// as they become available, with the stream terminated by a `data: [DONE]` message.
199+
#[serde(skip_serializing_if = "Option::is_none")]
200+
pub stream: Option<bool>, // nullable: true
201+
202+
/// Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.
203+
204+
/// The maximum value for `logprobs` is 5. If you need more than this, please contact us through our [Help center](https://help.openai.com) and describe your use case.
205+
#[serde(skip_serializing_if = "Option::is_none")]
206+
pub logprobs: Option<u8>, // min:0 , max: 5, default: null, nullable: true
207+
208+
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
209+
#[serde(skip_serializing_if = "Option::is_none")]
210+
pub stop: Option<Stop>,
211+
212+
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
213+
///
214+
/// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
215+
#[serde(skip_serializing_if = "Option::is_none")]
216+
pub presence_penalty: Option<f32>, // min: -2.0, max: 2.0, default 0
217+
218+
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
219+
///
220+
/// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
221+
#[serde(skip_serializing_if = "Option::is_none")]
222+
pub frequency_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
223+
224+
/// Modify the likelihood of specified tokens appearing in the completion.
225+
///
226+
/// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
227+
///
228+
/// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.
229+
#[serde(skip_serializing_if = "Option::is_none")]
230+
pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
231+
232+
/// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
233+
#[serde(skip_serializing_if = "Option::is_none")]
234+
pub user: Option<String>,
235+
}
236+
137237
#[derive(Debug, Deserialize)]
138238
pub struct Logprobs {
139239
pub tokens: Vec<String>,
@@ -150,6 +250,13 @@ pub struct Choice {
150250
pub finish_reason: Option<String>,
151251
}
152252

253+
#[derive(Debug, Deserialize)]
254+
pub struct ChatChoice {
255+
pub message: Message,
256+
pub index: u32,
257+
pub finish_reason: Option<String>,
258+
}
259+
153260
#[derive(Debug, Deserialize)]
154261
pub struct Usage {
155262
pub prompt_tokens: u32,
@@ -167,10 +274,23 @@ pub struct CreateCompletionResponse {
167274
pub usage: Option<Usage>,
168275
}
169276

277+
#[derive(Debug, Deserialize)]
278+
pub struct CreateChatResponse {
279+
pub id: String,
280+
pub object: String,
281+
pub created: u32,
282+
pub choices: Vec<ChatChoice>,
283+
pub usage: Option<Usage>,
284+
}
285+
170286
/// Parsed server side events stream until an \[DONE\] is received from server.
171287
pub type CompletionResponseStream =
172288
Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>;
173289

290+
/// Parsed server side events stream until an \[DONE\] is received from server.
291+
pub type ChatResponseStream =
292+
Pin<Box<dyn Stream<Item = Result<CreateChatResponse, OpenAIError>> + Send>>;
293+
174294
#[derive(Debug, Clone, Serialize, Default, Builder)]
175295
#[builder(name = "CreateEditRequestArgs")]
176296
#[builder(pattern = "mutable")]

async-openai/tests/boxed_future.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
2-
use futures::StreamExt;
31
use futures::future::{BoxFuture, FutureExt};
2+
use futures::StreamExt;
43

5-
use async_openai::Client;
64
use async_openai::types::{CompletionResponseStream, CreateCompletionRequestArgs};
5+
use async_openai::Client;
76

87
#[tokio::test]
98
async fn boxed_future_test() {
10-
119
fn interpret_bool(token_stream: &mut CompletionResponseStream) -> BoxFuture<'_, bool> {
1210
async move {
1311
while let Some(response) = token_stream.next().await {
@@ -17,12 +15,13 @@ async fn boxed_future_test() {
1715
if !token_str.is_empty() {
1816
return token_str.contains("yes") || token_str.contains("Yes");
1917
}
20-
},
18+
}
2119
Err(e) => eprintln!("Error: {e}"),
2220
}
2321
}
2422
false
25-
}.boxed()
23+
}
24+
.boxed()
2625
}
2726

2827
let client = Client::new();
@@ -34,11 +33,11 @@ async fn boxed_future_test() {
3433
.stream(true)
3534
.logprobs(3)
3635
.max_tokens(64_u16)
37-
.build().unwrap();
36+
.build()
37+
.unwrap();
3838

3939
let mut stream = client.completions().create_stream(request).await.unwrap();
4040

4141
let result = interpret_bool(&mut stream).await;
42-
assert_eq!(result, true);
43-
44-
}
42+
assert!(result);
43+
}

0 commit comments

Comments
 (0)