Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 89ddcfd

Browse files
committed
fix: mistral instrumentation
1 parent 9b2159e commit 89ddcfd

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

literalai/instrumentation/mistralai.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import time
2-
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Union
2+
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Generator, Union
33

44
from literalai.instrumentation import MISTRALAI_PROVIDER
55
from literalai.requirements import check_all_requirements
66

77
if TYPE_CHECKING:
88
from literalai.client import LiteralClient
99

10-
from types import GeneratorType
11-
1210
from literalai.context import active_steps_var, active_thread_var
1311
from literalai.helper import ensure_values_serializable
1412
from literalai.observability.generation import (
@@ -23,72 +21,72 @@
2321

2422
APIS_TO_WRAP = [
2523
{
26-
"module": "mistralai",
27-
"object": "Mistral",
28-
"method": "chat.complete",
24+
"module": "mistralai.chat",
25+
"object": "Chat",
26+
"method": "complete",
2927
"metadata": {
3028
"type": GenerationType.CHAT,
3129
},
3230
"async": False,
3331
},
3432
{
35-
"module": "mistralai",
36-
"object": "Mistral",
37-
"method": "chat.stream",
33+
"module": "mistralai.chat",
34+
"object": "Chat",
35+
"method": "stream",
3836
"metadata": {
3937
"type": GenerationType.CHAT,
4038
},
4139
"async": False,
4240
},
4341
{
44-
"module": "mistralai",
45-
"object": "Mistral",
46-
"method": "chat.complete_async",
42+
"module": "mistralai.chat",
43+
"object": "Chat",
44+
"method": "complete_async",
4745
"metadata": {
4846
"type": GenerationType.CHAT,
4947
},
5048
"async": True,
5149
},
5250
{
53-
"module": "mistralai",
54-
"object": "Mistral",
55-
"method": "chat.stream_async",
51+
"module": "mistralai.chat",
52+
"object": "Chat",
53+
"method": "stream_async",
5654
"metadata": {
5755
"type": GenerationType.CHAT,
5856
},
5957
"async": True,
6058
},
6159
{
62-
"module": "mistralai",
63-
"object": "Mistral",
64-
"method": "fim.complete",
60+
"module": "mistralai.fim",
61+
"object": "Fim",
62+
"method": "complete",
6563
"metadata": {
6664
"type": GenerationType.COMPLETION,
6765
},
6866
"async": False,
6967
},
7068
{
71-
"module": "mistralai",
72-
"object": "Mistral",
73-
"method": "fim.stream",
69+
"module": "mistralai.fim",
70+
"object": "Fim",
71+
"method": "stream",
7472
"metadata": {
7573
"type": GenerationType.COMPLETION,
7674
},
7775
"async": False,
7876
},
7977
{
80-
"module": "mistralai",
81-
"object": "Mistral",
82-
"method": "fim.complete_async",
78+
"module": "mistralai.fim",
79+
"object": "Fim",
80+
"method": "complete_async",
8381
"metadata": {
8482
"type": GenerationType.COMPLETION,
8583
},
8684
"async": True,
8785
},
8886
{
89-
"module": "mistralai",
90-
"object": "Mistral",
91-
"method": "fim.stream_async",
87+
"module": "mistralai.fim",
88+
"object": "Fim",
89+
"method": "stream_async",
9290
"metadata": {
9391
"type": GenerationType.COMPLETION,
9492
},
@@ -278,9 +276,11 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage
278276
else:
279277
return False
280278

279+
from mistralai import models
280+
281281
def streaming_response(
282282
generation: Union[ChatGeneration, CompletionGeneration],
283-
result: GeneratorType,
283+
result: Generator[models.CompletionEvent, None, None],
284284
context: AfterContext,
285285
):
286286
completion = ""
@@ -291,8 +291,8 @@ def streaming_response(
291291
token_count = 0
292292
for chunk in result:
293293
if generation and isinstance(generation, ChatGeneration):
294-
if len(chunk.choices) > 0:
295-
ok = process_delta(chunk.choices[0].delta, message_completion)
294+
if len(chunk.data.choices) > 0:
295+
ok = process_delta(chunk.data.choices[0].delta, message_completion)
296296
if not ok:
297297
yield chunk
298298
continue
@@ -363,7 +363,7 @@ def after(result, context: AfterContext, *args, **kwargs):
363363
generation.model = model
364364
if generation.settings:
365365
generation.settings["model"] = model
366-
if isinstance(result, GeneratorType):
366+
if isinstance(result, Generator):
367367
return streaming_response(generation, result, context)
368368
else:
369369
generation.duration = time.time() - context["start"]
@@ -392,7 +392,7 @@ def after(result, context: AfterContext, *args, **kwargs):
392392

393393
async def async_streaming_response(
394394
generation: Union[ChatGeneration, CompletionGeneration],
395-
result: AsyncGenerator,
395+
result: AsyncGenerator[models.CompletionEvent, None],
396396
context: AfterContext,
397397
):
398398
completion = ""
@@ -403,8 +403,8 @@ async def async_streaming_response(
403403
token_count = 0
404404
async for chunk in result:
405405
if generation and isinstance(generation, ChatGeneration):
406-
if len(chunk.choices) > 0:
407-
ok = process_delta(chunk.choices[0].delta, message_completion)
406+
if len(chunk.data.choices) > 0:
407+
ok = process_delta(chunk.data.choices[0].delta, message_completion)
408408
if not ok:
409409
yield chunk
410410
continue

0 commit comments

Comments
 (0)