Alina Lozovskaya
commited on
Commit
·
3bd411c
1
Parent(s):
21147c0
Handle abrupt websocket closures with retry and safe shutdown
Browse files
src/reachy_mini_conversation_app/openai_realtime.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import base64
|
|
|
|
| 3 |
import asyncio
|
| 4 |
import logging
|
| 5 |
from typing import Any, Tuple, Literal, cast
|
|
@@ -10,6 +11,7 @@ import gradio as gr
|
|
| 10 |
from openai import AsyncOpenAI
|
| 11 |
from fastrtc import AdditionalOutputs, AsyncStreamHandler, wait_for_item
|
| 12 |
from numpy.typing import NDArray
|
|
|
|
| 13 |
|
| 14 |
from reachy_mini_conversation_app.tools import (
|
| 15 |
ALL_TOOL_SPECS,
|
|
@@ -68,206 +70,227 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 68 |
return cast(NDArray[np.int16], resampled.astype(np.int16))
|
| 69 |
|
| 70 |
async def start_up(self) -> None:
|
| 71 |
-
"""Start the handler."""
|
| 72 |
self.client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
|
| 73 |
-
async with self.client.realtime.connect(model=config.MODEL_NAME) as conn:
|
| 74 |
-
try:
|
| 75 |
-
await conn.session.update(
|
| 76 |
-
session={
|
| 77 |
-
"type": "realtime",
|
| 78 |
-
"instructions": SESSION_INSTRUCTIONS,
|
| 79 |
-
"audio": {
|
| 80 |
-
"input": {
|
| 81 |
-
"format": {
|
| 82 |
-
"type": "audio/pcm",
|
| 83 |
-
"rate": self.target_input_rate,
|
| 84 |
-
},
|
| 85 |
-
"transcription": {
|
| 86 |
-
"model": "whisper-1",
|
| 87 |
-
"language": "en"
|
| 88 |
-
},
|
| 89 |
-
"turn_detection": {
|
| 90 |
-
"type": "server_vad",
|
| 91 |
-
"interrupt_response": True,
|
| 92 |
-
},
|
| 93 |
-
},
|
| 94 |
-
"output": {
|
| 95 |
-
"format": {
|
| 96 |
-
"type": "audio/pcm",
|
| 97 |
-
"rate": self.output_sample_rate,
|
| 98 |
-
},
|
| 99 |
-
"voice": "cedar",
|
| 100 |
-
},
|
| 101 |
-
},
|
| 102 |
-
"tools": ALL_TOOL_SPECS, # type: ignore[typeddict-item]
|
| 103 |
-
"tool_choice": "auto",
|
| 104 |
-
},
|
| 105 |
-
)
|
| 106 |
-
except Exception:
|
| 107 |
-
logger.exception("Realtime session.update failed; aborting startup")
|
| 108 |
-
return
|
| 109 |
-
|
| 110 |
-
logger.info("Realtime session updated successfully")
|
| 111 |
-
|
| 112 |
-
# Manage event received from the openai server
|
| 113 |
-
self.connection = conn
|
| 114 |
-
async for event in self.connection:
|
| 115 |
-
logger.debug(f"OpenAI event: {event.type}")
|
| 116 |
-
if event.type == "input_audio_buffer.speech_started":
|
| 117 |
-
if hasattr(self, "_clear_queue") and callable(self._clear_queue):
|
| 118 |
-
self._clear_queue()
|
| 119 |
-
if self.deps.head_wobbler is not None:
|
| 120 |
-
self.deps.head_wobbler.reset()
|
| 121 |
-
self.deps.movement_manager.set_listening(True)
|
| 122 |
-
logger.debug("User speech started")
|
| 123 |
-
|
| 124 |
-
if event.type == "input_audio_buffer.speech_stopped":
|
| 125 |
-
self.deps.movement_manager.set_listening(False)
|
| 126 |
-
logger.debug("User speech stopped - server will auto-commit with VAD")
|
| 127 |
-
|
| 128 |
-
if event.type in (
|
| 129 |
-
"response.audio.done", # GA
|
| 130 |
-
"response.output_audio.done", # GA alias
|
| 131 |
-
"response.audio.completed", # legacy (for safety)
|
| 132 |
-
"response.completed", # text-only completion
|
| 133 |
-
):
|
| 134 |
-
logger.debug("response completed")
|
| 135 |
-
|
| 136 |
-
if event.type == "response.created":
|
| 137 |
-
logger.debug("Response created")
|
| 138 |
-
|
| 139 |
-
if event.type == "response.done":
|
| 140 |
-
# Doesn't mean the audio is done playing
|
| 141 |
-
logger.debug("Response done")
|
| 142 |
-
|
| 143 |
-
# Handle partial transcription (user speaking in real-time)
|
| 144 |
-
if event.type == "conversation.item.input_audio_transcription.partial":
|
| 145 |
-
logger.debug(f"User partial transcript: {event.transcript}")
|
| 146 |
-
await self.output_queue.put(
|
| 147 |
-
AdditionalOutputs({"role": "user_partial", "content": event.transcript})
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
# Handle completed transcription (user finished speaking)
|
| 151 |
-
if event.type == "conversation.item.input_audio_transcription.completed":
|
| 152 |
-
logger.debug(f"User transcript: {event.transcript}")
|
| 153 |
-
await self.output_queue.put(AdditionalOutputs({"role": "user", "content": event.transcript}))
|
| 154 |
-
|
| 155 |
-
# Handle assistant transcription
|
| 156 |
-
if event.type in ("response.audio_transcript.done", "response.output_audio_transcript.done"):
|
| 157 |
-
logger.debug(f"Assistant transcript: {event.transcript}")
|
| 158 |
-
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": event.transcript}))
|
| 159 |
-
|
| 160 |
-
# Handle audio delta
|
| 161 |
-
if event.type in ("response.audio.delta", "response.output_audio.delta"):
|
| 162 |
-
if self.deps.head_wobbler is not None:
|
| 163 |
-
self.deps.head_wobbler.feed(event.delta)
|
| 164 |
-
self.last_activity_time = asyncio.get_event_loop().time()
|
| 165 |
-
logger.debug("last activity time updated to %s", self.last_activity_time)
|
| 166 |
-
await self.output_queue.put(
|
| 167 |
-
(
|
| 168 |
-
self.output_sample_rate,
|
| 169 |
-
np.frombuffer(base64.b64decode(event.delta), dtype=np.int16).reshape(1, -1),
|
| 170 |
-
),
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
# ---- tool-calling plumbing ----
|
| 175 |
-
if event.type == "response.function_call_arguments.done":
|
| 176 |
-
tool_name = getattr(event, "name", None)
|
| 177 |
-
args_json_str = getattr(event, "arguments", None)
|
| 178 |
-
call_id = getattr(event, "call_id", None)
|
| 179 |
-
|
| 180 |
-
if not isinstance(tool_name, str) or not isinstance(args_json_str, str):
|
| 181 |
-
logger.error("Invalid tool call: tool_name=%s, args=%s", tool_name, args_json_str)
|
| 182 |
-
continue
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
try:
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
AdditionalOutputs(
|
| 204 |
-
{
|
| 205 |
-
"role": "assistant",
|
| 206 |
-
"content": json.dumps(tool_result),
|
| 207 |
-
"metadata": {"title": f"🛠️ Used tool {tool_name}", "status": "done"},
|
| 208 |
-
},
|
| 209 |
-
),
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
if tool_name == "camera" and "b64_im" in tool_result:
|
| 213 |
-
# use raw base64, don't json.dumps (which adds quotes)
|
| 214 |
-
b64_im = tool_result["b64_im"]
|
| 215 |
-
if not isinstance(b64_im, str):
|
| 216 |
-
logger.warning("Unexpected type for b64_im: %s", type(b64_im))
|
| 217 |
-
b64_im = str(b64_im)
|
| 218 |
-
await self.connection.conversation.item.create(
|
| 219 |
-
item={
|
| 220 |
-
"type": "message",
|
| 221 |
-
"role": "user",
|
| 222 |
-
"content": [
|
| 223 |
-
{
|
| 224 |
-
"type": "input_image",
|
| 225 |
-
"image_url": f"data:image/jpeg;base64,{b64_im}",
|
| 226 |
},
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
np_img = self.deps.camera_worker.get_latest_frame()
|
| 234 |
-
img = gr.Image(value=np_img)
|
| 235 |
-
|
| 236 |
-
await self.output_queue.put(
|
| 237 |
-
AdditionalOutputs(
|
| 238 |
-
{
|
| 239 |
-
"role": "assistant",
|
| 240 |
-
"content": img,
|
| 241 |
},
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
# if this tool call was triggered by an idle signal, don't make the robot speak
|
| 246 |
-
# for other tool calls, let the robot reply out loud
|
| 247 |
-
if self.is_idle_tool_call:
|
| 248 |
-
self.is_idle_tool_call = False
|
| 249 |
-
else:
|
| 250 |
-
await self.connection.response.create(
|
| 251 |
-
response={
|
| 252 |
-
"instructions": "Use the tool result just returned and answer concisely in speech.",
|
| 253 |
},
|
| 254 |
)
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
|
| 257 |
-
if self.deps.head_wobbler is not None:
|
| 258 |
-
self.deps.head_wobbler.reset()
|
| 259 |
-
|
| 260 |
-
# server error
|
| 261 |
-
if event.type == "error":
|
| 262 |
-
err = getattr(event, "error", None)
|
| 263 |
-
msg = getattr(err, "message", str(err) if err else "unknown error")
|
| 264 |
-
code = getattr(err, "code", "")
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# Microphone receive
|
| 273 |
async def receive(self, frame: Tuple[int, NDArray[np.int16]]) -> None:
|
|
@@ -305,8 +328,14 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 305 |
async def shutdown(self) -> None:
|
| 306 |
"""Shutdown the handler."""
|
| 307 |
if self.connection:
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# Clear any remaining items in the output queue
|
| 312 |
while not self.output_queue.empty():
|
|
|
|
| 1 |
import json
|
| 2 |
import base64
|
| 3 |
+
import random
|
| 4 |
import asyncio
|
| 5 |
import logging
|
| 6 |
from typing import Any, Tuple, Literal, cast
|
|
|
|
| 11 |
from openai import AsyncOpenAI
|
| 12 |
from fastrtc import AdditionalOutputs, AsyncStreamHandler, wait_for_item
|
| 13 |
from numpy.typing import NDArray
|
| 14 |
+
from websockets.exceptions import ConnectionClosedError
|
| 15 |
|
| 16 |
from reachy_mini_conversation_app.tools import (
|
| 17 |
ALL_TOOL_SPECS,
|
|
|
|
| 70 |
return cast(NDArray[np.int16], resampled.astype(np.int16))
|
| 71 |
|
| 72 |
async def start_up(self) -> None:
|
| 73 |
+
"""Start the handler with minimal retries on unexpected websocket closure."""
|
| 74 |
self.client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
max_attempts = 3
|
| 77 |
+
for attempt in range(1, max_attempts + 1):
|
| 78 |
+
try:
|
| 79 |
+
async with self.client.realtime.connect(model=config.MODEL_NAME) as conn:
|
| 80 |
try:
|
| 81 |
+
await conn.session.update(
|
| 82 |
+
session={
|
| 83 |
+
"type": "realtime",
|
| 84 |
+
"instructions": SESSION_INSTRUCTIONS,
|
| 85 |
+
"audio": {
|
| 86 |
+
"input": {
|
| 87 |
+
"format": {
|
| 88 |
+
"type": "audio/pcm",
|
| 89 |
+
"rate": self.target_input_rate,
|
| 90 |
+
},
|
| 91 |
+
"transcription": {
|
| 92 |
+
"model": "whisper-1",
|
| 93 |
+
"language": "en"
|
| 94 |
+
},
|
| 95 |
+
"turn_detection": {
|
| 96 |
+
"type": "server_vad",
|
| 97 |
+
"interrupt_response": True,
|
| 98 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
},
|
| 100 |
+
"output": {
|
| 101 |
+
"format": {
|
| 102 |
+
"type": "audio/pcm",
|
| 103 |
+
"rate": self.output_sample_rate,
|
| 104 |
+
},
|
| 105 |
+
"voice": "cedar",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
},
|
| 107 |
+
},
|
| 108 |
+
"tools": ALL_TOOL_SPECS, # type: ignore[typeddict-item]
|
| 109 |
+
"tool_choice": "auto",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
},
|
| 111 |
)
|
| 112 |
+
except Exception:
|
| 113 |
+
logger.exception("Realtime session.update failed; aborting startup")
|
| 114 |
+
return
|
| 115 |
|
| 116 |
+
logger.info("Realtime session updated successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
# Manage event received from the openai server
|
| 119 |
+
self.connection = conn
|
| 120 |
+
try:
|
| 121 |
+
async for event in self.connection:
|
| 122 |
+
logger.debug(f"OpenAI event: {event.type}")
|
| 123 |
+
if event.type == "input_audio_buffer.speech_started":
|
| 124 |
+
if hasattr(self, "_clear_queue") and callable(self._clear_queue):
|
| 125 |
+
self._clear_queue()
|
| 126 |
+
if self.deps.head_wobbler is not None:
|
| 127 |
+
self.deps.head_wobbler.reset()
|
| 128 |
+
self.deps.movement_manager.set_listening(True)
|
| 129 |
+
logger.debug("User speech started")
|
| 130 |
+
|
| 131 |
+
if event.type == "input_audio_buffer.speech_stopped":
|
| 132 |
+
self.deps.movement_manager.set_listening(False)
|
| 133 |
+
logger.debug("User speech stopped - server will auto-commit with VAD")
|
| 134 |
+
|
| 135 |
+
if event.type in (
|
| 136 |
+
"response.audio.done", # GA
|
| 137 |
+
"response.output_audio.done", # GA alias
|
| 138 |
+
"response.audio.completed", # legacy (for safety)
|
| 139 |
+
"response.completed", # text-only completion
|
| 140 |
+
):
|
| 141 |
+
logger.debug("response completed")
|
| 142 |
+
|
| 143 |
+
if event.type == "response.created":
|
| 144 |
+
logger.debug("Response created")
|
| 145 |
+
|
| 146 |
+
if event.type == "response.done":
|
| 147 |
+
# Doesn't mean the audio is done playing
|
| 148 |
+
logger.debug("Response done")
|
| 149 |
+
|
| 150 |
+
# Handle partial transcription (user speaking in real-time)
|
| 151 |
+
if event.type == "conversation.item.input_audio_transcription.partial":
|
| 152 |
+
logger.debug(f"User partial transcript: {event.transcript}")
|
| 153 |
+
await self.output_queue.put(
|
| 154 |
+
AdditionalOutputs({"role": "user_partial", "content": event.transcript})
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Handle completed transcription (user finished speaking)
|
| 158 |
+
if event.type == "conversation.item.input_audio_transcription.completed":
|
| 159 |
+
logger.debug(f"User transcript: {event.transcript}")
|
| 160 |
+
await self.output_queue.put(AdditionalOutputs({"role": "user", "content": event.transcript}))
|
| 161 |
+
|
| 162 |
+
# Handle assistant transcription
|
| 163 |
+
if event.type in ("response.audio_transcript.done", "response.output_audio_transcript.done"):
|
| 164 |
+
logger.debug(f"Assistant transcript: {event.transcript}")
|
| 165 |
+
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": event.transcript}))
|
| 166 |
+
|
| 167 |
+
# Handle audio delta
|
| 168 |
+
if event.type in ("response.audio.delta", "response.output_audio.delta"):
|
| 169 |
+
if self.deps.head_wobbler is not None:
|
| 170 |
+
self.deps.head_wobbler.feed(event.delta)
|
| 171 |
+
self.last_activity_time = asyncio.get_event_loop().time()
|
| 172 |
+
logger.debug("last activity time updated to %s", self.last_activity_time)
|
| 173 |
+
await self.output_queue.put(
|
| 174 |
+
(
|
| 175 |
+
self.output_sample_rate,
|
| 176 |
+
np.frombuffer(base64.b64decode(event.delta), dtype=np.int16).reshape(1, -1),
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# ---- tool-calling plumbing ----
|
| 181 |
+
if event.type == "response.function_call_arguments.done":
|
| 182 |
+
tool_name = getattr(event, "name", None)
|
| 183 |
+
args_json_str = getattr(event, "arguments", None)
|
| 184 |
+
call_id = getattr(event, "call_id", None)
|
| 185 |
+
|
| 186 |
+
if not isinstance(tool_name, str) or not isinstance(args_json_str, str):
|
| 187 |
+
logger.error("Invalid tool call: tool_name=%s, args=%s", tool_name, args_json_str)
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
tool_result = await dispatch_tool_call(tool_name, args_json_str, self.deps)
|
| 192 |
+
logger.debug("Tool '%s' executed successfully", tool_name)
|
| 193 |
+
logger.debug("Tool result: %s", tool_result)
|
| 194 |
+
except Exception as e:
|
| 195 |
+
logger.error("Tool '%s' failed", tool_name)
|
| 196 |
+
tool_result = {"error": str(e)}
|
| 197 |
+
|
| 198 |
+
# send the tool result back
|
| 199 |
+
if isinstance(call_id, str):
|
| 200 |
+
await self.connection.conversation.item.create(
|
| 201 |
+
item={
|
| 202 |
+
"type": "function_call_output",
|
| 203 |
+
"call_id": call_id,
|
| 204 |
+
"output": json.dumps(tool_result),
|
| 205 |
+
},
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
await self.output_queue.put(
|
| 209 |
+
AdditionalOutputs(
|
| 210 |
+
{
|
| 211 |
+
"role": "assistant",
|
| 212 |
+
"content": json.dumps(tool_result),
|
| 213 |
+
"metadata": {"title": f"🛠️ Used tool {tool_name}", "status": "done"},
|
| 214 |
+
},
|
| 215 |
+
),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if tool_name == "camera" and "b64_im" in tool_result:
|
| 219 |
+
# use raw base64, don't json.dumps (which adds quotes)
|
| 220 |
+
b64_im = tool_result["b64_im"]
|
| 221 |
+
if not isinstance(b64_im, str):
|
| 222 |
+
logger.warning("Unexpected type for b64_im: %s", type(b64_im))
|
| 223 |
+
b64_im = str(b64_im)
|
| 224 |
+
await self.connection.conversation.item.create(
|
| 225 |
+
item={
|
| 226 |
+
"type": "message",
|
| 227 |
+
"role": "user",
|
| 228 |
+
"content": [
|
| 229 |
+
{
|
| 230 |
+
"type": "input_image",
|
| 231 |
+
"image_url": f"data:image/jpeg;base64,{b64_im}",
|
| 232 |
+
},
|
| 233 |
+
],
|
| 234 |
+
},
|
| 235 |
+
)
|
| 236 |
+
logger.info("Added camera image to conversation")
|
| 237 |
+
|
| 238 |
+
if self.deps.camera_worker is not None:
|
| 239 |
+
np_img = self.deps.camera_worker.get_latest_frame()
|
| 240 |
+
img = gr.Image(value=np_img)
|
| 241 |
+
|
| 242 |
+
await self.output_queue.put(
|
| 243 |
+
AdditionalOutputs(
|
| 244 |
+
{
|
| 245 |
+
"role": "assistant",
|
| 246 |
+
"content": img,
|
| 247 |
+
},
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# if this tool call was triggered by an idle signal, don't make the robot speak
|
| 252 |
+
# for other tool calls, let the robot reply out loud
|
| 253 |
+
if self.is_idle_tool_call:
|
| 254 |
+
self.is_idle_tool_call = False
|
| 255 |
+
else:
|
| 256 |
+
await self.connection.response.create(
|
| 257 |
+
response={
|
| 258 |
+
"instructions": "Use the tool result just returned and answer concisely in speech.",
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# re synchronize the head wobble after a tool call that may have taken some time
|
| 263 |
+
if self.deps.head_wobbler is not None:
|
| 264 |
+
self.deps.head_wobbler.reset()
|
| 265 |
+
|
| 266 |
+
# server error
|
| 267 |
+
if event.type == "error":
|
| 268 |
+
err = getattr(event, "error", None)
|
| 269 |
+
msg = getattr(err, "message", str(err) if err else "unknown error")
|
| 270 |
+
code = getattr(err, "code", "")
|
| 271 |
+
|
| 272 |
+
logger.error("Realtime error [%s]: %s (raw=%s)", code, msg, err)
|
| 273 |
+
|
| 274 |
+
# Only show user-facing errors, not internal state errors
|
| 275 |
+
if code not in ("input_audio_buffer_commit_empty", "conversation_already_has_active_response"):
|
| 276 |
+
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": f"[error] {msg}"}))
|
| 277 |
+
|
| 278 |
+
except ConnectionClosedError as e:
|
| 279 |
+
# Abrupt close (e.g., "no close frame received or sent") → retry
|
| 280 |
+
logger.warning(
|
| 281 |
+
"Realtime websocket closed unexpectedly (attempt %d/%d): %s",
|
| 282 |
+
attempt, max_attempts, e
|
| 283 |
+
)
|
| 284 |
+
if attempt < max_attempts:
|
| 285 |
+
# small jittered backoff
|
| 286 |
+
await asyncio.sleep(1.0 + random.uniform(0, 0.5))
|
| 287 |
+
continue
|
| 288 |
+
raise
|
| 289 |
+
# Normal exit from the receive loop, stop retrying
|
| 290 |
+
return
|
| 291 |
+
finally:
|
| 292 |
+
# never keep a stale reference
|
| 293 |
+
self.connection = None
|
| 294 |
|
| 295 |
# Microphone receive
|
| 296 |
async def receive(self, frame: Tuple[int, NDArray[np.int16]]) -> None:
|
|
|
|
| 328 |
async def shutdown(self) -> None:
|
| 329 |
"""Shutdown the handler."""
|
| 330 |
if self.connection:
|
| 331 |
+
try:
|
| 332 |
+
await self.connection.close()
|
| 333 |
+
except ConnectionClosedError:
|
| 334 |
+
pass
|
| 335 |
+
except Exception as e:
|
| 336 |
+
logger.debug(f"connection.close() ignored: {e}")
|
| 337 |
+
finally:
|
| 338 |
+
self.connection = None
|
| 339 |
|
| 340 |
# Clear any remaining items in the output queue
|
| 341 |
while not self.output_queue.empty():
|
tests/test_openai_realtime.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
import asyncio
|
|
|
|
|
|
|
| 2 |
from datetime import datetime, timezone
|
| 3 |
from unittest.mock import MagicMock
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
from reachy_mini_conversation_app.tools import ToolDependencies
|
| 6 |
from reachy_mini_conversation_app.openai_realtime import OpenaiRealtimeHandler
|
| 7 |
|
|
@@ -27,3 +32,86 @@ def test_format_timestamp_uses_wall_clock() -> None:
|
|
| 27 |
# Extract year from "[YYYY-MM-DD ...]"
|
| 28 |
year = int(formatted[1:5])
|
| 29 |
assert year == datetime.now(timezone.utc).year
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any
|
| 4 |
from datetime import datetime, timezone
|
| 5 |
from unittest.mock import MagicMock
|
| 6 |
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
import reachy_mini_conversation_app.openai_realtime as rt_mod
|
| 10 |
from reachy_mini_conversation_app.tools import ToolDependencies
|
| 11 |
from reachy_mini_conversation_app.openai_realtime import OpenaiRealtimeHandler
|
| 12 |
|
|
|
|
| 32 |
# Extract year from "[YYYY-MM-DD ...]"
|
| 33 |
year = int(formatted[1:5])
|
| 34 |
assert year == datetime.now(timezone.utc).year
|
| 35 |
+
|
| 36 |
+
@pytest.mark.asyncio
|
| 37 |
+
async def test_start_up_retries_on_abrupt_close(monkeypatch: Any, caplog: Any) -> None:
|
| 38 |
+
"""First connection dies with ConnectionClosedError during iteration -> retried.
|
| 39 |
+
|
| 40 |
+
Second connection iterates cleanly (no events) -> start_up returns without raising.
|
| 41 |
+
Ensures handler clears self.connection at the end.
|
| 42 |
+
"""
|
| 43 |
+
caplog.set_level(logging.WARNING)
|
| 44 |
+
|
| 45 |
+
# Use a local Exception as the module's ConnectionClosedError to avoid ws dependency
|
| 46 |
+
FakeCCE = type("FakeCCE", (Exception,), {})
|
| 47 |
+
monkeypatch.setattr(rt_mod, "ConnectionClosedError", FakeCCE)
|
| 48 |
+
|
| 49 |
+
# Make asyncio.sleep return immediately (for backoff)
|
| 50 |
+
async def _fast_sleep(*_a: Any, **_kw: Any) -> None: return None
|
| 51 |
+
monkeypatch.setattr(asyncio, "sleep", _fast_sleep, raising=False)
|
| 52 |
+
|
| 53 |
+
attempt_counter = {"n": 0}
|
| 54 |
+
|
| 55 |
+
class FakeConn:
|
| 56 |
+
"""Minimal realtime connection stub."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, mode: str):
|
| 59 |
+
self._mode = mode
|
| 60 |
+
|
| 61 |
+
class _Session:
|
| 62 |
+
async def update(self, **_kw: Any) -> None: return None
|
| 63 |
+
self.session = _Session()
|
| 64 |
+
|
| 65 |
+
class _InputAudioBuffer:
|
| 66 |
+
async def append(self, **_kw: Any) -> None: return None
|
| 67 |
+
self.input_audio_buffer = _InputAudioBuffer()
|
| 68 |
+
|
| 69 |
+
class _Item:
|
| 70 |
+
async def create(self, **_kw: Any) -> None: return None
|
| 71 |
+
|
| 72 |
+
class _Conversation:
|
| 73 |
+
item = _Item()
|
| 74 |
+
self.conversation = _Conversation()
|
| 75 |
+
|
| 76 |
+
class _Response:
|
| 77 |
+
async def create(self, **_kw: Any) -> None: return None
|
| 78 |
+
async def cancel(self, **_kw: Any) -> None: return None
|
| 79 |
+
self.response = _Response()
|
| 80 |
+
|
| 81 |
+
async def __aenter__(self) -> "FakeConn": return self
|
| 82 |
+
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: return False
|
| 83 |
+
async def close(self) -> None: return None
|
| 84 |
+
|
| 85 |
+
# Async iterator protocol
|
| 86 |
+
def __aiter__(self) -> "FakeConn": return self
|
| 87 |
+
async def __anext__(self) -> None:
|
| 88 |
+
if self._mode == "raise_on_iter":
|
| 89 |
+
raise FakeCCE("abrupt close (simulated)")
|
| 90 |
+
raise StopAsyncIteration # clean exit (no events)
|
| 91 |
+
|
| 92 |
+
class FakeRealtime:
|
| 93 |
+
def connect(self, **_kw: Any) -> FakeConn:
|
| 94 |
+
attempt_counter["n"] += 1
|
| 95 |
+
mode = "raise_on_iter" if attempt_counter["n"] == 1 else "clean"
|
| 96 |
+
return FakeConn(mode)
|
| 97 |
+
|
| 98 |
+
class FakeClient:
|
| 99 |
+
def __init__(self, **_kw: Any) -> None: self.realtime = FakeRealtime()
|
| 100 |
+
|
| 101 |
+
# Patch the OpenAI client used by the handler
|
| 102 |
+
monkeypatch.setattr(rt_mod, "AsyncOpenAI", FakeClient)
|
| 103 |
+
|
| 104 |
+
# Build handler with minimal deps
|
| 105 |
+
deps = ToolDependencies(reachy_mini=MagicMock(), movement_manager=MagicMock())
|
| 106 |
+
handler = rt_mod.OpenaiRealtimeHandler(deps)
|
| 107 |
+
|
| 108 |
+
# Run: should retry once and exit cleanly
|
| 109 |
+
await handler.start_up()
|
| 110 |
+
|
| 111 |
+
# Validate: two attempts total (fail -> retry -> succeed), and connection cleared
|
| 112 |
+
assert attempt_counter["n"] == 2
|
| 113 |
+
assert handler.connection is None
|
| 114 |
+
|
| 115 |
+
# Optional: confirm we logged the unexpected close once
|
| 116 |
+
warnings = [r for r in caplog.records if r.levelname == "WARNING" and "closed unexpectedly" in r.msg]
|
| 117 |
+
assert len(warnings) == 1
|