1
1
import time
2
- from typing import TYPE_CHECKING , AsyncGenerator , Dict , Union
2
+ from typing import TYPE_CHECKING , AsyncGenerator , Dict , Generator , Union
3
3
4
4
from literalai .instrumentation import MISTRALAI_PROVIDER
5
5
from literalai .requirements import check_all_requirements
6
6
7
7
if TYPE_CHECKING :
8
8
from literalai .client import LiteralClient
9
9
10
- from types import GeneratorType
11
-
12
10
from literalai .context import active_steps_var , active_thread_var
13
11
from literalai .helper import ensure_values_serializable
14
12
from literalai .observability .generation import (
23
21
24
22
APIS_TO_WRAP = [
25
23
{
26
- "module" : "mistralai" ,
27
- "object" : "Mistral " ,
28
- "method" : "chat. complete" ,
24
+ "module" : "mistralai.chat " ,
25
+ "object" : "Chat " ,
26
+ "method" : "complete" ,
29
27
"metadata" : {
30
28
"type" : GenerationType .CHAT ,
31
29
},
32
30
"async" : False ,
33
31
},
34
32
{
35
- "module" : "mistralai" ,
36
- "object" : "Mistral " ,
37
- "method" : "chat. stream" ,
33
+ "module" : "mistralai.chat " ,
34
+ "object" : "Chat " ,
35
+ "method" : "stream" ,
38
36
"metadata" : {
39
37
"type" : GenerationType .CHAT ,
40
38
},
41
39
"async" : False ,
42
40
},
43
41
{
44
- "module" : "mistralai" ,
45
- "object" : "Mistral " ,
46
- "method" : "chat. complete_async" ,
42
+ "module" : "mistralai.chat " ,
43
+ "object" : "Chat " ,
44
+ "method" : "complete_async" ,
47
45
"metadata" : {
48
46
"type" : GenerationType .CHAT ,
49
47
},
50
48
"async" : True ,
51
49
},
52
50
{
53
- "module" : "mistralai" ,
54
- "object" : "Mistral " ,
55
- "method" : "chat. stream_async" ,
51
+ "module" : "mistralai.chat " ,
52
+ "object" : "Chat " ,
53
+ "method" : "stream_async" ,
56
54
"metadata" : {
57
55
"type" : GenerationType .CHAT ,
58
56
},
59
57
"async" : True ,
60
58
},
61
59
{
62
- "module" : "mistralai" ,
63
- "object" : "Mistral " ,
64
- "method" : "fim. complete" ,
60
+ "module" : "mistralai.fim " ,
61
+ "object" : "Fim " ,
62
+ "method" : "complete" ,
65
63
"metadata" : {
66
64
"type" : GenerationType .COMPLETION ,
67
65
},
68
66
"async" : False ,
69
67
},
70
68
{
71
- "module" : "mistralai" ,
72
- "object" : "Mistral " ,
73
- "method" : "fim. stream" ,
69
+ "module" : "mistralai.fim " ,
70
+ "object" : "Fim " ,
71
+ "method" : "stream" ,
74
72
"metadata" : {
75
73
"type" : GenerationType .COMPLETION ,
76
74
},
77
75
"async" : False ,
78
76
},
79
77
{
80
- "module" : "mistralai" ,
81
- "object" : "Mistral " ,
82
- "method" : "fim. complete_async" ,
78
+ "module" : "mistralai.fim " ,
79
+ "object" : "Fim " ,
80
+ "method" : "complete_async" ,
83
81
"metadata" : {
84
82
"type" : GenerationType .COMPLETION ,
85
83
},
86
84
"async" : True ,
87
85
},
88
86
{
89
- "module" : "mistralai" ,
90
- "object" : "Mistral " ,
91
- "method" : "fim. stream_async" ,
87
+ "module" : "mistralai.fim " ,
88
+ "object" : "Fim " ,
89
+ "method" : "stream_async" ,
92
90
"metadata" : {
93
91
"type" : GenerationType .COMPLETION ,
94
92
},
@@ -278,9 +276,11 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage
278
276
else :
279
277
return False
280
278
279
+ from mistralai import models
280
+
281
281
def streaming_response (
282
282
generation : Union [ChatGeneration , CompletionGeneration ],
283
- result : GeneratorType ,
283
+ result : Generator [ models . CompletionEvent , None , None ] ,
284
284
context : AfterContext ,
285
285
):
286
286
completion = ""
@@ -291,8 +291,8 @@ def streaming_response(
291
291
token_count = 0
292
292
for chunk in result :
293
293
if generation and isinstance (generation , ChatGeneration ):
294
- if len (chunk .choices ) > 0 :
295
- ok = process_delta (chunk .choices [0 ].delta , message_completion )
294
+ if len (chunk .data . choices ) > 0 :
295
+ ok = process_delta (chunk .data . choices [0 ].delta , message_completion )
296
296
if not ok :
297
297
yield chunk
298
298
continue
@@ -363,7 +363,7 @@ def after(result, context: AfterContext, *args, **kwargs):
363
363
generation .model = model
364
364
if generation .settings :
365
365
generation .settings ["model" ] = model
366
- if isinstance (result , GeneratorType ):
366
+ if isinstance (result , Generator ):
367
367
return streaming_response (generation , result , context )
368
368
else :
369
369
generation .duration = time .time () - context ["start" ]
@@ -392,7 +392,7 @@ def after(result, context: AfterContext, *args, **kwargs):
392
392
393
393
async def async_streaming_response (
394
394
generation : Union [ChatGeneration , CompletionGeneration ],
395
- result : AsyncGenerator ,
395
+ result : AsyncGenerator [ models . CompletionEvent , None ] ,
396
396
context : AfterContext ,
397
397
):
398
398
completion = ""
@@ -403,8 +403,8 @@ async def async_streaming_response(
403
403
token_count = 0
404
404
async for chunk in result :
405
405
if generation and isinstance (generation , ChatGeneration ):
406
- if len (chunk .choices ) > 0 :
407
- ok = process_delta (chunk .choices [0 ].delta , message_completion )
406
+ if len (chunk .data . choices ) > 0 :
407
+ ok = process_delta (chunk .data . choices [0 ].delta , message_completion )
408
408
if not ok :
409
409
yield chunk
410
410
continue
0 commit comments