Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 98b3571

Browse files
Merge pull request #108 from Chainlit/willy/eng-1754-fix-mistralai-instrumentation-for-100
Willy/eng 1754 fix mistralai instrumentation for 100
2 parents 5011c6d + 4a0b063 commit 98b3571

File tree

4 files changed

+66
-62
lines changed

4 files changed

+66
-62
lines changed

literalai/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def ensure_values_serializable(data):
1818
pass
1919

2020
try:
21-
from mistralai.models.chat_completion import ChatMessage
21+
from mistralai import UserMessage
2222

23-
if isinstance(data, ChatMessage):
23+
if isinstance(data, UserMessage):
2424
return filter_none_values(data.model_dump())
2525
except ImportError:
2626
pass

literalai/instrumentation/mistralai.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,92 @@
11
import time
2-
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Union
2+
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Generator, Union
33

44
from literalai.instrumentation import MISTRALAI_PROVIDER
55
from literalai.requirements import check_all_requirements
66

77
if TYPE_CHECKING:
88
from literalai.client import LiteralClient
99

10-
from types import GeneratorType
11-
1210
from literalai.context import active_steps_var, active_thread_var
1311
from literalai.helper import ensure_values_serializable
14-
from literalai.observability.generation import GenerationMessage, CompletionGeneration, ChatGeneration, GenerationType
12+
from literalai.observability.generation import (
13+
ChatGeneration,
14+
CompletionGeneration,
15+
GenerationMessage,
16+
GenerationType,
17+
)
1518
from literalai.wrappers import AfterContext, BeforeContext, wrap_all
1619

17-
REQUIREMENTS = ["mistralai>=0.2.0"]
20+
REQUIREMENTS = ["mistralai>=1.0.0"]
1821

