Skip to content

Commit cb65397

Browse files
authored
Merge pull request #47 from speechmatics/fix/wait-for-eot-message
Wait for eot on close
2 parents 42943b1 + f8b8a6c commit cb65397

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

examples/rt/async/file/main.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
import asyncio
3+
from speechmatics.rt import AsyncClient, ServerMessageType
4+
5+
6+
async def main():
7+
# Create a client using environment variable SPEECHMATICS_API_KEY
8+
async with AsyncClient() as client:
9+
# Register event handlers
10+
@client.on(ServerMessageType.ADD_TRANSCRIPT)
11+
def handle_final_transcript(msg):
12+
print(f"Final: {msg['metadata']['transcript']}")
13+
14+
# Transcribe audio file
15+
with open("./examples/example.wav", "rb") as audio_file:
16+
await client.transcribe(audio_file)
17+
18+
# Run the async function
19+
asyncio.run(main())

sdk/rt/speechmatics/rt/_async_client.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def __init__(
8181
self._recognition_started_evt,
8282
self._session_done_evt,
8383
) = self._init_session_info()
84-
self._eos_sent = False
8584

8685
transport = self._create_transport_from_config(
8786
auth=auth,
@@ -96,6 +95,7 @@ def __init__(
9695
self.on(ServerMessageType.END_OF_TRANSCRIPT, self._on_eot)
9796
self.on(ServerMessageType.ERROR, self._on_error)
9897
self.on(ServerMessageType.WARNING, self._on_warning)
98+
self.on(ServerMessageType.AUDIO_ADDED, self._on_audio_added)
9999

100100
self._logger.debug("AsyncClient initialized (request_id=%s)", self._session.request_id)
101101

@@ -141,6 +141,26 @@ async def start_session(
141141
ws_headers=ws_headers,
142142
)
143143

144+
async def stop_session(self) -> None:
145+
"""
146+
This method closes the WebSocket connection and ends the transcription session.
147+
148+
Raises:
149+
ConnectionError: If the WebSocket connection fails.
150+
TranscriptionError: If the server reports an error during teardown.
151+
TimeoutError: If the connection or teardown times out.
152+
153+
Examples:
154+
Basic streaming:
155+
>>> async with AsyncClient() as client:
156+
... await client.start_session()
157+
... await client.send_audio(frame)
158+
... await client.stop_session()
159+
"""
160+
await self._send_eos(self._seq_no)
161+
await self._session_done_evt.wait() # Wait for end of transcript event to indicate we can stop listening
162+
await self.close()
163+
144164
async def transcribe(
145165
self,
146166
source: BinaryIO,
@@ -233,7 +253,6 @@ async def _audio_producer(self, source: BinaryIO, chunk_size: int) -> None:
233253
chunk_size: Chunk size for audio data
234254
"""
235255
src = FileSource(source, chunk_size=chunk_size)
236-
seq_no = 0
237256

238257
try:
239258
async for frame in src:
@@ -242,13 +261,12 @@ async def _audio_producer(self, source: BinaryIO, chunk_size: int) -> None:
242261

243262
try:
244263
await self.send_audio(frame)
245-
seq_no += 1
246264
except Exception as e:
247265
self._logger.error("Failed to send audio frame: %s", e)
248266
self._session_done_evt.set()
249267
break
250268

251-
await self._send_eos(seq_no)
269+
await self.stop_session()
252270
except asyncio.CancelledError:
253271
raise
254272
except Exception as e:
@@ -286,13 +304,19 @@ def _on_error(self, msg: dict[str, Any]) -> None:
286304
self._session_done_evt.set()
287305
raise TranscriptionError(error)
288306

307+
def _on_audio_added(self, msg: dict[str, Any]) -> None:
308+
"""Handle AudioAdded message from server."""
309+
self._seq_no = msg.get("seq_no", 0)
310+
289311
def _on_warning(self, msg: dict[str, Any]) -> None:
290312
"""Handle Warning message from server."""
291313
self._logger.warning("Server warning: %s", msg.get("reason", "unknown"))
292314

293315
async def close(self) -> None:
294316
"""
295317
Close the client and clean up resources.
318+
WARNING: this closes the client without waiting for remaining messages to be processed.
319+
It is recommended to use stop_session() instead.
296320
297321
Ensures the session is marked as complete and delegates to the base
298322
class for full cleanup including WebSocket connection termination.

sdk/rt/speechmatics/rt/_base_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(self, transport: Transport) -> None:
4141
self._transport = transport
4242
self._recv_task: Optional[asyncio.Task[None]] = None
4343
self._closed_evt = asyncio.Event()
44+
self._eos_sent = False
45+
self._seq_no = 0
4446

4547
self._logger = get_logger("speechmatics.rt.base_client")
4648

@@ -112,14 +114,15 @@ async def send_audio(self, payload: bytes) -> None:
112114
>>> audio_chunk = b""
113115
>>> await client.send_audio(audio_chunk)
114116
"""
115-
if self._closed_evt.is_set():
117+
if self._closed_evt.is_set() or self._eos_sent:
116118
raise TransportError("Client is closed")
117119

118120
if not isinstance(payload, bytes):
119121
raise ValueError("Payload must be bytes")
120122

121123
try:
122124
await self._transport.send_message(payload)
125+
self._seq_no += 1
123126
except Exception:
124127
self._closed_evt.set()
125128
raise
@@ -133,7 +136,7 @@ async def send_message(self, message: dict[str, Any]) -> None:
133136
>>> msg = json.dumps({"message": "StartRecognition", ...})
134137
>>> await client.send_message(msg)
135138
"""
136-
if self._closed_evt.is_set():
139+
if self._closed_evt.is_set() or self._eos_sent:
137140
raise TransportError("Client is closed")
138141

139142
if not isinstance(message, dict):

0 commit comments

Comments
 (0)