diff --git a/agents-core/vision_agents/core/utils/video_utils.py b/agents-core/vision_agents/core/utils/video_utils.py index 5572f265..9fad5d7a 100644 --- a/agents-core/vision_agents/core/utils/video_utils.py +++ b/agents-core/vision_agents/core/utils/video_utils.py @@ -1,10 +1,13 @@ """Video frame utilities.""" import io +import logging import av -from PIL.Image import Resampling from PIL import Image +from PIL.Image import Resampling + +logger = logging.getLogger(__name__) def ensure_even_dimensions(frame: av.VideoFrame) -> av.VideoFrame: @@ -83,3 +86,42 @@ def frame_to_png_bytes(frame: av.VideoFrame) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() + + +def resize_frame(self, frame: av.VideoFrame) -> av.VideoFrame: + """ + Resizes a video frame to target dimensions while maintaining the aspect ratio. The method centers the resized + image on a black background if the target dimensions do not match the original aspect ratio. + + Parameters: + frame (av.VideoFrame): The input video frame to be resized. + + Returns: + av.VideoFrame: The output video frame after resizing, maintaining the original aspect ratio. + + Raises: + None + """ + img = frame.to_image() + + # Calculate scaling to maintain aspect ratio + src_width, src_height = img.size + target_width, target_height = self.width, self.height + + # Calculate scale factor (fit within target dimensions) + scale = min(target_width / src_width, target_height / src_height) + new_width = int(src_width * scale) + new_height = int(src_height * scale) + + # Resize with aspect ratio maintained + resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Create black background at target resolution + result = Image.new("RGB", (target_width, target_height), (0, 0, 0)) + + # Paste resized image centered + x_offset = (target_width - new_width) // 2 + y_offset = (target_height - new_height) // 2 + result.paste(resized, (x_offset, y_offset)) + + return av.VideoFrame.from_image(result) diff --git a/plugins/decart/README.md b/plugins/decart/README.md new file mode 100644 index 00000000..0292dfed --- /dev/null +++ b/plugins/decart/README.md @@ -0,0 +1,4 @@ +# Decart Plugin + +Decart plugin for Vision Agents. + diff --git a/plugins/decart/example/__init__.py b/plugins/decart/example/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/decart/example/decart_example.py b/plugins/decart/example/decart_example.py new file mode 100644 index 00000000..4ece59c7 --- /dev/null +++ b/plugins/decart/example/decart_example.py @@ -0,0 +1,58 @@ +import logging + +from dotenv import load_dotenv + +from vision_agents.core import User, Agent, cli +from vision_agents.core.agents import AgentLauncher +from vision_agents.plugins import decart, getstream, openai, elevenlabs, deepgram + +logger = logging.getLogger(__name__) + +load_dotenv() + + +async def create_agent(**kwargs) -> Agent: + processor = decart.RestylingProcessor( + initial_prompt="A cute animated movie with vibrant colours", model="mirage_v2" + ) + llm = openai.LLM(model="gpt-4o-mini") + + agent = Agent( + edge=getstream.Edge(), + agent_user=User(name="Story teller", id="agent"), + instructions="You are a story teller. You will tell a short story to the user. You will use the Decart processor to change the style of the video and user's background. You can embed audio tags in your responses for added effect Emotional tone: [EXCITED], [NERVOUS], [FRUSTRATED], [TIRED] Reactions: [GASP], [SIGH], [LAUGHS], [GULPS] Volume & energy: [WHISPERING], [SHOUTING], [QUIETLY], [LOUDLY] Pacing & rhythm: [PAUSES], [STAMMERS], [RUSHED]", + llm=llm, + tts=elevenlabs.TTS(voice_id="N2lVS1w4EtoT3dr4eOWO"), + stt=deepgram.STT(), + processors=[processor], + ) + + @llm.register_function( + description="This function changes the prompt of the Decart processor which in turn changes the style of the video and user's background" + ) + async def change_prompt(prompt: str) -> str: + await processor.update_prompt(prompt) + return f"Prompt changed to {prompt}" + + return agent + + +async def join_call(agent: Agent, call_type: str, call_id: str, **kwargs) -> None: + """Join the call and start the agent.""" + # Ensure the agent user is created + await agent.create_user() + # Create a call + call = await agent.create_call(call_type, call_id) + + logger.info("🤖 Starting Agent...") + + # Have the agent join the call/room + with await agent.join(call): + logger.info("Joining call") + logger.info("LLM ready") + + await agent.finish() # Run till the call ends + + +if __name__ == "__main__": + cli(AgentLauncher(create_agent=create_agent, join_call=join_call)) diff --git a/plugins/decart/example/pyproject.toml b/plugins/decart/example/pyproject.toml new file mode 100644 index 00000000..da56fe3c --- /dev/null +++ b/plugins/decart/example/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "decart-example" +version = "0.0.0" +requires-python = ">=3.10" + +dependencies = [ + "vision-agents", + "python-dotenv", + "vision-agents-plugins-openai", + "vision-agents-plugins-decart", + "vision-agents-plugins-elvenlabs", + "vision-agents-plugins-getstream", + "vision-agents-plugins-deepgram", + +] + +[tool.uv.sources] +vision-agents = { workspace = true } +vision-agents-plugins-getstream = { editable=true } +vision-agents-plugins-openai = { editable=true } +vision-agents-plugins-elevenlabs = { editable=true } +vision-agents-plugins-deepgram = { editable=true } +vision-agents-plugins-decart = { editable=true } diff --git a/plugins/decart/py.typed b/plugins/decart/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/plugins/decart/pyproject.toml b/plugins/decart/pyproject.toml new file mode 100644 index 00000000..c95bedd2 --- /dev/null +++ b/plugins/decart/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "vision-agents-plugins-decart" +dynamic = ["version"] +description = "Decart plugin for Vision Agents" +readme = "README.md" +keywords = ["decart", "AI", "voice agents", "agents"] +requires-python = ">=3.10" +license = "MIT" +dependencies = [ + "vision-agents", + "decart", +] + +[project.urls] +Documentation = "https://visionagents.ai/" +Website = "https://visionagents.ai/" +Source = "https://github.com/GetStream/Vision-Agents" + +[tool.hatch.version] +source = "vcs" +raw-options = { root = "..", search_parent_directories = true, fallback_version = "0.0.0" } + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.hatch.build.targets.sdist] +include = ["/vision_agents"] + +[tool.uv.sources] +vision-agents = { workspace = true } + +[dependency-groups] +dev = [ + "pytest>=8.4.1", + "pytest-asyncio>=1.0.0", +] + diff --git a/plugins/decart/tests/test_decart_restyling.py b/plugins/decart/tests/test_decart_restyling.py new file mode 100644 index 00000000..12f96c00 --- /dev/null +++ b/plugins/decart/tests/test_decart_restyling.py @@ -0,0 +1,446 @@ +"""Tests for RestylingProcessor.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import av +import pytest +from aiortc import MediaStreamTrack +from decart import DecartSDKError +import websockets + +from vision_agents.plugins.decart import RestylingProcessor +from vision_agents.plugins.decart.decart_video_track import DecartVideoTrack + + +@pytest.fixture +def mock_video_track(): + """Mock video track.""" + track = MagicMock(spec=MediaStreamTrack) + return track + + +@pytest.fixture +def sample_frame(): + """Test av.VideoFrame fixture.""" + from PIL import Image + + image = Image.new("RGB", (1280, 720), color="blue") + return av.VideoFrame.from_image(image) + + +@pytest.fixture +def mock_decart_client(): + """Mock DecartClient with async close method.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.DecartClient" + ) as mock_client: + mock_instance = MagicMock() + mock_instance.close = AsyncMock() + mock_instance.base_url = "https://api.decart.ai" + mock_instance.api_key = "test_key" + mock_client.return_value = mock_instance + yield mock_client + + +class TestRestylingProcessor: + """Tests for RestylingProcessor.""" + + def test_publish_video_track(self, mock_decart_client): + """Test that publish_video_track returns DecartVideoTrack.""" + processor = RestylingProcessor(api_key="test_key") + track = processor.publish_video_track() + assert isinstance(track, DecartVideoTrack) + processor.close() + + @pytest.mark.asyncio + async def test_process_video_triggers_connection( + self, mock_video_track, mock_decart_client + ): + """Test that process_video triggers connection to Decart.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.connect = AsyncMock(return_value=mock_client_instance) + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + await processor.process_video(mock_video_track, None) + + assert processor._current_track == mock_video_track + assert mock_realtime.connect.called + processor.close() + + @pytest.mark.asyncio + async def test_process_video_prevents_duplicate_connections( + self, mock_video_track, mock_decart_client + ): + """Test that process_video prevents duplicate connections.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.connect = AsyncMock(return_value=mock_client_instance) + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + processor._connecting = True + + await processor.process_video(mock_video_track, None) + + assert not mock_realtime.connect.called + processor.close() + + @pytest.mark.asyncio + async def test_update_prompt_when_connected(self, mock_decart_client): + """Test update_prompt updates prompt when connected.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + mock_client = AsyncMock() + processor._realtime_client = mock_client + processor._connected = True + + await processor.update_prompt("new style", enrich=False) + + mock_client.set_prompt.assert_called_once_with("new style", enrich=False) + assert processor.initial_prompt == "new style" + processor.close() + + @pytest.mark.asyncio + async def test_update_prompt_noop_when_disconnected(self, mock_decart_client): + """Test update_prompt is no-op when disconnected.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor( + api_key="test_key", initial_prompt="original" + ) + processor._realtime_client = None + + await processor.update_prompt("new style") + + assert processor.initial_prompt == "original" + processor.close() + + @pytest.mark.asyncio + async def test_update_prompt_uses_default_enrich(self, mock_decart_client): + """Test update_prompt uses default enrich value when not specified.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key", enrich=True) + mock_client = AsyncMock() + processor._realtime_client = mock_client + + await processor.update_prompt("new style") + + mock_client.set_prompt.assert_called_once_with("new style", enrich=True) + processor.close() + + @pytest.mark.asyncio + async def test_set_mirror_when_connected(self, mock_decart_client): + """Test set_mirror updates mirror mode when connected.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key", mirror=True) + mock_client = AsyncMock() + processor._realtime_client = mock_client + + await processor.set_mirror(False) + + mock_client.set_mirror.assert_called_once_with(False) + assert processor.mirror is False + processor.close() + + @pytest.mark.asyncio + async def test_set_mirror_noop_when_disconnected(self, mock_decart_client): + """Test set_mirror is no-op when disconnected.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key", mirror=True) + processor._realtime_client = None + + await processor.set_mirror(False) + + assert processor.mirror is True + processor.close() + + +class TestConnectionManagement: + """Tests for connection management.""" + + @pytest.mark.asyncio + async def test_connection_lifecycle(self, mock_video_track, mock_decart_client): + """Test connection lifecycle (connecting -> connected).""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.connect = AsyncMock(return_value=mock_client_instance) + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + assert not processor._connected + assert not processor._connecting + + await processor._connect_to_decart(mock_video_track) + + assert processor._connected + assert not processor._connecting + assert processor._realtime_client is not None + processor.close() + + @pytest.mark.asyncio + async def test_reconnection_on_connection_error( + self, mock_video_track, mock_decart_client + ): + """Test reconnection on connection errors.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.connect = AsyncMock(return_value=mock_client_instance) + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + processor._current_track = mock_video_track + + error = DecartSDKError("connection timeout") + processor._on_error(error) + + await asyncio.sleep(0.1) + assert mock_realtime.connect.called + processor.close() + + @pytest.mark.asyncio + async def test_reconnection_on_websocket_error( + self, mock_video_track, mock_decart_client + ): + """Test reconnection on websocket connection errors.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.connect = AsyncMock(return_value=mock_client_instance) + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + processor._current_track = mock_video_track + + error = websockets.ConnectionClosedError(None, None) + processor._on_error(error) + + await asyncio.sleep(0.1) + assert mock_realtime.connect.called + processor.close() + + @pytest.mark.asyncio + async def test_no_reconnection_on_non_connection_error( + self, mock_video_track, mock_decart_client + ): + """Test no reconnection on non-connection errors.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + processor._current_track = mock_video_track + + error = DecartSDKError("invalid api key") + processor._on_error(error) + + await asyncio.sleep(0.1) + assert not processor._connected + processor.close() + + @pytest.mark.asyncio + async def test_connection_change_updates_state(self, mock_decart_client): + """Test that connection change events update state.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + + processor._on_connection_change("connecting") + assert processor._connected + + processor._on_connection_change("connected") + assert processor._connected + + processor._on_connection_change("disconnected") + assert not processor._connected + + processor._on_connection_change("error") + assert not processor._connected + processor.close() + + @pytest.mark.asyncio + async def test_disconnect_cleans_up(self, mock_video_track, mock_decart_client): + """Test that disconnect cleans up properly.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.disconnect = AsyncMock() + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + await processor._connect_to_decart(mock_video_track) + assert processor._connected + + await processor._disconnect_from_decart() + + assert not processor._connected + assert processor._realtime_client is None + mock_client_instance.disconnect.assert_called_once() + processor.close() + + @pytest.mark.asyncio + async def test_processing_loop_reconnects( + self, mock_video_track, mock_decart_client + ): + """Test that processing loop reconnects when connection is lost.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ) as mock_realtime: + mock_client_instance = AsyncMock() + mock_client_instance.connect = AsyncMock(return_value=mock_client_instance) + mock_realtime.connect = AsyncMock(return_value=mock_client_instance) + + processor = RestylingProcessor(api_key="test_key") + processor._current_track = mock_video_track + processor._connected = False + processor._connecting = False + + task = asyncio.create_task(processor._processing_loop()) + await asyncio.sleep(1.5) + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert mock_realtime.connect.called + processor.close() + + +class TestFrameHandling: + """Tests for frame handling.""" + + @pytest.mark.asyncio + async def test_frames_received_from_decart_forwarded_to_track( + self, sample_frame, mock_decart_client + ): + """Test that frames received from Decart are forwarded to video track.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + + call_count = 0 + + async def mock_recv(): + nonlocal call_count + call_count += 1 + if call_count > 2: + raise asyncio.CancelledError() + return sample_frame + + mock_transformed_stream = AsyncMock() + mock_transformed_stream.recv = mock_recv + + task = asyncio.create_task( + processor._receive_frames_from_decart(mock_transformed_stream) + ) + await asyncio.sleep(0.1) + processor._video_track.stop() + + try: + await task + except asyncio.CancelledError: + pass + + assert processor._video_track.frame_queue.qsize() > 0 + processor.close() + + @pytest.mark.asyncio + async def test_frame_receiving_task_cancelled_on_close( + self, sample_frame, mock_decart_client + ): + """Test that frame receiving task is cancelled on close.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + + mock_transformed_stream = AsyncMock() + mock_transformed_stream.recv = AsyncMock(return_value=sample_frame) + + task = asyncio.create_task( + processor._receive_frames_from_decart(mock_transformed_stream) + ) + await asyncio.sleep(0.05) + + processor.close() + await asyncio.sleep(0.1) + + assert task.done() + processor.close() + + @pytest.mark.asyncio + async def test_on_remote_stream_starts_frame_receiving( + self, sample_frame, mock_decart_client + ): + """Test that on_remote_stream starts frame receiving task.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + + mock_transformed_stream = AsyncMock() + mock_transformed_stream.recv = AsyncMock(return_value=sample_frame) + + processor._on_remote_stream(mock_transformed_stream) + + assert processor._frame_receiving_task is not None + assert not processor._frame_receiving_task.done() + + await asyncio.sleep(0.1) + processor._video_track.stop() + await asyncio.sleep(0.1) + + processor.close() + + @pytest.mark.asyncio + async def test_on_remote_stream_cancels_previous_task( + self, sample_frame, mock_decart_client + ): + """Test that on_remote_stream cancels previous frame receiving task.""" + with patch( + "vision_agents.plugins.decart.decart_restyling_processor.RealtimeClient" + ): + processor = RestylingProcessor(api_key="test_key") + + mock_stream1 = AsyncMock() + mock_stream1.recv = AsyncMock(return_value=sample_frame) + + mock_stream2 = AsyncMock() + mock_stream2.recv = AsyncMock(return_value=sample_frame) + + processor._on_remote_stream(mock_stream1) + task1 = processor._frame_receiving_task + + await asyncio.sleep(0.05) + processor._on_remote_stream(mock_stream2) + + assert task1.done() + assert processor._frame_receiving_task != task1 + + processor._video_track.stop() + await asyncio.sleep(0.1) + processor.close() diff --git a/plugins/decart/tests/test_decart_video_track.py b/plugins/decart/tests/test_decart_video_track.py new file mode 100644 index 00000000..a1dfbb7c --- /dev/null +++ b/plugins/decart/tests/test_decart_video_track.py @@ -0,0 +1,124 @@ +"""Tests for DecartVideoTrack.""" + +import av +import pytest +from PIL import Image + +from vision_agents.plugins.decart.decart_video_track import DecartVideoTrack + + +@pytest.fixture +def sample_image(): + """Test image fixture.""" + return Image.new("RGB", (640, 480), color="blue") + + +@pytest.fixture +def sample_frame(sample_image): + """Test av.VideoFrame fixture.""" + return av.VideoFrame.from_image(sample_image) + + +@pytest.fixture +def sample_frame_large(): + """Test av.VideoFrame fixture with different size.""" + image = Image.new("RGB", (1920, 1080), color="red") + return av.VideoFrame.from_image(image) + + +class TestDecartVideoTrack: + """Tests for DecartVideoTrack.""" + + def test_init_default_dimensions(self): + """Test initialization with default dimensions.""" + track = DecartVideoTrack() + assert track.width == 1280 + assert track.height == 720 + assert not track._stopped + track.stop() + + def test_init_custom_dimensions(self): + """Test initialization with custom dimensions.""" + track = DecartVideoTrack(width=1920, height=1080) + assert track.width == 1920 + assert track.height == 1080 + assert not track._stopped + track.stop() + + @pytest.mark.asyncio + async def test_add_frame_correct_size(self, sample_frame): + """Test adding frame with correct size.""" + track = DecartVideoTrack(width=640, height=480) + await track.add_frame(sample_frame) + assert track.frame_queue.qsize() == 1 + track.stop() + + @pytest.mark.asyncio + async def test_add_frame_requires_resize(self, sample_frame_large): + """Test adding frame that requires resize.""" + track = DecartVideoTrack(width=1280, height=720) + await track.add_frame(sample_frame_large) + assert track.frame_queue.qsize() == 1 + received_frame = await track.frame_queue.get() + assert received_frame.width == 1280 + assert received_frame.height == 720 + track.stop() + + @pytest.mark.asyncio + async def test_add_frame_ignored_when_stopped(self, sample_frame): + """Test that add_frame is ignored when track is stopped.""" + track = DecartVideoTrack() + track.stop() + await track.add_frame(sample_frame) + assert track.frame_queue.qsize() == 0 + track.stop() + + @pytest.mark.asyncio + async def test_recv_returns_frame(self, sample_frame): + """Test that recv returns a frame when available.""" + track = DecartVideoTrack(width=640, height=480) + await track.add_frame(sample_frame) + received_frame = await track.recv() + assert received_frame is not None + assert received_frame.width == 640 + assert received_frame.height == 480 + assert received_frame.pts is not None + assert received_frame.time_base is not None + track.stop() + + @pytest.mark.asyncio + async def test_recv_returns_placeholder_when_no_frames(self): + """Test that recv returns placeholder frame when no frames available.""" + track = DecartVideoTrack() + received_frame = await track.recv() + assert received_frame is not None + assert received_frame.width == 1280 + assert received_frame.height == 720 + track.stop() + + @pytest.mark.asyncio + async def test_recv_raises_when_stopped(self): + """Test that recv raises exception when track is stopped.""" + track = DecartVideoTrack() + track.stop() + with pytest.raises(Exception, match="Track stopped"): + await track.recv() + + @pytest.mark.asyncio + async def test_recv_returns_latest_frame(self, sample_frame): + """Test that recv returns the latest frame.""" + track = DecartVideoTrack(width=640, height=480) + await track.add_frame(sample_frame) + frame1 = await track.recv() + await track.add_frame(sample_frame) + frame2 = await track.recv() + assert frame1 is not None + assert frame2 is not None + track.stop() + + def test_stop(self): + """Test stopping the video track.""" + track = DecartVideoTrack() + assert not track._stopped + track.stop() + assert track._stopped diff --git a/plugins/decart/vision_agents/plugins/decart/__init__.py b/plugins/decart/vision_agents/plugins/decart/__init__.py new file mode 100644 index 00000000..a92595d9 --- /dev/null +++ b/plugins/decart/vision_agents/plugins/decart/__init__.py @@ -0,0 +1,3 @@ +from .decart_restyling_processor import RestylingProcessor + +__all__ = ["RestylingProcessor"] diff --git a/plugins/decart/vision_agents/plugins/decart/decart_restyling_processor.py b/plugins/decart/vision_agents/plugins/decart/decart_restyling_processor.py new file mode 100644 index 00000000..ec016920 --- /dev/null +++ b/plugins/decart/vision_agents/plugins/decart/decart_restyling_processor.py @@ -0,0 +1,299 @@ +import asyncio +import logging +import os +from asyncio import CancelledError +from typing import Any, Optional, cast + +import aiortc +import av +import websockets +from aiortc import MediaStreamTrack, VideoStreamTrack +from decart import DecartClient, models +from decart import DecartSDKError +from decart.realtime import RealtimeClient, RealtimeConnectOptions +from decart.types import ModelState, Prompt +from decart.models import RealTimeModels + +from vision_agents.core.processors.base_processor import ( + AudioVideoProcessor, + VideoProcessorMixin, + VideoPublisherMixin, +) +from .decart_video_track import DecartVideoTrack + +logger = logging.getLogger(__name__) + + +def _should_reconnect(exc: Exception) -> bool: + if isinstance(exc, websockets.ConnectionClosedError): + return True + + if isinstance(exc, DecartSDKError): + error_msg = str(exc).lower() + if ( + "connection" in error_msg + or "disconnect" in error_msg + or "timeout" in error_msg + ): + return True + + return False + + +class RestylingProcessor(AudioVideoProcessor, VideoProcessorMixin, VideoPublisherMixin): + """Decart Realtime restyling processor for transforming user video tracks. + + This processor accepts the user's local video track, sends it to Decart's + Realtime API via websocket, receives transformed frames, and publishes them + as a new video track. + + Example: + agent = Agent( + edge=getstream.Edge(), + agent_user=User(name="Styled AI"), + instructions="Be helpful", + llm=gemini.Realtime(), + processors=[ + decart.RestylingProcessor( + initial_prompt="Studio Ghibli animation style", + model="mirage_v2" + ) + ] + ) + """ + + name = "decart_restyling" + + def __init__( + self, + api_key: Optional[str] = None, + model: RealTimeModels = "mirage_v2", + initial_prompt: str = "Cyberpunk city", + enrich: bool = True, + mirror: bool = True, + width: int = 1280, # Model preferred + height: int = 720, + **kwargs, + ): + """Initialize the Decart restyling processor. + + Args: + api_key: Decart API key. Uses DECART_API_KEY env var if not provided. + model: Decart model name (default: "mirage_v2"). + initial_prompt: Initial style prompt text. + enrich: Whether to enrich prompt (default: True). + mirror: Mirror mode for front camera (default: True). + width: Output video width (default: 1280). + height: Output video height (default: 720). + **kwargs: Additional arguments passed to parent class. + """ + super().__init__( + interval=0, + receive_audio=False, + receive_video=True, + **kwargs, + ) + + self.api_key = api_key or os.getenv("DECART_API_KEY") + if not self.api_key: + raise ValueError( + "Decart API key is required. Set DECART_API_KEY environment variable " + "or pass api_key parameter." + ) + + self.model_name = model + self.initial_prompt = initial_prompt + self.enrich = enrich + self.mirror = mirror + self.width = width + self.height = height + + self.model = models.realtime(self.model_name) + + self._decart_client = DecartClient(api_key=self.api_key, **kwargs) + self._video_track = DecartVideoTrack(width=width, height=height) + self._realtime_client: Optional[RealtimeClient] = None + + self._connected = False + self._connecting = False + self._processing_task: Optional[asyncio.Task] = None + self._frame_receiving_task: Optional[asyncio.Task] = None + self._current_track: Optional[MediaStreamTrack] = None + self._on_connection_change_callback = None + + logger.info( + f"Decart RestylingProcessor initialized (model: {self.model_name}, prompt: {self.initial_prompt[:50]}...)" + ) + + async def process_video( + self, + incoming_track: aiortc.mediastreams.MediaStreamTrack, + participant: Any, + shared_forwarder=None, + ): + logger.debug("Processing video track, connecting to Decart") + self._current_track = incoming_track + if not self._connected and not self._connecting: + await self._connect_to_decart(incoming_track) + + def publish_video_track(self) -> VideoStreamTrack: + return self._video_track + + async def update_prompt( + self, prompt_text: str, enrich: Optional[bool] = None + ) -> None: + """ + Updates the prompt used for the Decart real-time client. This method allows + changing the current prompt and optionally specifies whether to enrich the + prompt content. The operation is performed asynchronously and requires an + active connection to the Decart client. + + If the `enrich` parameter is not provided, the method uses the default + `self.enrich` value. + + Parameters: + prompt_text: str + The text of the new prompt to be applied. + enrich: Optional[bool] + Specifies whether to enrich the prompt content. If not provided, + defaults to the object's `enrich` attribute. + + Returns: + None + """ + if not self._realtime_client: + logger.debug("Cannot set prompt: not connected to Decart") + return + + enrich_value = enrich if enrich is not None else self.enrich + await self._realtime_client.set_prompt(prompt_text, enrich=enrich_value) + self.initial_prompt = prompt_text + logger.info(f"Updated Decart prompt: {prompt_text[:50]}...") + + async def set_mirror(self, enabled: bool) -> None: + if not self._realtime_client: + logger.debug("Cannot set mirror: not connected to Decart") + return + + await self._realtime_client.set_mirror(enabled) + self.mirror = enabled + logger.debug(f"Updated Decart mirror mode: {enabled}") + + async def _connect_to_decart(self, local_track: MediaStreamTrack) -> None: + if self._connecting: + logger.debug("Already connecting to Decart, skipping") + return + + logger.info(f"Connecting to Decart Realtime API (model: {self.model_name})") + self._connecting = True + + try: + if self._realtime_client: + await self._disconnect_from_decart() + + initial_state = ModelState( + prompt=Prompt( + text=self.initial_prompt, + enrich=self.enrich, + ), + mirror=self.mirror, + ) + + self._realtime_client = await RealtimeClient.connect( + base_url=self._decart_client.base_url, + api_key=self._decart_client.api_key, + local_track=local_track, + options=RealtimeConnectOptions( + model=self.model, + on_remote_stream=self._on_remote_stream, + initial_state=initial_state, + ), + ) + + self._realtime_client.on("connection_change", self._on_connection_change) + self._realtime_client.on("error", self._on_error) + + self._connected = True + logger.info("Connected to Decart Realtime API") + + if self._processing_task is None or self._processing_task.done(): + self._processing_task = asyncio.create_task(self._processing_loop()) + + except Exception as e: + self._connected = False + logger.error(f"Failed to connect to Decart: {e}") + raise + finally: + self._connecting = False + + def _on_remote_stream(self, transformed_stream: MediaStreamTrack) -> None: + if self._frame_receiving_task and not self._frame_receiving_task.done(): + self._frame_receiving_task.cancel() + + self._frame_receiving_task = asyncio.create_task( + self._receive_frames_from_decart(transformed_stream) + ) + logger.debug("Started receiving frames from Decart transformed stream") + + async def _receive_frames_from_decart( + self, transformed_stream: MediaStreamTrack + ) -> None: + try: + while not self._video_track.is_stopped: + frame = await transformed_stream.recv() + await self._video_track.add_frame(cast(av.VideoFrame, frame)) + except asyncio.CancelledError: + logger.debug("Frame receiving from Decart cancelled") + + def _on_connection_change(self, state: str) -> None: + logger.debug(f"Decart connection state changed: {state}") + if state in ("connected", "connecting"): + self._connected = True + elif state in ("disconnected", "error"): + self._connected = False + if state == "disconnected": + logger.info("Disconnected from Decart Realtime API") + elif state == "error": + logger.warning("Decart connection error occurred") + + if self._on_connection_change_callback: + self._on_connection_change_callback(state) + + def _on_error(self, error: DecartSDKError) -> None: + logger.warning(f"Decart error: {error}") + if _should_reconnect(error) and self._current_track: + logger.info("Attempting to reconnect to Decart...") + asyncio.create_task(self._connect_to_decart(self._current_track)) + + # Reconnect to Decart if the connection is dropped + async def _processing_loop(self) -> None: + try: + while True: + if not self._connected and not self._connecting and self._current_track: + logger.debug("Connection lost, attempting to reconnect...") + await self._connect_to_decart(self._current_track) + + await asyncio.sleep(1.0) + except CancelledError: + logger.debug("Decart processing loop cancelled") + + async def _disconnect_from_decart(self) -> None: + if self._realtime_client: + logger.debug("Disconnecting from Decart Realtime API") + await self._realtime_client.disconnect() + self._realtime_client = None + self._connected = False + + async def close(self) -> None: + if self._video_track: + self._video_track.stop() + + if self._frame_receiving_task and not self._frame_receiving_task.done(): + self._frame_receiving_task.cancel() + + if self._processing_task and not self._processing_task.done(): + self._processing_task.cancel() + else: + if self._realtime_client or self._decart_client: + await self._disconnect_from_decart() + await self._decart_client.close() diff --git a/plugins/decart/vision_agents/plugins/decart/decart_video_track.py b/plugins/decart/vision_agents/plugins/decart/decart_video_track.py new file mode 100644 index 00000000..7761b270 --- /dev/null +++ b/plugins/decart/vision_agents/plugins/decart/decart_video_track.py @@ -0,0 +1,83 @@ +import asyncio +import logging +from typing import Optional + +import av +from PIL import Image +from aiortc import MediaStreamTrack, VideoStreamTrack + +from vision_agents.core.utils.video_queue import VideoLatestNQueue +from vision_agents.core.utils.video_utils import resize_frame + +logger = logging.getLogger(__name__) + + +class DecartVideoTrack(VideoStreamTrack): + """Video track that forwards Decart restyled video frames. + + Receives video frames from Decart's Realtime API and provides + them through the standard VideoStreamTrack interface for publishing + to the call. + """ + + def __init__(self, width: int = 1280, height: int = 720): + """Initialize the Decart video track. + + Args: + width: Video frame width. + height: Video frame height. + """ + super().__init__() + + self.width = width + self.height = height + + self.frame_queue: VideoLatestNQueue[av.VideoFrame] = VideoLatestNQueue(maxlen=2) + placeholder = Image.new("RGB", (self.width, self.height), color=(30, 30, 40)) + self.placeholder_frame = av.VideoFrame.from_image(placeholder) + self.last_frame: av.VideoFrame = self.placeholder_frame + + self._stopped = False + self._source_track: Optional[MediaStreamTrack] = None + + logger.debug(f"DecartVideoTrack initialized ({width}x{height})") + + async def add_frame(self, frame: av.VideoFrame | av.AudioFrame | av.Packet) -> None: + if self._stopped: + return + if not isinstance(frame, av.VideoFrame): + return + if frame.width != self.width or frame.height != self.height: + frame = await asyncio.to_thread(resize_frame, self, frame) + self.frame_queue.put_latest_nowait(frame) + + async def recv(self) -> av.VideoFrame: + if self._stopped: + raise Exception("Track stopped") + + try: + frame = await asyncio.wait_for( + self.frame_queue.get(), + timeout=0.033, + ) + if frame: + self.last_frame = frame + except asyncio.TimeoutError: + pass + + pts, time_base = await self.next_timestamp() + + output_frame = self.last_frame + output_frame.pts = pts + output_frame.time_base = time_base + + return output_frame + + @property + def is_stopped(self) -> bool: + """Check if the video track is stopped.""" + return self._stopped + + def stop(self) -> None: + self._stopped = True + super().stop() diff --git a/pyproject.toml b/pyproject.toml index 9dde8ab9..1d141fb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ vision-agents-plugins-inworld = { workspace = true } vision-agents-plugins-moondream = { workspace = true } vision-agents-plugins-vogent = { workspace = true } vision-agents-plugins-fast-whisper = { workspace = true } +vision-agents-plugins-decart = { workspace = true } [tool.uv] # Workspace-level override to resolve numpy version conflicts @@ -55,7 +56,8 @@ members = [ "plugins/inworld", "plugins/vogent", "plugins/moondream", - "plugins/fast_whisper" + "plugins/fast_whisper", + "plugins/decart" ] exclude = [ "**/__pycache__", diff --git a/uv.lock b/uv.lock index 6b0929f5..61ada853 100644 --- a/uv.lock +++ b/uv.lock @@ -14,6 +14,7 @@ members = [ "vision-agents-plugins-anthropic", "vision-agents-plugins-aws", "vision-agents-plugins-cartesia", + "vision-agents-plugins-decart", "vision-agents-plugins-deepgram", "vision-agents-plugins-elevenlabs", "vision-agents-plugins-fast-whisper", @@ -1017,6 +1018,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, ] +[[package]] +name = "decart" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiofiles" }, + { name = "aiohttp" }, + { name = "numpy" }, + { name = "opencv-python" }, + { name = "pydantic" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/86/75e55da0dc56528b232e31186c02bf409ba9bce888e5520507a9737cdc36/decart-0.0.8.tar.gz", hash = "sha256:95ea6fc0af5896198ce76abf7d2ecf39feec77e3ebf7c5fa55e2f239879d5545", size = 198620, upload-time = "2025-11-12T17:02:42.627Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/e2/e128f098aaf0981849da4684a30d53b9a84be84fc372709de14557d135e9/decart-0.0.8-py3-none-any.whl", hash = "sha256:3f1ab43ced32edcd5a494673bc8f899a559f6a5c59796da88d71c5f48b8422dc", size = 20263, upload-time = "2025-11-12T17:02:40.913Z" }, +] + [[package]] name = "decorator" version = "5.2.1" @@ -5556,6 +5574,32 @@ dev = [ { name = "pytest-asyncio", specifier = ">=1.0.0" }, ] +[[package]] +name = "vision-agents-plugins-decart" +source = { editable = "plugins/decart" } +dependencies = [ + { name = "decart" }, + { name = "vision-agents" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, +] + +[package.metadata] +requires-dist = [ + { name = "decart" }, + { name = "vision-agents", editable = "agents-core" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=1.0.0" }, +] + [[package]] name = "vision-agents-plugins-deepgram" source = { editable = "plugins/deepgram" }