feat: add tool response in variable extraction llm

This commit is contained in:
Sabiha Khan 2026-03-19 13:40:58 +05:30
parent 1604e306ec
commit 5dd40ca90a

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, List
from loguru import logger
@ -34,6 +35,124 @@ class VariableExtractionManager:
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
# Keys stripped from HTTP tool responses before passing to the extraction
_TOOL_RESPONSE_STRIP_KEYS = {"status", "status_code"}
# Maximum character length for a single tool response in the extraction
# context. Responses longer than this are truncated with a marker.
_TOOL_RESPONSE_MAX_CHARS = 2000
# Transition tool response
_TRANSITION_RESPONSE = '{"status": "done"}'
def _build_tool_call_name_lookup(self) -> dict[str, str]:
"""Build a mapping of tool_call_id → function name from assistant messages.
This allows labelling tool responses with the function that produced them.
"""
lookup: dict[str, str] = {}
for msg in self._context.messages:
if not isinstance(msg, dict) or msg.get("role") != "assistant":
continue
for tc in msg.get("tool_calls", []):
tc_id = tc.get("id")
func = tc.get("function") or {}
tc_name = func.get("name")
if tc_id and tc_name:
lookup[tc_id] = tc_name
return lookup
def _format_tool_response(self, raw_content: str, tool_name: str) -> str | None:
"""Clean, trim, and format a tool response for the extraction context.
Returns None if the response should be excluded (e.g. transition tools).
"""
# Skip transition tool responses
if raw_content.strip() == self._TRANSITION_RESPONSE:
return None
# Try to parse as JSON so we can strip wrapper keys and extract data
try:
parsed = json.loads(raw_content)
if isinstance(parsed, dict):
# If there is a "data" key, prefer its content — that is the
# actual HTTP response payload from custom tools.
if "data" in parsed:
parsed = parsed["data"]
else:
# Strip wrapper metadata keys
for key in self._TOOL_RESPONSE_STRIP_KEYS:
parsed.pop(key, None)
formatted = json.dumps(parsed, ensure_ascii=False)
else:
formatted = raw_content
except (json.JSONDecodeError, TypeError):
formatted = raw_content
# Truncate if too long
if len(formatted) > self._TOOL_RESPONSE_MAX_CHARS:
formatted = formatted[: self._TOOL_RESPONSE_MAX_CHARS] + "...(truncated)"
return f"[Tool Response: {tool_name}]\n{formatted}"
def _get_role_and_content(self, msg: Any) -> tuple[str | None, str | None]:
"""Return (role, content) for a single context message.
Supports both OpenAI-style dict messages and Google Gemini ``Content``
objects. Only plain textual content is returned image parts, tool
call placeholders, etc. are ignored.
"""
# OpenAI format — dict with ``role`` and ``content`` keys
if isinstance(msg, dict):
role = msg.get("role")
content_field = msg.get("content")
if isinstance(content_field, str):
return role, content_field
if isinstance(content_field, list):
texts = [
segment.get("text", "")
for segment in content_field
if isinstance(segment, dict) and segment.get("type") == "text"
]
return role, (" ".join(texts) if texts else None)
return role, None
# Google Gemini format — ``Content`` object with ``parts`` list
role_attr = getattr(msg, "role", None)
parts_attr = getattr(msg, "parts", None)
if role_attr is None or parts_attr is None:
return None, None
role = "assistant" if role_attr == "model" else role_attr
texts = [t for p in parts_attr if (t := getattr(p, "text", None))]
return role, (" ".join(texts) if texts else None)
def _build_conversation_history(self) -> str:
"""Build a text representation of the conversation for the extraction LLM.
Includes assistant/user messages and formatted tool responses (excluding
transition tool responses).
"""
tool_call_names = self._build_tool_call_name_lookup()
lines: list[str] = []
for msg in self._context.messages:
role, content = self._get_role_and_content(msg)
if role in ("assistant", "user") and content:
lines.append(f"{role}: {content}")
elif isinstance(msg, dict) and msg.get("role") == "tool":
tool_content = msg.get("content", "")
tool_call_id = msg.get("tool_call_id", "")
tool_name = tool_call_names.get(tool_call_id, "unknown")
formatted = self._format_tool_response(tool_content, tool_name)
if formatted:
lines.append(formatted)
return "\n".join(lines)
async def _perform_extraction(
self,
extraction_variables: List[ExtractionVariableDTO],
@ -50,73 +169,9 @@ class VariableExtractionManager:
)
# ------------------------------------------------------------------
# Build a normalized representation of the existing conversation so the
# extractor works with both OpenAI-style (dict) messages and Google
# Gemini `Content` objects.
# Build a normalized conversation history including tool responses.
# ------------------------------------------------------------------
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
"""Return a pair of (role, content) for the given message.
The logic supports both OpenAI-style dict messages and Google
`Content` objects that expose ``role`` and ``parts`` attributes.
Only plain textual content is extracted image parts, tool call
placeholders, etc. are ignored for the purpose of variable
extraction.
"""
# --------------------------------------------------------------
# OpenAI format → simple dict with ``role`` and ``content`` keys
# --------------------------------------------------------------
if isinstance(msg, dict):
role = msg.get("role")
content_field = msg.get("content")
# Content can be a str, list of segments, or None.
if isinstance(content_field, str):
content = content_field
elif isinstance(content_field, list):
# Collapse all text parts into a single string.
texts = [
segment.get("text", "")
for segment in content_field
if isinstance(segment, dict) and segment.get("type") == "text"
]
content = " ".join(texts) if texts else None
else:
content = None
return role, content
# --------------------------------------------------------------
# Google Gemini format → ``Content`` object with ``parts`` list
# --------------------------------------------------------------
role_attr = getattr(msg, "role", None)
parts_attr = getattr(msg, "parts", None)
if role_attr is None or parts_attr is None:
return None, None # Unrecognised message format
role = (
"assistant" if role_attr == "model" else role_attr
) # Normalise role name
# Collect textual parts only (ignore images, function calls, etc.)
texts: list[str] = []
for part in parts_attr:
text_val = getattr(part, "text", None)
if text_val:
texts.append(text_val)
content = " ".join(texts) if texts else None
return role, content
conversation_lines: list[str] = []
for msg in self._context.messages:
role, content = _get_role_and_content(msg)
if role in ("assistant", "user") and content:
conversation_lines.append(f"{role}: {content}")
conversation_history = "\n".join(conversation_lines)
conversation_history = self._build_conversation_history()
system_prompt = (
"You are an assistant tasked with extracting structured data from the conversation. "