Skip to content

Commit 323d430

Browse files
committed
新增 ChatMessageAccumulatorWrapper,以便于能够在使用chatCompletion时使用原始 chunk 数据。这一需求主要适用于chat API的转发场景。
1 parent 05f4673 commit 323d430

File tree

10 files changed

+109
-13
lines changed

10 files changed

+109
-13
lines changed

README-zh.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ OpenAi4J是一个非官方的Java库,旨在帮助java开发者与OpenAI的GPT
2121
## 导入依赖
2222
### Gradle
2323

24-
`implementation 'io.github.lambdua:<api|client|service>:0.22.3'`
24+
`implementation 'io.github.lambdua:<api|client|service>:0.22.4'`
2525
### Maven
2626
```xml
2727

2828
<dependency>
2929
<groupId>io.github.lambdua</groupId>
3030
<artifactId>service</artifactId>
31-
<version>0.22.3</version>
31+
<version>0.22.4</version>
3232
</dependency>
3333
```
3434

@@ -61,7 +61,7 @@ static void simpleChat() {
6161
<dependency>
6262
<groupId>io.github.lambdua</groupId>
6363
<artifactId>api</artifactId>
64-
<version>0.22.3</version>
64+
<version>0.22.4</version>
6565
</dependency>
6666
```
6767

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ applications effortlessly.
2525
## Import
2626
### Gradle
2727

28-
`implementation 'io.github.lambdua:<api|client|service>:0.22.3'`
28+
`implementation 'io.github.lambdua:<api|client|service>:0.22.4'`
2929
### Maven
3030
```xml
3131

3232
<dependency>
3333
<groupId>io.github.lambdua</groupId>
3434
<artifactId>service</artifactId>
35-
<version>0.22.3</version>
35+
<version>0.22.4</version>
3636
</dependency>
3737
```
3838

@@ -67,7 +67,7 @@ To utilize pojos, import the api module:
6767
<dependency>
6868
<groupId>io.github.lambdua</groupId>
6969
<artifactId>api</artifactId>
70-
<version>0.22.3</version>
70+
<version>0.22.4</version>
7171
</dependency>
7272
```
7373

api/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>io.github.lambdua</groupId>
88
<artifactId>openai-java</artifactId>
9-
<version>0.22.3</version>
9+
<version>0.22.4</version>
1010
</parent>
1111
<packaging>jar</packaging>
1212
<artifactId>api</artifactId>

api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChunk.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.theokanning.openai.completion.chat;
22

3+
import com.fasterxml.jackson.annotation.JsonIgnore;
34
import com.fasterxml.jackson.annotation.JsonProperty;
45
import com.theokanning.openai.Usage;
56
import lombok.Data;
@@ -51,4 +52,13 @@ public class ChatCompletionChunk {
5152
*/
5253
Usage usage;
5354

55+
/**
56+
* The original data packet returned by chat completion.
57+
* the value like this:
58+
* <pre>
59+
* data:{"id":"chatcmpl-A0QiHfuacgBSbvd8Ld1Por1HojY31","object":"chat.completion.chunk","created":1724666049,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
60+
* </pre>
61+
*/
62+
@JsonIgnore
63+
String source;
5464
}

client/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>io.github.lambdua</groupId>
88
<artifactId>openai-java</artifactId>
9-
<version>0.22.3</version>
9+
<version>0.22.4</version>
1010
</parent>
1111
<packaging>jar</packaging>
1212

example/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>io.github.lambdua</groupId>
88
<artifactId>example</artifactId>
9-
<version>0.22.3</version>
9+
<version>0.22.4</version>
1010
<name>example</name>
1111

1212
<properties>
@@ -17,7 +17,7 @@
1717
<dependency>
1818
<groupId>io.github.lambdua</groupId>
1919
<artifactId>service</artifactId>
20-
<version>0.22.3</version>
20+
<version>0.22.4</version>
2121
</dependency>
2222

2323
</dependencies>

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<groupId>io.github.lambdua</groupId>
77
<artifactId>openai-java</artifactId>
8-
<version>0.22.3</version>
8+
<version>0.22.4</version>
99
<packaging>pom</packaging>
1010
<description>openai java 版本</description>
1111
<url>https://github.com/Lambdua/openai-java</url>

service/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>io.github.lambdua</groupId>
88
<artifactId>openai-java</artifactId>
9-
<version>0.22.3</version>
9+
<version>0.22.4</version>
1010
</parent>
1111
<packaging>jar</packaging>
1212

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.theokanning.openai.service;
2+
3+
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
4+
5+
/**
6+
* Wrapper class of ChatMessageAccumulator
7+
*
8+
* @author Allen Hu
9+
* @date 2024/10/18
10+
*/
11+
public class ChatMessageAccumulatorWrapper {
12+
13+
private final ChatMessageAccumulator chatMessageAccumulator;
14+
private final ChatCompletionChunk chatCompletionChunk;
15+
16+
public ChatMessageAccumulatorWrapper(ChatMessageAccumulator chatMessageAccumulator, ChatCompletionChunk chatCompletionChunk) {
17+
this.chatMessageAccumulator = chatMessageAccumulator;
18+
this.chatCompletionChunk = chatCompletionChunk;
19+
}
20+
21+
public ChatMessageAccumulator getChatMessageAccumulator() {
22+
return chatMessageAccumulator;
23+
}
24+
25+
public ChatCompletionChunk getChatCompletionChunk() {
26+
return chatCompletionChunk;
27+
}
28+
}

service/src/main/java/com/theokanning/openai/service/OpenAiService.java

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.theokanning.openai.service;
22

33
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.core.JsonProcessingException;
45
import com.fasterxml.jackson.core.type.TypeReference;
56
import com.fasterxml.jackson.databind.DeserializationFeature;
67
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -72,6 +73,8 @@
7273
import java.util.*;
7374
import java.util.concurrent.ExecutorService;
7475
import java.util.concurrent.TimeUnit;
76+
import java.util.function.BiConsumer;
77+
import java.util.function.Supplier;
7578

7679
public class OpenAiService {
7780

@@ -190,7 +193,17 @@ public ChatCompletionResult createChatCompletion(ChatCompletionRequest request)
190193

191194
public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
192195
request.setStream(true);
193-
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
196+
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class, new BiConsumer<ChatCompletionChunk, SSE>() {
197+
@Override
198+
public void accept(ChatCompletionChunk chatCompletionChunk, SSE sse) {
199+
chatCompletionChunk.setSource(sse.getData());
200+
}
201+
}, new Supplier<ChatCompletionChunk>() {
202+
@Override
203+
public ChatCompletionChunk get() {
204+
return new ChatCompletionChunk();
205+
}
206+
});
194207
}
195208

