feat: add call tags extraction in workflow

This commit is contained in:
Abhishek Kumar 2026-02-10 08:15:15 +05:30
parent 7a102026fb
commit 15809e03a4
8 changed files with 345 additions and 7 deletions

View file

@ -51,6 +51,8 @@ class NodeDataDTO(BaseModel):
extraction_enabled: bool = False
extraction_prompt: Optional[str] = None
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
call_tags_enabled: bool = False
call_tags_prompt: Optional[str] = None
add_global_prompt: bool = True
wait_for_user_response: bool = False
wait_for_user_response_timeout: Optional[float] = None

View file

@ -207,8 +207,9 @@ class PipecatEngine:
)
logger.info(f"Arguments: {function_call_params.arguments}")
# Perform variable extraction before transitioning to new node
# Perform variable extraction and call tags extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(self._current_node)
await self._perform_call_tags_extraction_if_needed(self._current_node)
# Set context for the new node, so that when the function call result
# frame is received by LLMContextAggregator and an LLM generation
@ -413,6 +414,54 @@ class PipecatEngine:
)
await _do_extraction()
async def _perform_call_tags_extraction_if_needed(
self, node: Optional[Node], run_in_background: bool = True
) -> None:
"""Perform call tags extraction if the node has it enabled.
Extracted tags are merged into gathered_context["call_tags"].
Args:
node: The node to extract call tags from.
run_in_background: If True, runs extraction as a fire-and-forget task.
If False, awaits the extraction synchronously.
"""
if not (node and node.call_tags_enabled):
return
parent_context = get_current_turn_context()
call_tags_prompt = self._format_prompt(node.call_tags_prompt or "")
async def _do_extraction():
try:
extracted_tags = (
await self._variable_extraction_manager._perform_call_tags_extraction(
parent_context, call_tags_prompt
)
)
# Merge into existing call_tags (no duplicates)
existing_tags = self._gathered_context.get("call_tags", [])
for tag in extracted_tags:
if tag not in existing_tags:
existing_tags.append(tag)
self._gathered_context["call_tags"] = existing_tags
logger.debug(
f"Call tags extraction completed. Tags: {existing_tags}"
)
except Exception as e:
logger.error(f"Error during call tags extraction: {str(e)}")
if run_in_background:
logger.debug(
f"Scheduling background call tags extraction for node: {node.name}"
)
asyncio.create_task(_do_extraction())
else:
logger.debug(
f"Performing synchronous call tags extraction for node: {node.name}"
)
await _do_extraction()
async def _setup_llm_context(self, node: Node) -> None:
"""Common method to set up LLM context"""
# Set node name for tracing
@ -527,10 +576,13 @@ class PipecatEngine:
# Mute the pipeline
self._mute_pipeline = True
# Perform final variable extraction synchronously before ending
# Perform final variable extraction and call tags extraction synchronously before ending
await self._perform_variable_extraction_if_needed(
self._current_node, run_in_background=False
)
await self._perform_call_tags_extraction_if_needed(
self._current_node, run_in_background=False
)
frame_to_push = CancelFrame() if abort_immediately else EndFrame()

View file

@ -119,8 +119,8 @@ class VariableExtractionManager:
conversation_history = "\n".join(conversation_lines)
system_prompt = (
"You are an assistant tasked with extracting structured data from the conversation. "
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown." # noqa: E501
"You are an assistant tasked with extracting structured data from a conversation. "
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown."
)
# Use provided extraction_prompt as system prompt, or default
system_prompt = (
@ -184,3 +184,134 @@ class VariableExtractionManager:
logger.debug(f"Extracted variables: {extracted}")
return extracted
async def _perform_call_tags_extraction(
self,
parent_ctx: Any,
call_tags_prompt: str = "",
) -> list[str]:
"""Run a chat completion to extract call tags from the conversation.
Returns a list of tag strings.
"""
# ------------------------------------------------------------------
# Build a normalized conversation history (reuses the same helper
# logic from variable extraction).
# ------------------------------------------------------------------
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
if isinstance(msg, dict):
role = msg.get("role")
content_field = msg.get("content")
if isinstance(content_field, str):
content = content_field
elif isinstance(content_field, list):
texts = [
segment.get("text", "")
for segment in content_field
if isinstance(segment, dict) and segment.get("type") == "text"
]
content = " ".join(texts) if texts else None
else:
content = None
return role, content
role_attr = getattr(msg, "role", None)
parts_attr = getattr(msg, "parts", None)
if role_attr is None or parts_attr is None:
return None, None
role = "assistant" if role_attr == "model" else role_attr
texts: list[str] = []
for part in parts_attr:
text_val = getattr(part, "text", None)
if text_val:
texts.append(text_val)
content = " ".join(texts) if texts else None
return role, content
conversation_lines: list[str] = []
for msg in self._context.messages:
role, content = _get_role_and_content(msg)
if role in ("assistant", "user") and content:
conversation_lines.append(f"{role}: {content}")
conversation_history = "\n".join(conversation_lines)
system_prompt = (
"You are an assistant tasked with extracting call tags from a conversation. "
"Return ONLY a valid JSON array of short tag strings. Do not wrap the JSON in markdown. "
"Example: [\"interested\", \"follow_up_needed\", \"pricing_discussed\"]"
)
if call_tags_prompt:
system_prompt = system_prompt + "\n\n" + call_tags_prompt
user_prompt = (
"\n\nExtract relevant call tags from the following conversation:\n"
f"{conversation_history}"
)
extraction_context = LLMContext()
extraction_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
extraction_context.set_messages(extraction_messages)
llm_response = await self._engine.llm.run_inference(extraction_context)
model_name = getattr(self._engine.llm, "model_name", "unknown")
if is_tracing_enabled():
tracer = trace.get_tracer("pipecat")
with tracer.start_as_current_span(
"llm-call-tags-extraction", context=parent_ctx
) as span:
add_llm_span_attributes(
span,
service_name=self._engine.llm.__class__.__name__,
model=model_name,
operation_name="llm-call-tags-extraction",
messages=extraction_messages,
output=llm_response,
stream=False,
parameters={},
)
if llm_response is None:
logger.warning("Call tags extractor returned no response; returning empty list.")
return []
parsed = parse_llm_json(llm_response)
# parse_llm_json returns a dict. If the LLM returned a JSON array,
# it will be stored under the "raw" key or similar. We need to
# handle both cases: a plain list from the LLM or a dict wrapper.
if isinstance(parsed, list):
tags = [str(t) for t in parsed if t]
elif isinstance(parsed, dict):
# If parse_llm_json wrapped a list in {"raw": ...}, try to
# extract the list. Otherwise flatten dict values.
import json as _json
raw = parsed.get("raw")
if raw and isinstance(raw, str):
try:
maybe_list = _json.loads(raw)
if isinstance(maybe_list, list):
tags = [str(t) for t in maybe_list if t]
else:
tags = []
except _json.JSONDecodeError:
tags = []
else:
# Flatten any list values in the dict
tags = []
for v in parsed.values():
if isinstance(v, list):
tags.extend(str(t) for t in v if t)
elif isinstance(v, str) and v:
tags.append(v)
else:
tags = []
logger.debug(f"Extracted call tags: {tags}")
return tags

View file

@ -43,6 +43,8 @@ class Node:
self.extraction_enabled = data.extraction_enabled
self.extraction_prompt = data.extraction_prompt
self.extraction_variables = data.extraction_variables
self.call_tags_enabled = data.call_tags_enabled
self.call_tags_prompt = data.call_tags_prompt
self.add_global_prompt = data.add_global_prompt
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start