1
1
import time
2
- from typing import TYPE_CHECKING , AsyncGenerator , Dict , Union
2
+ from typing import TYPE_CHECKING , AsyncGenerator , Dict , Generator , Union
3
3
4
4
from literalai .instrumentation import MISTRALAI_PROVIDER
5
5
from literalai .requirements import check_all_requirements
6
6
7
7
if TYPE_CHECKING :
8
8
from literalai .client import LiteralClient
9
9
10
- from types import GeneratorType
11
-
12
10
from literalai .context import active_steps_var , active_thread_var
13
11
from literalai .helper import ensure_values_serializable
14
- from literalai .observability .generation import GenerationMessage , CompletionGeneration , ChatGeneration , GenerationType
12
+ from literalai .observability .generation import (
13
+ ChatGeneration ,
14
+ CompletionGeneration ,
15
+ GenerationMessage ,
16
+ GenerationType ,
17
+ )
15
18
from literalai .wrappers import AfterContext , BeforeContext , wrap_all
16
19
17
- REQUIREMENTS = ["mistralai>=0.2 .0" ]
20
+ REQUIREMENTS = ["mistralai>=1.0 .0" ]
18
21
19
22
APIS_TO_WRAP = [
20
23
{
21
- "module" : "mistralai.client " ,
22
- "object" : "MistralClient " ,
23
- "method" : "chat " ,
24
+ "module" : "mistralai.chat " ,
25
+ "object" : "Chat " ,
26
+ "method" : "complete " ,
24
27
"metadata" : {
25
28
"type" : GenerationType .CHAT ,
26
29
},
27
30
"async" : False ,
28
31
},
29
32
{
30
- "module" : "mistralai.client " ,
31
- "object" : "MistralClient " ,
32
- "method" : "chat_stream " ,
33
+ "module" : "mistralai.chat " ,
34
+ "object" : "Chat " ,
35
+ "method" : "stream " ,
33
36
"metadata" : {
34
37
"type" : GenerationType .CHAT ,
35
38
},
36
39
"async" : False ,
37
40
},
38
41
{
39
- "module" : "mistralai.async_client " ,
40
- "object" : "MistralAsyncClient " ,
41
- "method" : "chat " ,
42
+ "module" : "mistralai.chat " ,
43
+ "object" : "Chat " ,
44
+ "method" : "complete_async " ,
42
45
"metadata" : {
43
46
"type" : GenerationType .CHAT ,
44
47
},
45
48
"async" : True ,
46
49
},
47
50
{
48
- "module" : "mistralai.async_client " ,
49
- "object" : "MistralAsyncClient " ,
50
- "method" : "chat_stream " ,
51
+ "module" : "mistralai.chat " ,
52
+ "object" : "Chat " ,
53
+ "method" : "stream_async " ,
51
54
"metadata" : {
52
55
"type" : GenerationType .CHAT ,
53
56
},
54
57
"async" : True ,
55
58
},
56
59
{
57
- "module" : "mistralai.client " ,
58
- "object" : "MistralClient " ,
59
- "method" : "completion " ,
60
+ "module" : "mistralai.fim " ,
61
+ "object" : "Fim " ,
62
+ "method" : "complete " ,
60
63
"metadata" : {
61
64
"type" : GenerationType .COMPLETION ,
62
65
},
63
66
"async" : False ,
64
67
},
65
68
{
66
- "module" : "mistralai.client " ,
67
- "object" : "MistralClient " ,
68
- "method" : "completion_stream " ,
69
+ "module" : "mistralai.fim " ,
70
+ "object" : "Fim " ,
71
+ "method" : "stream " ,
69
72
"metadata" : {
70
73
"type" : GenerationType .COMPLETION ,
71
74
},
72
75
"async" : False ,
73
76
},
74
77
{
75
- "module" : "mistralai.async_client " ,
76
- "object" : "MistralAsyncClient " ,
77
- "method" : "completion " ,
78
+ "module" : "mistralai.fim " ,
79
+ "object" : "Fim " ,
80
+ "method" : "complete_async " ,
78
81
"metadata" : {
79
82
"type" : GenerationType .COMPLETION ,
80
83
},
81
84
"async" : True ,
82
85
},
83
86
{
84
- "module" : "mistralai.async_client " ,
85
- "object" : "MistralAsyncClient " ,
86
- "method" : "completion_stream " ,
87
+ "module" : "mistralai.fim " ,
88
+ "object" : "Fim " ,
89
+ "method" : "stream_async " ,
87
90
"metadata" : {
88
91
"type" : GenerationType .COMPLETION ,
89
92
},
@@ -239,13 +242,13 @@ async def before(context: BeforeContext, *args, **kwargs):
239
242
240
243
return before
241
244
242
- from mistralai . models . chat_completion import DeltaMessage
245
+ from mistralai import DeltaMessage
243
246
244
247
def process_delta (new_delta : DeltaMessage , message_completion : GenerationMessage ):
245
248
if new_delta .tool_calls :
246
249
if "tool_calls" not in message_completion :
247
250
message_completion ["tool_calls" ] = []
248
- delta_tool_call = new_delta .tool_calls [0 ]
251
+ delta_tool_call = new_delta .tool_calls [0 ] # type: ignore
249
252
delta_function = delta_tool_call .function
250
253
if not delta_function :
251
254
return False
@@ -273,9 +276,11 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage
273
276
else :
274
277
return False
275
278
279
+ from mistralai import models
280
+
276
281
def streaming_response (
277
282
generation : Union [ChatGeneration , CompletionGeneration ],
278
- result : GeneratorType ,
283
+ result : Generator [ models . CompletionEvent , None , None ] ,
279
284
context : AfterContext ,
280
285
):
281
286
completion = ""
@@ -286,8 +291,8 @@ def streaming_response(
286
291
token_count = 0
287
292
for chunk in result :
288
293
if generation and isinstance (generation , ChatGeneration ):
289
- if len (chunk .choices ) > 0 :
290
- ok = process_delta (chunk .choices [0 ].delta , message_completion )
294
+ if len (chunk .data . choices ) > 0 :
295
+ ok = process_delta (chunk .data . choices [0 ].delta , message_completion )
291
296
if not ok :
292
297
yield chunk
293
298
continue
@@ -298,22 +303,22 @@ def streaming_response(
298
303
token_count += 1
299
304
elif generation and isinstance (generation , CompletionGeneration ):
300
305
if (
301
- len (chunk .choices ) > 0
302
- and chunk .choices [0 ].message .content is not None
306
+ len (chunk .data . choices ) > 0
307
+ and chunk .data . choices [0 ].delta .content is not None
303
308
):
304
309
if generation .tt_first_token is None :
305
310
generation .tt_first_token = (
306
311
time .time () - context ["start" ]
307
312
) * 1000
308
313
token_count += 1
309
- completion += chunk .choices [0 ].message .content
314
+ completion += chunk .data . choices [0 ].delta .content
310
315
311
316
if (
312
317
generation
313
318
and getattr (chunk , "model" , None )
314
- and generation .model != chunk .model
319
+ and generation .model != chunk .data . model
315
320
):
316
- generation .model = chunk .model
321
+ generation .model = chunk .data . model
317
322
318
323
yield chunk
319
324
@@ -358,7 +363,7 @@ def after(result, context: AfterContext, *args, **kwargs):
358
363
generation .model = model
359
364
if generation .settings :
360
365
generation .settings ["model" ] = model
361
- if isinstance (result , GeneratorType ):
366
+ if isinstance (result , Generator ):
362
367
return streaming_response (generation , result , context )
363
368
else :
364
369
generation .duration = time .time () - context ["start" ]
@@ -387,7 +392,7 @@ def after(result, context: AfterContext, *args, **kwargs):
387
392
388
393
async def async_streaming_response (
389
394
generation : Union [ChatGeneration , CompletionGeneration ],
390
- result : AsyncGenerator ,
395
+ result : AsyncGenerator [ models . CompletionEvent , None ] ,
391
396
context : AfterContext ,
392
397
):
393
398
completion = ""
@@ -398,8 +403,8 @@ async def async_streaming_response(
398
403
token_count = 0
399
404
async for chunk in result :
400
405
if generation and isinstance (generation , ChatGeneration ):
401
- if len (chunk .choices ) > 0 :
402
- ok = process_delta (chunk .choices [0 ].delta , message_completion )
406
+ if len (chunk .data . choices ) > 0 :
407
+ ok = process_delta (chunk .data . choices [0 ].delta , message_completion )
403
408
if not ok :
404
409
yield chunk
405
410
continue
@@ -410,22 +415,22 @@ async def async_streaming_response(
410
415
token_count += 1
411
416
elif generation and isinstance (generation , CompletionGeneration ):
412
417
if (
413
- len (chunk .choices ) > 0
414
- and chunk .choices [0 ].message . content is not None
418
+ len (chunk .data . choices ) > 0
419
+ and chunk .data . choices [0 ].delta is not None
415
420
):
416
421
if generation .tt_first_token is None :
417
422
generation .tt_first_token = (
418
423
time .time () - context ["start" ]
419
424
) * 1000
420
425
token_count += 1
421
- completion += chunk .choices [0 ].message .content
426
+ completion += chunk .data . choices [0 ].delta .content or ""
422
427
423
428
if (
424
429
generation
425
430
and getattr (chunk , "model" , None )
426
- and generation .model != chunk .model
431
+ and generation .model != chunk .data . model
427
432
):
428
- generation .model = chunk .model
433
+ generation .model = chunk .data . model
429
434
430
435
yield chunk
431
436
0 commit comments