1922
APIS_TO_WRAP = [
2023
{
21-
"module": "mistralai.client",
22-
"object": "MistralClient",
23-
"method": "chat",
24+
"module": "mistralai.chat",
25+
"object": "Chat",
26+
"method": "complete",
2427
"metadata": {
2528
"type": GenerationType.CHAT,
2629
},
2730
"async": False,
2831
},
2932
{
30-
"module": "mistralai.client",
31-
"object": "MistralClient",
32-
"method": "chat_stream",
33+
"module": "mistralai.chat",
34+
"object": "Chat",
35+
"method": "stream",
3336
"metadata": {
3437
"type": GenerationType.CHAT,
3538
},
3639
"async": False,
3740
},
3841
{
39-
"module": "mistralai.async_client",
40-
"object": "MistralAsyncClient",
41-
"method": "chat",
42+
"module": "mistralai.chat",
43+
"object": "Chat",
44+
"method": "complete_async",
4245
"metadata": {
4346
"type": GenerationType.CHAT,
4447
},
4548
"async": True,
4649
},
4750
{
48-
"module": "mistralai.async_client",
49-
"object": "MistralAsyncClient",
50-
"method": "chat_stream",
51+
"module": "mistralai.chat",
52+
"object": "Chat",
53+
"method": "stream_async",
5154
"metadata": {
5255
"type": GenerationType.CHAT,
5356
},
5457
"async": True,
5558
},
5659
{
57-
"module": "mistralai.client",
58-
"object": "MistralClient",
59-
"method": "completion",
60+
"module": "mistralai.fim",
61+
"object": "Fim",
62+
"method": "complete",
6063
"metadata": {
6164
"type": GenerationType.COMPLETION,
6265
},
6366
"async": False,
6467
},
6568
{
66-
"module": "mistralai.client",
67-
"object": "MistralClient",
68-
"method": "completion_stream",
69+
"module": "mistralai.fim",
70+
"object": "Fim",
71+
"method": "stream",
6972
"metadata": {
7073
"type": GenerationType.COMPLETION,
7174
},
7275
"async": False,
7376
},
7477
{
75-
"module": "mistralai.async_client",
76-
"object": "MistralAsyncClient",
77-
"method": "completion",
78+
"module": "mistralai.fim",
79+
"object": "Fim",
80+
"method": "complete_async",
7881
"metadata": {
7982
"type": GenerationType.COMPLETION,
8083
},
8184
"async": True,
8285
},
8386
{
84-
"module": "mistralai.async_client",
85-
"object": "MistralAsyncClient",
86-
"method": "completion_stream",
87+
"module": "mistralai.fim",
88+
"object": "Fim",
89+
"method": "stream_async",
8790
"metadata": {
8891
"type": GenerationType.COMPLETION,
8992
},
@@ -239,13 +242,13 @@ async def before(context: BeforeContext, *args, **kwargs):
239242

240243
return before
241244

242-
from mistralai.models.chat_completion import DeltaMessage
245+
from mistralai import DeltaMessage
243246

244247
def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage):
245248
if new_delta.tool_calls:
246249
if "tool_calls" not in message_completion:
247250
message_completion["tool_calls"] = []
248-
delta_tool_call = new_delta.tool_calls[0]
251+
delta_tool_call = new_delta.tool_calls[0] # type: ignore
249252
delta_function = delta_tool_call.function
250253
if not delta_function:
251254
return False
@@ -273,9 +276,11 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage
273276
else:
274277
return False
275278

279+
from mistralai import models
280+
276281
def streaming_response(
277282
generation: Union[ChatGeneration, CompletionGeneration],
278-
result: GeneratorType,
283+
result: Generator[models.CompletionEvent, None, None],
279284
context: AfterContext,
280285
):
281286
completion = ""
@@ -286,8 +291,8 @@ def streaming_response(
286291
token_count = 0
287292
for chunk in result:
288293
if generation and isinstance(generation, ChatGeneration):
289-
if len(chunk.choices) > 0:
290-
ok = process_delta(chunk.choices[0].delta, message_completion)
294+
if len(chunk.data.choices) > 0:
295+
ok = process_delta(chunk.data.choices[0].delta, message_completion)
291296
if not ok:
292297
yield chunk
293298
continue
@@ -298,22 +303,22 @@ def streaming_response(
298303
token_count += 1
299304
elif generation and isinstance(generation, CompletionGeneration):
300305
if (
301-
len(chunk.choices) > 0
302-
and chunk.choices[0].message.content is not None
306+
len(chunk.data.choices) > 0
307+
and chunk.data.choices[0].delta.content is not None
303308
):
304309
if generation.tt_first_token is None:
305310
generation.tt_first_token = (
306311
time.time() - context["start"]
307312
) * 1000
308313
token_count += 1
309-
completion += chunk.choices[0].message.content
314+
completion += chunk.data.choices[0].delta.content
310315

311316
if (
312317
generation
313318
and getattr(chunk, "model", None)
314-
and generation.model != chunk.model
319+
and generation.model != chunk.data.model
315320
):
316-
generation.model = chunk.model
321+
generation.model = chunk.data.model
317322

318323
yield chunk
319324

@@ -358,7 +363,7 @@ def after(result, context: AfterContext, *args, **kwargs):
358363
generation.model = model
359364
if generation.settings:
360365
generation.settings["model"] = model
361-
if isinstance(result, GeneratorType):
366+
if isinstance(result, Generator):
362367
return streaming_response(generation, result, context)
363368
else:
364369
generation.duration = time.time() - context["start"]
@@ -387,7 +392,7 @@ def after(result, context: AfterContext, *args, **kwargs):
387392

388393
async def async_streaming_response(
389394
generation: Union[ChatGeneration, CompletionGeneration],
390-
result: AsyncGenerator,
395+
result: AsyncGenerator[models.CompletionEvent, None],
391396
context: AfterContext,
392397
):
393398
completion = ""
@@ -398,8 +403,8 @@ async def async_streaming_response(
398403
token_count = 0
399404
async for chunk in result:
400405
if generation and isinstance(generation, ChatGeneration):
401-
if len(chunk.choices) > 0:
402-
ok = process_delta(chunk.choices[0].delta, message_completion)
406+
if len(chunk.data.choices) > 0:
407+
ok = process_delta(chunk.data.choices[0].delta, message_completion)
403408
if not ok:
404409
yield chunk
405410
continue
@@ -410,22 +415,22 @@ async def async_streaming_response(
410415
token_count += 1
411416
elif generation and isinstance(generation, CompletionGeneration):
412417
if (
413-
len(chunk.choices) > 0
414-
and chunk.choices[0].message.content is not None
418+
len(chunk.data.choices) > 0
419+
and chunk.data.choices[0].delta is not None
415420
):
416421
if generation.tt_first_token is None:
417422
generation.tt_first_token = (
418423
time.time() - context["start"]
419424
) * 1000
420425
token_count += 1
421-
completion += chunk.choices[0].message.content
426+
completion += chunk.data.choices[0].delta.content or ""
422427

423428
if (
424429
generation
425430
and getattr(chunk, "model", None)
426-
and generation.model != chunk.model
431+
and generation.model != chunk.data.model
427432
):
428-
generation.model = chunk.model
433+
generation.model = chunk.data.model
429434

430435
yield chunk
431436

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ mypy
88
langchain
99
llama-index
1010
pytest_httpx
11-
mistralai < 1.0.0
11+
mistralai

tests/e2e/test_mistralai.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from asyncio import sleep
44

55
import pytest
6-
from mistralai.async_client import MistralAsyncClient
7-
from mistralai.client import MistralClient
6+
from mistralai import Mistral
87
from pytest_httpx import HTTPXMock
98

109
from literalai.client import LiteralClient
11-
from literalai.observability.generation import CompletionGeneration, ChatGeneration
10+
from literalai.observability.generation import ChatGeneration, CompletionGeneration
1211

1312

1413
@pytest.fixture
@@ -63,13 +62,13 @@ async def test_chat(self, client: "LiteralClient", httpx_mock: "HTTPXMock"):
6362
},
6463
}
6564
)
66-
mai_client = MistralClient(api_key="j3s4V1z4")
65+
mai_client = Mistral(api_key="j3s4V1z4")
6766
thread_id = None
6867

6968
@client.thread
7069
def main():
7170
# https://docs.mistral.ai/api/#operation/createChatCompletion
72-
mai_client.chat(
71+
mai_client.chat.complete(
7372
model="open-mistral-7b",
7473
messages=[
7574
{
@@ -124,13 +123,13 @@ async def test_completion(self, client: "LiteralClient", httpx_mock: "HTTPXMock"
124123
},
125124
)
126125

127-
mai_client = MistralClient(api_key="j3s4V1z4")
126+
mai_client = Mistral(api_key="j3s4V1z4")
128127
thread_id = None
129128

130129
@client.thread
131130
def main():
132131
# https://docs.mistral.ai/api/#operation/createFIMCompletion
133-
mai_client.completion(
132+
mai_client.fim.complete(
134133
model="codestral-2405",
135134
prompt="1+1=",
136135
temperature=0,
@@ -183,13 +182,13 @@ async def test_async_chat(self, client: "LiteralClient", httpx_mock: "HTTPXMock"
183182
},
184183
)
185184

186-
mai_client = MistralAsyncClient(api_key="j3s4V1z4")
185+
mai_client = Mistral(api_key="j3s4V1z4")
187186
thread_id = None
188187

189188
@client.thread
190189
async def main():
191190
# https://docs.mistral.ai/api/#operation/createChatCompletion
192-
await mai_client.chat(
191+
await mai_client.chat.complete_async(
193192
model="open-mistral-7b",
194193
messages=[
195194
{
@@ -246,13 +245,13 @@ async def test_async_completion(
246245
},
247246
)
248247

249-
mai_client = MistralAsyncClient(api_key="j3s4V1z4")
248+
mai_client = Mistral(api_key="j3s4V1z4")
250249
thread_id = None
251250

252251
@client.thread
253252
async def main():
254253
# https://docs.mistral.ai/api/#operation/createFIMCompletion
255-
await mai_client.completion(
254+
await mai_client.fim.complete_async(
256255
model="codestral-2405",
257256
prompt="1+1=",
258257
temperature=0,

0 commit comments

Comments
 (0)