mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
feat: add call tags extraction in workflow
This commit is contained in:
parent
7a102026fb
commit
15809e03a4
8 changed files with 345 additions and 7 deletions
|
|
@ -51,6 +51,8 @@ class NodeDataDTO(BaseModel):
|
|||
extraction_enabled: bool = False
|
||||
extraction_prompt: Optional[str] = None
|
||||
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
|
||||
call_tags_enabled: bool = False
|
||||
call_tags_prompt: Optional[str] = None
|
||||
add_global_prompt: bool = True
|
||||
wait_for_user_response: bool = False
|
||||
wait_for_user_response_timeout: Optional[float] = None
|
||||
|
|
|
|||
|
|
@ -207,8 +207,9 @@ class PipecatEngine:
|
|||
)
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
# Perform variable extraction before transitioning to new node
|
||||
# Perform variable extraction and call tags extraction before transitioning to new node
|
||||
await self._perform_variable_extraction_if_needed(self._current_node)
|
||||
await self._perform_call_tags_extraction_if_needed(self._current_node)
|
||||
|
||||
# Set context for the new node, so that when the function call result
|
||||
# frame is received by LLMContextAggregator and an LLM generation
|
||||
|
|
@ -413,6 +414,54 @@ class PipecatEngine:
|
|||
)
|
||||
await _do_extraction()
|
||||
|
||||
async def _perform_call_tags_extraction_if_needed(
|
||||
self, node: Optional[Node], run_in_background: bool = True
|
||||
) -> None:
|
||||
"""Perform call tags extraction if the node has it enabled.
|
||||
|
||||
Extracted tags are merged into gathered_context["call_tags"].
|
||||
|
||||
Args:
|
||||
node: The node to extract call tags from.
|
||||
run_in_background: If True, runs extraction as a fire-and-forget task.
|
||||
If False, awaits the extraction synchronously.
|
||||
"""
|
||||
if not (node and node.call_tags_enabled):
|
||||
return
|
||||
|
||||
parent_context = get_current_turn_context()
|
||||
call_tags_prompt = self._format_prompt(node.call_tags_prompt or "")
|
||||
|
||||
async def _do_extraction():
|
||||
try:
|
||||
extracted_tags = (
|
||||
await self._variable_extraction_manager._perform_call_tags_extraction(
|
||||
parent_context, call_tags_prompt
|
||||
)
|
||||
)
|
||||
# Merge into existing call_tags (no duplicates)
|
||||
existing_tags = self._gathered_context.get("call_tags", [])
|
||||
for tag in extracted_tags:
|
||||
if tag not in existing_tags:
|
||||
existing_tags.append(tag)
|
||||
self._gathered_context["call_tags"] = existing_tags
|
||||
logger.debug(
|
||||
f"Call tags extraction completed. Tags: {existing_tags}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during call tags extraction: {str(e)}")
|
||||
|
||||
if run_in_background:
|
||||
logger.debug(
|
||||
f"Scheduling background call tags extraction for node: {node.name}"
|
||||
)
|
||||
asyncio.create_task(_do_extraction())
|
||||
else:
|
||||
logger.debug(
|
||||
f"Performing synchronous call tags extraction for node: {node.name}"
|
||||
)
|
||||
await _do_extraction()
|
||||
|
||||
async def _setup_llm_context(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context"""
|
||||
# Set node name for tracing
|
||||
|
|
@ -527,10 +576,13 @@ class PipecatEngine:
|
|||
# Mute the pipeline
|
||||
self._mute_pipeline = True
|
||||
|
||||
# Perform final variable extraction synchronously before ending
|
||||
# Perform final variable extraction and call tags extraction synchronously before ending
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node, run_in_background=False
|
||||
)
|
||||
await self._perform_call_tags_extraction_if_needed(
|
||||
self._current_node, run_in_background=False
|
||||
)
|
||||
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
|
||||
|
|
|
|||
|
|
@ -119,8 +119,8 @@ class VariableExtractionManager:
|
|||
conversation_history = "\n".join(conversation_lines)
|
||||
|
||||
system_prompt = (
|
||||
"You are an assistant tasked with extracting structured data from the conversation. "
|
||||
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown." # noqa: E501
|
||||
"You are an assistant tasked with extracting structured data from a conversation. "
|
||||
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown."
|
||||
)
|
||||
# Use provided extraction_prompt as system prompt, or default
|
||||
system_prompt = (
|
||||
|
|
@ -184,3 +184,134 @@ class VariableExtractionManager:
|
|||
|
||||
logger.debug(f"Extracted variables: {extracted}")
|
||||
return extracted
|
||||
|
||||
async def _perform_call_tags_extraction(
|
||||
self,
|
||||
parent_ctx: Any,
|
||||
call_tags_prompt: str = "",
|
||||
) -> list[str]:
|
||||
"""Run a chat completion to extract call tags from the conversation.
|
||||
|
||||
Returns a list of tag strings.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build a normalized conversation history (reuses the same helper
|
||||
# logic from variable extraction).
|
||||
# ------------------------------------------------------------------
|
||||
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
content_field = msg.get("content")
|
||||
if isinstance(content_field, str):
|
||||
content = content_field
|
||||
elif isinstance(content_field, list):
|
||||
texts = [
|
||||
segment.get("text", "")
|
||||
for segment in content_field
|
||||
if isinstance(segment, dict) and segment.get("type") == "text"
|
||||
]
|
||||
content = " ".join(texts) if texts else None
|
||||
else:
|
||||
content = None
|
||||
return role, content
|
||||
|
||||
role_attr = getattr(msg, "role", None)
|
||||
parts_attr = getattr(msg, "parts", None)
|
||||
if role_attr is None or parts_attr is None:
|
||||
return None, None
|
||||
role = "assistant" if role_attr == "model" else role_attr
|
||||
texts: list[str] = []
|
||||
for part in parts_attr:
|
||||
text_val = getattr(part, "text", None)
|
||||
if text_val:
|
||||
texts.append(text_val)
|
||||
content = " ".join(texts) if texts else None
|
||||
return role, content
|
||||
|
||||
conversation_lines: list[str] = []
|
||||
for msg in self._context.messages:
|
||||
role, content = _get_role_and_content(msg)
|
||||
if role in ("assistant", "user") and content:
|
||||
conversation_lines.append(f"{role}: {content}")
|
||||
|
||||
conversation_history = "\n".join(conversation_lines)
|
||||
|
||||
system_prompt = (
|
||||
"You are an assistant tasked with extracting call tags from a conversation. "
|
||||
"Return ONLY a valid JSON array of short tag strings. Do not wrap the JSON in markdown. "
|
||||
"Example: [\"interested\", \"follow_up_needed\", \"pricing_discussed\"]"
|
||||
)
|
||||
if call_tags_prompt:
|
||||
system_prompt = system_prompt + "\n\n" + call_tags_prompt
|
||||
|
||||
user_prompt = (
|
||||
"\n\nExtract relevant call tags from the following conversation:\n"
|
||||
f"{conversation_history}"
|
||||
)
|
||||
|
||||
extraction_context = LLMContext()
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
extraction_context.set_messages(extraction_messages)
|
||||
|
||||
llm_response = await self._engine.llm.run_inference(extraction_context)
|
||||
|
||||
model_name = getattr(self._engine.llm, "model_name", "unknown")
|
||||
|
||||
if is_tracing_enabled():
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
with tracer.start_as_current_span(
|
||||
"llm-call-tags-extraction", context=parent_ctx
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name=self._engine.llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-call-tags-extraction",
|
||||
messages=extraction_messages,
|
||||
output=llm_response,
|
||||
stream=False,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
if llm_response is None:
|
||||
logger.warning("Call tags extractor returned no response; returning empty list.")
|
||||
return []
|
||||
|
||||
parsed = parse_llm_json(llm_response)
|
||||
|
||||
# parse_llm_json returns a dict. If the LLM returned a JSON array,
|
||||
# it will be stored under the "raw" key or similar. We need to
|
||||
# handle both cases: a plain list from the LLM or a dict wrapper.
|
||||
if isinstance(parsed, list):
|
||||
tags = [str(t) for t in parsed if t]
|
||||
elif isinstance(parsed, dict):
|
||||
# If parse_llm_json wrapped a list in {"raw": ...}, try to
|
||||
# extract the list. Otherwise flatten dict values.
|
||||
import json as _json
|
||||
raw = parsed.get("raw")
|
||||
if raw and isinstance(raw, str):
|
||||
try:
|
||||
maybe_list = _json.loads(raw)
|
||||
if isinstance(maybe_list, list):
|
||||
tags = [str(t) for t in maybe_list if t]
|
||||
else:
|
||||
tags = []
|
||||
except _json.JSONDecodeError:
|
||||
tags = []
|
||||
else:
|
||||
# Flatten any list values in the dict
|
||||
tags = []
|
||||
for v in parsed.values():
|
||||
if isinstance(v, list):
|
||||
tags.extend(str(t) for t in v if t)
|
||||
elif isinstance(v, str) and v:
|
||||
tags.append(v)
|
||||
else:
|
||||
tags = []
|
||||
|
||||
logger.debug(f"Extracted call tags: {tags}")
|
||||
return tags
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ class Node:
|
|||
self.extraction_enabled = data.extraction_enabled
|
||||
self.extraction_prompt = data.extraction_prompt
|
||||
self.extraction_variables = data.extraction_variables
|
||||
self.call_tags_enabled = data.call_tags_enabled
|
||||
self.call_tags_prompt = data.call_tags_prompt
|
||||
self.add_global_prompt = data.add_global_prompt
|
||||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue