mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add tool response in variable extraction llm
This commit is contained in:
parent
1604e306ec
commit
5dd40ca90a
1 changed files with 121 additions and 66 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -34,6 +35,124 @@ class VariableExtractionManager:
|
|||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Keys stripped from HTTP tool responses before passing to the extraction
|
||||
_TOOL_RESPONSE_STRIP_KEYS = {"status", "status_code"}
|
||||
|
||||
# Maximum character length for a single tool response in the extraction
|
||||
# context. Responses longer than this are truncated with a marker.
|
||||
_TOOL_RESPONSE_MAX_CHARS = 2000
|
||||
|
||||
# Transition tool response
|
||||
_TRANSITION_RESPONSE = '{"status": "done"}'
|
||||
|
||||
def _build_tool_call_name_lookup(self) -> dict[str, str]:
|
||||
"""Build a mapping of tool_call_id → function name from assistant messages.
|
||||
|
||||
This allows labelling tool responses with the function that produced them.
|
||||
"""
|
||||
lookup: dict[str, str] = {}
|
||||
for msg in self._context.messages:
|
||||
if not isinstance(msg, dict) or msg.get("role") != "assistant":
|
||||
continue
|
||||
for tc in msg.get("tool_calls", []):
|
||||
tc_id = tc.get("id")
|
||||
func = tc.get("function") or {}
|
||||
tc_name = func.get("name")
|
||||
if tc_id and tc_name:
|
||||
lookup[tc_id] = tc_name
|
||||
return lookup
|
||||
|
||||
def _format_tool_response(self, raw_content: str, tool_name: str) -> str | None:
|
||||
"""Clean, trim, and format a tool response for the extraction context.
|
||||
|
||||
Returns None if the response should be excluded (e.g. transition tools).
|
||||
"""
|
||||
# Skip transition tool responses
|
||||
if raw_content.strip() == self._TRANSITION_RESPONSE:
|
||||
return None
|
||||
|
||||
# Try to parse as JSON so we can strip wrapper keys and extract data
|
||||
try:
|
||||
parsed = json.loads(raw_content)
|
||||
if isinstance(parsed, dict):
|
||||
# If there is a "data" key, prefer its content — that is the
|
||||
# actual HTTP response payload from custom tools.
|
||||
if "data" in parsed:
|
||||
parsed = parsed["data"]
|
||||
else:
|
||||
# Strip wrapper metadata keys
|
||||
for key in self._TOOL_RESPONSE_STRIP_KEYS:
|
||||
parsed.pop(key, None)
|
||||
|
||||
formatted = json.dumps(parsed, ensure_ascii=False)
|
||||
else:
|
||||
formatted = raw_content
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
formatted = raw_content
|
||||
|
||||
# Truncate if too long
|
||||
if len(formatted) > self._TOOL_RESPONSE_MAX_CHARS:
|
||||
formatted = formatted[: self._TOOL_RESPONSE_MAX_CHARS] + "...(truncated)"
|
||||
|
||||
return f"[Tool Response: {tool_name}]\n{formatted}"
|
||||
|
||||
def _get_role_and_content(self, msg: Any) -> tuple[str | None, str | None]:
|
||||
"""Return (role, content) for a single context message.
|
||||
|
||||
Supports both OpenAI-style dict messages and Google Gemini ``Content``
|
||||
objects. Only plain textual content is returned — image parts, tool
|
||||
call placeholders, etc. are ignored.
|
||||
"""
|
||||
# OpenAI format — dict with ``role`` and ``content`` keys
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
content_field = msg.get("content")
|
||||
|
||||
if isinstance(content_field, str):
|
||||
return role, content_field
|
||||
if isinstance(content_field, list):
|
||||
texts = [
|
||||
segment.get("text", "")
|
||||
for segment in content_field
|
||||
if isinstance(segment, dict) and segment.get("type") == "text"
|
||||
]
|
||||
return role, (" ".join(texts) if texts else None)
|
||||
return role, None
|
||||
|
||||
# Google Gemini format — ``Content`` object with ``parts`` list
|
||||
role_attr = getattr(msg, "role", None)
|
||||
parts_attr = getattr(msg, "parts", None)
|
||||
if role_attr is None or parts_attr is None:
|
||||
return None, None
|
||||
|
||||
role = "assistant" if role_attr == "model" else role_attr
|
||||
texts = [t for p in parts_attr if (t := getattr(p, "text", None))]
|
||||
return role, (" ".join(texts) if texts else None)
|
||||
|
||||
def _build_conversation_history(self) -> str:
|
||||
"""Build a text representation of the conversation for the extraction LLM.
|
||||
|
||||
Includes assistant/user messages and formatted tool responses (excluding
|
||||
transition tool responses).
|
||||
"""
|
||||
tool_call_names = self._build_tool_call_name_lookup()
|
||||
|
||||
lines: list[str] = []
|
||||
for msg in self._context.messages:
|
||||
role, content = self._get_role_and_content(msg)
|
||||
if role in ("assistant", "user") and content:
|
||||
lines.append(f"{role}: {content}")
|
||||
elif isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
tool_content = msg.get("content", "")
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
tool_name = tool_call_names.get(tool_call_id, "unknown")
|
||||
formatted = self._format_tool_response(tool_content, tool_name)
|
||||
if formatted:
|
||||
lines.append(formatted)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _perform_extraction(
|
||||
self,
|
||||
extraction_variables: List[ExtractionVariableDTO],
|
||||
|
|
@ -50,73 +169,9 @@ class VariableExtractionManager:
|
|||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build a normalized representation of the existing conversation so the
|
||||
# extractor works with both OpenAI-style (dict) messages and Google
|
||||
# Gemini `Content` objects.
|
||||
# Build a normalized conversation history including tool responses.
|
||||
# ------------------------------------------------------------------
|
||||
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
|
||||
"""Return a pair of (role, content) for the given message.
|
||||
|
||||
The logic supports both OpenAI-style dict messages and Google
|
||||
`Content` objects that expose ``role`` and ``parts`` attributes.
|
||||
Only plain textual content is extracted – image parts, tool call
|
||||
placeholders, etc. are ignored for the purpose of variable
|
||||
extraction.
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# OpenAI format → simple dict with ``role`` and ``content`` keys
|
||||
# --------------------------------------------------------------
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
content_field = msg.get("content")
|
||||
|
||||
# Content can be a str, list of segments, or None.
|
||||
if isinstance(content_field, str):
|
||||
content = content_field
|
||||
elif isinstance(content_field, list):
|
||||
# Collapse all text parts into a single string.
|
||||
texts = [
|
||||
segment.get("text", "")
|
||||
for segment in content_field
|
||||
if isinstance(segment, dict) and segment.get("type") == "text"
|
||||
]
|
||||
content = " ".join(texts) if texts else None
|
||||
else:
|
||||
content = None
|
||||
|
||||
return role, content
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Google Gemini format → ``Content`` object with ``parts`` list
|
||||
# --------------------------------------------------------------
|
||||
role_attr = getattr(msg, "role", None)
|
||||
parts_attr = getattr(msg, "parts", None)
|
||||
|
||||
if role_attr is None or parts_attr is None:
|
||||
return None, None # Unrecognised message format
|
||||
|
||||
role = (
|
||||
"assistant" if role_attr == "model" else role_attr
|
||||
) # Normalise role name
|
||||
|
||||
# Collect textual parts only (ignore images, function calls, etc.)
|
||||
texts: list[str] = []
|
||||
for part in parts_attr:
|
||||
text_val = getattr(part, "text", None)
|
||||
if text_val:
|
||||
texts.append(text_val)
|
||||
|
||||
content = " ".join(texts) if texts else None
|
||||
return role, content
|
||||
|
||||
conversation_lines: list[str] = []
|
||||
for msg in self._context.messages:
|
||||
role, content = _get_role_and_content(msg)
|
||||
if role in ("assistant", "user") and content:
|
||||
conversation_lines.append(f"{role}: {content}")
|
||||
|
||||
conversation_history = "\n".join(conversation_lines)
|
||||
conversation_history = self._build_conversation_history()
|
||||
|
||||
system_prompt = (
|
||||
"You are an assistant tasked with extracting structured data from the conversation. "
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue