mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: enable context summarization
This commit is contained in:
parent
501d06c00d
commit
56763a4527
7 changed files with 232 additions and 5 deletions
|
|
@ -679,6 +679,10 @@ async def _run_pipeline(
|
|||
workflow_id, workflow.organization_id
|
||||
)
|
||||
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
"context_compaction_enabled", False
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
workflow=workflow_graph,
|
||||
|
|
@ -689,6 +693,7 @@ async def _run_pipeline(
|
|||
embeddings_model=embeddings_model,
|
||||
embeddings_base_url=embeddings_base_url,
|
||||
has_recordings=has_recordings,
|
||||
context_compaction_enabled=context_compaction_enabled,
|
||||
)
|
||||
|
||||
# Create pipeline components
|
||||
|
|
|
|||
|
|
@ -159,8 +159,7 @@ def create_stt_service(
|
|||
)
|
||||
elif user_config.stt.provider == ServiceProviders.ASSEMBLYAI.value:
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
pipecat_language = _to_language_enum(language, default=Language.EN)
|
||||
settings_kwargs = {"model": user_config.stt.model, "language": pipecat_language}
|
||||
settings_kwargs = {"model": user_config.stt.model, "language": language}
|
||||
if keyterms:
|
||||
settings_kwargs["keyterms_prompt"] = keyterms
|
||||
return AssemblyAISTTService(
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ from api.services.workflow.pipecat_engine_context_composer import (
|
|||
compose_functions_for_node,
|
||||
compose_system_prompt_for_node,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_context_summarizer import (
|
||||
ContextSummarizationManager,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_custom_tools import (
|
||||
CustomToolManager,
|
||||
)
|
||||
|
|
@ -67,6 +70,7 @@ class PipecatEngine:
|
|||
embeddings_model: Optional[str] = None,
|
||||
embeddings_base_url: Optional[str] = None,
|
||||
has_recordings: bool = False,
|
||||
context_compaction_enabled: bool = False,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
|
|
@ -114,6 +118,12 @@ class PipecatEngine:
|
|||
# response mode instructions on all nodes for in-context learning.
|
||||
self._has_recordings: bool = has_recordings
|
||||
|
||||
# Background context summarization on node transitions
|
||||
self._context_compaction_enabled: bool = context_compaction_enabled
|
||||
self._context_summarization_manager: Optional[ContextSummarizationManager] = (
|
||||
None
|
||||
)
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
|
|
@ -148,6 +158,10 @@ class PipecatEngine:
|
|||
# Helper that encapsulates custom tool management
|
||||
self._custom_tool_manager = CustomToolManager(self)
|
||||
|
||||
# Helper that encapsulates context summarization
|
||||
if self._context_compaction_enabled:
|
||||
self._context_summarization_manager = ContextSummarizationManager(self)
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
|
|
@ -500,6 +514,11 @@ class PipecatEngine:
|
|||
else:
|
||||
await self._handle_agent_node(node)
|
||||
|
||||
# Summarize context in background after non-start node transitions
|
||||
# to clean up tool calls from previous nodes
|
||||
if previous_node_id is not None and self._context_summarization_manager:
|
||||
self._context_summarization_manager.start()
|
||||
|
||||
async def _handle_start_node(self, node: Node) -> None:
|
||||
"""Handle start node execution."""
|
||||
# Check if delayed start is enabled
|
||||
|
|
@ -714,3 +733,7 @@ class PipecatEngine:
|
|||
and not self._user_response_timeout_task.done()
|
||||
):
|
||||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Cancel any in-flight background summarization
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
|
|
|
|||
173
api/services/workflow/pipecat_engine_context_summarizer.py
Normal file
173
api/services/workflow/pipecat_engine_context_summarizer.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.services.pipecat.tracing_config import ensure_tracing
|
||||
from pipecat.frames.frames import LLMContextSummaryRequestFrame
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMContextSummarizationUtil,
|
||||
LLMContextSummaryConfig,
|
||||
)
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
class ContextSummarizationManager:
|
||||
"""Manages background context summarization on node transitions.
|
||||
|
||||
Replaces old messages (including orphaned tool calls from previous nodes)
|
||||
with a concise summary to keep the context window manageable.
|
||||
"""
|
||||
|
||||
def __init__(self, engine: "PipecatEngine") -> None:
|
||||
self._engine = engine
|
||||
self._summarization_task: Optional[asyncio.Task] = None
|
||||
self._config = LLMContextSummaryConfig(
|
||||
target_context_tokens=4000,
|
||||
min_messages_after_summary=2,
|
||||
summarization_timeout=30.0,
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self) -> LLMContextSummaryConfig:
|
||||
return self._config
|
||||
|
||||
def start(self) -> None:
|
||||
"""Kick off background context summarization, cancelling any in-flight one."""
|
||||
if self._summarization_task and not self._summarization_task.done():
|
||||
self._summarization_task.cancel()
|
||||
|
||||
current_node = self._engine._current_node
|
||||
self._summarization_task = asyncio.create_task(
|
||||
self._summarize_context_in_background(),
|
||||
name=f"ctx-summarize:{current_node.name}",
|
||||
)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cancel any in-flight background summarization."""
|
||||
if self._summarization_task and not self._summarization_task.done():
|
||||
self._summarization_task.cancel()
|
||||
|
||||
async def _summarize_context_in_background(self) -> None:
|
||||
"""Summarize conversation context after a node transition.
|
||||
|
||||
Runs as a fire-and-forget background task so it doesn't block
|
||||
the new node from speaking. Replaces old messages (including
|
||||
orphaned tool calls from previous nodes) with a concise summary.
|
||||
"""
|
||||
context = self._engine.context
|
||||
llm = self._engine.llm
|
||||
current_node = self._engine._current_node
|
||||
|
||||
try:
|
||||
messages = context.messages
|
||||
# Not worth summarizing if context is small
|
||||
if len(messages) <= 6:
|
||||
return
|
||||
|
||||
config = self._config
|
||||
request_frame = LLMContextSummaryRequestFrame(
|
||||
request_id=f"node-transition-{current_node.id}",
|
||||
context=context,
|
||||
min_messages_to_keep=config.min_messages_after_summary,
|
||||
target_context_tokens=config.target_context_tokens,
|
||||
summarization_prompt=config.summary_prompt,
|
||||
summarization_timeout=config.summarization_timeout,
|
||||
)
|
||||
|
||||
# Capture parent OTel context before the await
|
||||
parent_ctx = self._engine._get_otel_context()
|
||||
|
||||
summary_text, last_index = await asyncio.wait_for(
|
||||
llm._generate_summary(request_frame),
|
||||
timeout=config.summarization_timeout,
|
||||
)
|
||||
|
||||
if not summary_text or last_index < 0:
|
||||
logger.warning(
|
||||
"Context summarization returned empty result, keeping full context"
|
||||
)
|
||||
return
|
||||
|
||||
# Trace the LLM call — mirror what _generate_summary sends to
|
||||
# run_inference: system prompt + formatted transcript as user msg.
|
||||
model_name = getattr(llm, "model_name", "unknown")
|
||||
if ensure_tracing():
|
||||
summarize_result = (
|
||||
LLMContextSummarizationUtil.get_messages_to_summarize(
|
||||
context, config.min_messages_after_summary
|
||||
)
|
||||
)
|
||||
transcript = LLMContextSummarizationUtil.format_messages_for_summary(
|
||||
summarize_result.messages
|
||||
)
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
with tracer.start_as_current_span(
|
||||
"llm-context-summarization", context=parent_ctx
|
||||
) as span:
|
||||
tracing_messages = [
|
||||
{"role": "system", "content": config.summary_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Conversation history:\n{transcript}",
|
||||
},
|
||||
]
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name=llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-context-summarization",
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": summary_text}),
|
||||
stream=False,
|
||||
parameters={
|
||||
"target_context_tokens": config.target_context_tokens,
|
||||
},
|
||||
)
|
||||
|
||||
# Snapshot current messages at apply-time (not request-time)
|
||||
# to preserve anything added while the summary was generating
|
||||
current_messages = context.messages
|
||||
recent_messages = current_messages[last_index + 1 :]
|
||||
|
||||
summary_message = {
|
||||
"role": "user",
|
||||
"content": config.summary_message_template.format(summary=summary_text),
|
||||
}
|
||||
|
||||
# Preserve the current system message (already set by the new node)
|
||||
first_system_msg = next(
|
||||
(
|
||||
m
|
||||
for m in current_messages
|
||||
if isinstance(m, dict) and m.get("role") == "system"
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
new_messages = []
|
||||
if first_system_msg:
|
||||
new_messages.append(first_system_msg)
|
||||
new_messages.append(summary_message)
|
||||
new_messages.extend(recent_messages)
|
||||
|
||||
context.set_messages(new_messages)
|
||||
logger.info(
|
||||
f"Background context summarization applied: "
|
||||
f"{len(current_messages)} -> {len(new_messages)} messages"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Context summarization cancelled (new transition started)")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Context summarization timed out after {self._config.summarization_timeout}s"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Background context summarization failed: {e}")
|
||||
|
|
@ -413,12 +413,12 @@ async def _execute_webhook_node(
|
|||
return True
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
logger.warning(
|
||||
f"Webhook '{webhook_name}' failed: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Webhook '{webhook_name}' request error: {e}")
|
||||
logger.warning(f"Webhook '{webhook_name}' request error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Webhook '{webhook_name}' unexpected error: {e}")
|
||||
|
|
|
|||
|
|
@ -44,6 +44,9 @@ export const ConfigurationsDialog = ({
|
|||
const [turnStopStrategy, setTurnStopStrategy] = useState<TurnStopStrategy>(
|
||||
workflowConfigurations?.turn_stop_strategy || 'transcription'
|
||||
);
|
||||
const [contextCompactionEnabled, setContextCompactionEnabled] = useState<boolean>(
|
||||
workflowConfigurations?.context_compaction_enabled ?? false
|
||||
);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
|
||||
const handleSave = async () => {
|
||||
|
|
@ -54,7 +57,8 @@ export const ConfigurationsDialog = ({
|
|||
max_call_duration: maxCallDuration,
|
||||
max_user_idle_timeout: maxUserIdleTimeout,
|
||||
smart_turn_stop_secs: smartTurnStopSecs,
|
||||
turn_stop_strategy: turnStopStrategy
|
||||
turn_stop_strategy: turnStopStrategy,
|
||||
context_compaction_enabled: contextCompactionEnabled,
|
||||
}, name);
|
||||
onOpenChange(false);
|
||||
} catch (error) {
|
||||
|
|
@ -73,6 +77,7 @@ export const ConfigurationsDialog = ({
|
|||
setMaxUserIdleTimeout(workflowConfigurations?.max_user_idle_timeout || 10);
|
||||
setSmartTurnStopSecs(workflowConfigurations?.smart_turn_stop_secs || 2);
|
||||
setTurnStopStrategy(workflowConfigurations?.turn_stop_strategy || 'transcription');
|
||||
setContextCompactionEnabled(workflowConfigurations?.context_compaction_enabled ?? false);
|
||||
}
|
||||
}, [open, workflowName, workflowConfigurations]);
|
||||
|
||||
|
|
@ -215,6 +220,27 @@ export const ConfigurationsDialog = ({
|
|||
)}
|
||||
</div>
|
||||
|
||||
{/* Context Management Section */}
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold mb-1">Context Compaction</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Automatically summarize conversation context when transitioning between nodes. Removes stale tool calls and keeps the context clean for the new node.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<Label htmlFor="context-compaction-enabled" className="text-sm">
|
||||
Enable Context Compaction
|
||||
</Label>
|
||||
<Switch
|
||||
id="context-compaction-enabled"
|
||||
checked={contextCompactionEnabled}
|
||||
onCheckedChange={setContextCompactionEnabled}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Call Management Section */}
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ export interface WorkflowConfigurations {
|
|||
turn_stop_strategy: TurnStopStrategy; // Strategy for detecting end of user turn
|
||||
dictionary?: string; // Comma-separated words for voice agent to listen for
|
||||
voicemail_detection?: VoicemailDetectionConfiguration;
|
||||
context_compaction_enabled?: boolean; // Summarize context on node transitions to remove stale tool calls
|
||||
[key: string]: unknown; // Allow additional properties for future configurations
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue