|
1 |
| -from typing import Optional, AsyncGenerator |
2 |
| - |
| 1 | +import anyio |
3 | 2 | import httpx
|
4 | 3 | import orjson
|
5 | 4 |
|
| 5 | +async def chat_with_endpoint(): |
| 6 | + async with httpx.AsyncClient() as client: |
| 7 | + while True: |
| 8 | + # Get user input |
| 9 | + prompt = input("\nYou: ") |
| 10 | + if prompt.lower() == "exit": |
| 11 | + break |
6 | 12 |
|
7 |
| -class StreamLLMService: |
8 |
| - def __init__(self, base_url: str = "http://localhost:11434/v1"): |
9 |
| - self.base_url = base_url |
10 |
| - self.model = "llama3.2" |
11 |
| - |
12 |
| - async def stream_chat(self, prompt: str) -> AsyncGenerator[bytes, None]: |
13 |
| - """Stream chat completion responses from LLM.""" |
14 |
| - # Send user message first |
15 |
| - user_msg = { |
16 |
| - "role": "user", |
17 |
| - "content": prompt, |
18 |
| - } |
19 |
| - yield orjson.dumps(user_msg) + b"\n" |
20 |
| - |
21 |
| - # Open client as context manager and stream responses |
22 |
| - async with httpx.AsyncClient(base_url=self.base_url) as client: |
| 13 | + # Send request to the API |
| 14 | + print("\nModel: ", end="", flush=True) |
23 | 15 | async with client.stream(
|
24 | 16 | "POST",
|
25 |
| - "/chat/completions", |
26 |
| - json={ |
27 |
| - "model": self.model, |
28 |
| - "messages": [{"role": "user", "content": prompt}], |
29 |
| - "stream": True, |
30 |
| - }, |
31 |
| - timeout=60.0, |
| 17 | + "http://localhost:8000/chat/", |
| 18 | + data={"prompt": prompt}, |
| 19 | + timeout=60 |
32 | 20 | ) as response:
|
33 |
| - async for line in response.aiter_lines(): |
34 |
| - print(line) |
35 |
| - if line.startswith("data: ") and line != "data: [DONE]": |
| 21 | + async for chunk in response.aiter_lines(): |
| 22 | + if chunk: |
36 | 23 | try:
|
37 |
| - json_line = line[6:] # Remove "data: " prefix |
38 |
| - data = orjson.loads(json_line) |
39 |
| - content = ( |
40 |
| - data.get("choices", [{}])[0] |
41 |
| - .get("delta", {}) |
42 |
| - .get("content", "") |
43 |
| - ) |
44 |
| - if content: |
45 |
| - model_msg = {"role": "model", "content": content} |
46 |
| - yield orjson.dumps(model_msg) + b"\n" |
47 |
| - except Exception: |
48 |
| - pass |
49 |
| - |
| 24 | + data = orjson.loads(chunk) |
| 25 | + print(data["content"], end="", flush=True) |
| 26 | + except Exception as e: |
| 27 | + print(f"\nError parsing chunk: {e}") |
50 | 28 |
|
51 |
| -# FastAPI dependency |
52 |
| -def get_llm_service(base_url: Optional[str] = None) -> StreamLLMService: |
53 |
| - return StreamLLMService(base_url=base_url or "http://localhost:11434/v1") |
| 29 | +if __name__ == "__main__": |
| 30 | + anyio.run(chat_with_endpoint) |
0 commit comments