196209

@@ -692,6 +705,31 @@ public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
692705
return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
693706
}
694707

708+
/**
709+
* Calls the Open AI api and returns a Flowable of type T for streaming
710+
* omitting the last message.
711+
* @param apiCall The api call
712+
* @param cl Class of type T to return
713+
* @param consumer After the instance creation is complete
714+
* @param newInstance If the serialization fails, call this interface to get an instance
715+
*/
716+
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl, BiConsumer<T, SSE> consumer,
717+
Supplier<T> newInstance) {
718+
return stream(apiCall, true).map(sse -> {
719+
try {
720+
T t = mapper.readValue(sse.getData(), cl);
721+
if (Objects.nonNull(consumer)) {
722+
consumer.accept(t, sse);
723+
}
724+
return t;
725+
} catch (JsonProcessingException e) {
726+
T t = newInstance.get();
727+
consumer.accept(t, sse);
728+
return t;
729+
}
730+
});
731+
}
732+
695733
/**
696734
* Shuts down the OkHttp ExecutorService.
697735
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
@@ -758,6 +796,26 @@ public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatComp
758796
});
759797
}
760798

799+
public Flowable<ChatMessageAccumulatorWrapper> mapStreamToAccumulatorWrapper(Flowable<ChatCompletionChunk> flowable) {
800+
ChatFunctionCall functionCall = new ChatFunctionCall(null, null);
801+
AssistantMessage accumulatedMessage = new AssistantMessage();
802+
return flowable.map(chunk -> {
803+
List<ChatCompletionChoice> choices = chunk.getChoices();
804+
AssistantMessage messageChunk = null;
805+
if (null != choices && !choices.isEmpty()) {
806+
ChatCompletionChoice firstChoice = choices.get(0);
807+
messageChunk = firstChoice.getMessage();
808+
appendContent(messageChunk, accumulatedMessage);
809+
processFunctionCall(messageChunk, functionCall, accumulatedMessage);
810+
processToolCalls(messageChunk, accumulatedMessage);
811+
if (firstChoice.getFinishReason() != null) {
812+
handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage);
813+
}
814+
}
815+
ChatMessageAccumulator chatMessageAccumulator = new ChatMessageAccumulator(messageChunk, accumulatedMessage, chunk.getUsage());
816+
return new ChatMessageAccumulatorWrapper(chatMessageAccumulator, chunk);
817+
});
818+
}
761819

762820
/**
763821
* 处理消息块中的函数调用。

0 commit comments

Comments
 (0)