mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
fix: fix interruption handling for Gemini Live
1. Fixes #236 2. Fix run_inference for variable extraction for Gemini Live
This commit is contained in:
parent
14e6f29f2f
commit
e31b38122e
12 changed files with 48 additions and 15 deletions
|
|
@ -120,9 +120,21 @@ class InMemoryLogsBuffer:
|
|||
f"Incremented turn counter to {self._turn_counter} for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _event_sort_key(event: dict) -> str:
|
||||
payload_ts = event.get("payload", {}).get("timestamp")
|
||||
return payload_ts or event.get("timestamp", "")
|
||||
|
||||
def _sorted_events(self) -> List[dict]:
|
||||
# Stable sort by the realtime (payload) timestamp when available, falling
|
||||
# back to the buffer-append timestamp. Python's sort is stable, so events
|
||||
# sharing a key retain their original insertion order — this keeps
|
||||
# consecutive bot-text chunks of a single turn contiguous.
|
||||
return sorted(self._events, key=self._event_sort_key)
|
||||
|
||||
def get_events(self) -> List[dict]:
|
||||
"""Get all events for final storage."""
|
||||
return self._events
|
||||
"""Get all events for final storage, ordered by realtime timestamp."""
|
||||
return self._sorted_events()
|
||||
|
||||
def contains_user_speech(self) -> bool:
|
||||
"""Return True if any final user transcription event has non-empty text."""
|
||||
|
|
@ -141,7 +153,7 @@ class InMemoryLogsBuffer:
|
|||
Filters for rtf-user-transcription (final) and rtf-bot-text events,
|
||||
formats them as '[timestamp] user/assistant: text\\n'.
|
||||
"""
|
||||
return _generate_transcript_text(self._events)
|
||||
return _generate_transcript_text(self._sorted_events())
|
||||
|
||||
def write_transcript_to_temp_file(self) -> Optional[str]:
|
||||
"""Write transcript to a temporary text file and return the path.
|
||||
|
|
|
|||
|
|
@ -616,10 +616,15 @@ async def _run_pipeline(
|
|||
llm = create_realtime_llm_service(user_config, audio_config)
|
||||
stt = None
|
||||
tts = None
|
||||
# Realtime services don't implement run_inference, so create a
|
||||
# separate text LLM for variable extraction and other out-of-band
|
||||
# inference calls.
|
||||
inference_llm = create_llm_service(user_config)
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
inference_llm = None
|
||||
|
||||
workflow_graph = WorkflowGraph(ReactFlowDTO.model_validate(run_workflow_json))
|
||||
|
||||
|
|
@ -703,9 +708,15 @@ async def _run_pipeline(
|
|||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
"context_compaction_enabled", False
|
||||
)
|
||||
# Context compaction doesn't apply in realtime mode: the speech-to-speech
|
||||
# service manages its own conversation state server-side.
|
||||
if is_realtime and context_compaction_enabled:
|
||||
logger.info("Disabling context_compaction_enabled for realtime workflow run")
|
||||
context_compaction_enabled = False
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
inference_llm=inference_llm,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars=merged_call_context_vars,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ class PipecatEngine:
|
|||
*,
|
||||
task: Optional[PipelineTask] = None,
|
||||
llm: Optional["LLMService"] = None,
|
||||
inference_llm: Optional["LLMService"] = None,
|
||||
context: Optional[LLMContext] = None,
|
||||
workflow: WorkflowGraph,
|
||||
call_context_vars: dict,
|
||||
|
|
@ -75,6 +76,12 @@ class PipecatEngine:
|
|||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
# LLM used for out-of-band inference (variable extraction, context
|
||||
# summarization). Falls back to the pipeline LLM when not provided.
|
||||
# In realtime mode the pipeline LLM is a speech-to-speech service
|
||||
# that does not implement run_inference, so a separate text LLM
|
||||
# must be passed in.
|
||||
self.inference_llm = inference_llm or llm
|
||||
self.context = context
|
||||
self.workflow = workflow
|
||||
self._call_context_vars = call_context_vars
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class ContextSummarizationManager:
|
|||
orphaned tool calls from previous nodes) with a concise summary.
|
||||
"""
|
||||
context = self._engine.context
|
||||
llm = self._engine.llm
|
||||
llm = self._engine.inference_llm
|
||||
current_node = self._engine._current_node
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -203,12 +203,12 @@ class VariableExtractionManager:
|
|||
# current node's system prompt that build_chat_completion_params
|
||||
# would otherwise prepend.
|
||||
# ------------------------------------------------------------------
|
||||
llm_response = await self._engine.llm.run_inference(
|
||||
llm_response = await self._engine.inference_llm.run_inference(
|
||||
extraction_context, system_instruction=system_prompt
|
||||
)
|
||||
|
||||
# Get model name for tracing
|
||||
model_name = getattr(self._engine.llm, "model_name", "unknown")
|
||||
model_name = getattr(self._engine.inference_llm, "model_name", "unknown")
|
||||
|
||||
if ensure_tracing():
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
|
|
@ -221,7 +221,7 @@ class VariableExtractionManager:
|
|||
]
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name=self._engine.llm.__class__.__name__,
|
||||
service_name=self._engine.inference_llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-variable-extraction",
|
||||
messages=tracing_messages,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue