From 115db6a62f94e570659803fbf0031b7a9eadc9b6 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Mon, 13 Apr 2026 09:27:20 +0100 Subject: [PATCH] Prompt service returns token counts --- trustgraph-base/trustgraph/base/__init__.py | 2 +- .../trustgraph/base/prompt_client.py | 83 ++++++++++++++++--- .../agent/orchestrator/meta_router.py | 8 +- .../agent/orchestrator/pattern_base.py | 3 +- .../agent/orchestrator/plan_pattern.py | 6 +- .../agent/orchestrator/supervisor_pattern.py | 3 +- .../trustgraph/agent/react/agent_manager.py | 5 +- .../trustgraph/agent/react/tools.py | 6 +- .../extract/kg/definitions/extract.py | 3 +- .../trustgraph/extract/kg/ontology/extract.py | 3 +- .../extract/kg/relationships/extract.py | 3 +- .../trustgraph/extract/kg/rows/processor.py | 5 +- .../retrieval/document_rag/document_rag.py | 11 +-- .../retrieval/graph_rag/graph_rag.py | 30 ++++--- 14 files changed, 124 insertions(+), 47 deletions(-) diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 9511c44d..ce17a585 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -21,7 +21,7 @@ from . embeddings_client import EmbeddingsClientSpec from . text_completion_client import ( TextCompletionClientSpec, TextCompletionClient, TextCompletionResult, ) -from . prompt_client import PromptClientSpec +from . prompt_client import PromptClientSpec, PromptClient, PromptResult from . triples_store_service import TriplesStoreService from . graph_embeddings_store_service import GraphEmbeddingsStoreService from . document_embeddings_store_service import DocumentEmbeddingsStoreService diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 6859a9f0..853e7e66 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -1,10 +1,22 @@ import json import asyncio +from dataclasses import dataclass +from typing import Optional, Any from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import PromptRequest, PromptResponse +@dataclass +class PromptResult: + response_type: str # "text", "json", or "jsonl" + text: Optional[str] = None # populated for "text" + object: Any = None # populated for "json" + objects: Optional[list] = None # populated for "jsonl" + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + class PromptClient(RequestResponse): async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None): @@ -26,17 +38,40 @@ class PromptClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - if resp.text: return resp.text + if resp.text: + return PromptResult( + response_type="text", + text=resp.text, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) - return json.loads(resp.object) + parsed = json.loads(resp.object) + + if isinstance(parsed, list): + return PromptResult( + response_type="jsonl", + objects=parsed, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) + + return PromptResult( + response_type="json", + object=parsed, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) else: - last_text = "" - last_object = None + last_resp = None async def forward_chunks(resp): - nonlocal last_text, last_object + nonlocal last_resp if resp.error: raise RuntimeError(resp.error.message) @@ -44,14 +79,13 @@ class PromptClient(RequestResponse): end_stream = getattr(resp, 'end_of_stream', False) if resp.text is not None: - last_text = resp.text if chunk_callback: if asyncio.iscoroutinefunction(chunk_callback): await chunk_callback(resp.text, end_stream) else: chunk_callback(resp.text, end_stream) - elif resp.object: - last_object = resp.object + + last_resp = resp return end_stream @@ -70,10 +104,36 @@ class PromptClient(RequestResponse): timeout=timeout ) - if last_text: - return last_text + if last_resp is None: + return PromptResult(response_type="text") - return json.loads(last_object) if last_object else None + if last_resp.object: + parsed = json.loads(last_resp.object) + + if isinstance(parsed, list): + return PromptResult( + response_type="jsonl", + objects=parsed, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) + + return PromptResult( + response_type="json", + object=parsed, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) + + return PromptResult( + response_type="text", + text=last_resp.text, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) async def extract_definitions(self, text, timeout=600): return await self.prompt( @@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec): response_schema = PromptResponse, impl = PromptClient, ) - diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py index c3b1afa6..b99bb9e7 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py @@ -71,7 +71,7 @@ class MetaRouter: try: client = context("prompt-request") - response = await client.prompt( + result = await client.prompt( id="task-type-classify", variables={ "question": question, @@ -81,7 +81,7 @@ class MetaRouter: ], }, ) - selected = response.strip().lower().replace('"', '').replace("'", "") + selected = result.text.strip().lower().replace('"', '').replace("'", "") if selected in self.task_types: framing = self.task_types[selected].get("framing", DEFAULT_FRAMING) @@ -120,7 +120,7 @@ class MetaRouter: try: client = context("prompt-request") - response = await client.prompt( + result = await client.prompt( id="pattern-select", variables={ "question": question, @@ -133,7 +133,7 @@ class MetaRouter: ], }, ) - selected = response.strip().lower().replace('"', '').replace("'", "") + selected = result.text.strip().lower().replace('"', '').replace("'", "") if selected in valid_patterns: logger.info(f"MetaRouter: selected pattern '{selected}'") diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index c18c5bac..b0f271cd 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -600,10 +600,11 @@ class PatternBase: return "".join(accumulated) else: - return await client.prompt( + result = await client.prompt( id=prompt_id, variables=variables, ) + return result.text async def send_final_response(self, respond, streaming, answer_text, already_streamed=False, message_id=""): diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 59d22929..b50074c4 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -113,7 +113,7 @@ class PlanThenExecutePattern(PatternBase): client = context("prompt-request") # Use the plan-create prompt template - plan_steps = await client.prompt( + result = await client.prompt( id="plan-create", variables={ "question": request.question, @@ -125,6 +125,7 @@ class PlanThenExecutePattern(PatternBase): }, ) + plan_steps = result.objects # Validate we got a list if not isinstance(plan_steps, list) or not plan_steps: logger.warning("plan-create returned invalid result, falling back to single step") @@ -240,7 +241,7 @@ class PlanThenExecutePattern(PatternBase): client = context("prompt-request") # Single-shot: ask LLM which tool + arguments to use for this goal - tool_call = await client.prompt( + result = await client.prompt( id="plan-step-execute", variables={ "goal": goal, @@ -259,6 +260,7 @@ class PlanThenExecutePattern(PatternBase): }, ) + tool_call = result.object tool_name = tool_call.get("tool", "") tool_arguments = tool_call.get("arguments", {}) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index d5537876..6dec1f18 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -100,7 +100,7 @@ class SupervisorPattern(PatternBase): client = context("prompt-request") # Use the supervisor-decompose prompt template - goals = await client.prompt( + result = await client.prompt( id="supervisor-decompose", variables={ "question": request.question, @@ -113,6 +113,7 @@ class SupervisorPattern(PatternBase): }, ) + goals = result.objects # Validate result if not isinstance(goals, list): goals = [] diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index e86a2d6c..64ba977a 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -255,7 +255,7 @@ class AgentManager: client = context("prompt-request") # Get streaming response - response_text = await client.agent_react( + await client.agent_react( variables=variables, streaming=True, chunk_callback=on_chunk @@ -275,10 +275,11 @@ class AgentManager: # Non-streaming path - get complete text and parse client = context("prompt-request") - response_text = await client.agent_react( + prompt_result = await client.agent_react( variables=variables, streaming=False ) + response_text = prompt_result.text logger.debug(f"Response text:\n{response_text}") diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 6fd96ade..c474f740 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -78,9 +78,10 @@ class TextCompletionImpl: async def invoke(self, **arguments): client = self.context("prompt-request") logger.debug("Prompt question...") - return await client.question( + result = await client.question( arguments.get("question") ) + return result.text # This tool implementation knows how to do MCP tool invocation. This uses # the mcp-tool service. @@ -227,10 +228,11 @@ class PromptImpl: async def invoke(self, **arguments): client = self.context("prompt-request") logger.debug(f"Prompt template invocation: {self.template_id}...") - return await client.prompt( + result = await client.prompt( id=self.template_id, variables=arguments ) + return result.text # This tool implementation invokes a dynamically configured tool service diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 2bb88c8a..9b5bbb79 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -117,10 +117,11 @@ class Processor(FlowProcessor): try: - defs = await flow("prompt-request").extract_definitions( + result = await flow("prompt-request").extract_definitions( text = chunk ) + defs = result.objects logger.debug(f"Definitions response: {defs}") if type(defs) != list: diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 29808cae..bdb0e6e8 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -376,10 +376,11 @@ class Processor(FlowProcessor): """ try: # Call prompt service with simplified format prompt - extraction_response = await flow("prompt-request").prompt( + result = await flow("prompt-request").prompt( id="extract-with-ontologies", variables=prompt_variables ) + extraction_response = result.object logger.debug(f"Simplified extraction response: {extraction_response}") # Parse response into structured format diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index b557ec32..8068a23d 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -100,10 +100,11 @@ class Processor(FlowProcessor): try: - rels = await flow("prompt-request").extract_relationships( + result = await flow("prompt-request").extract_relationships( text = chunk ) + rels = result.objects logger.debug(f"Prompt response: {rels}") if type(rels) != list: diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 8fd494b0..973bb3d7 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -148,11 +148,12 @@ class Processor(FlowProcessor): schema_dict = row_schema_translator.encode(schema) # Use prompt client to extract rows based on schema - objects = await flow("prompt-request").extract_objects( + result = await flow("prompt-request").extract_objects( schema=schema_dict, text=text ) - + + objects = result.objects if not isinstance(objects, list): return [] diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 730a7226..f51a2c3c 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -37,14 +37,14 @@ class Query: async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" - response = await self.rag.prompt_client.prompt( + result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) concepts = [] - if isinstance(response, str): - for line in response.strip().split('\n'): + if result.text: + for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) @@ -228,7 +228,7 @@ class DocumentRag: accumulated_chunks.append(chunk) await chunk_callback(chunk, end_of_stream) - resp = await self.prompt_client.document_prompt( + await self.prompt_client.document_prompt( query=query, documents=docs, streaming=True, @@ -237,10 +237,11 @@ class DocumentRag: # Combine all chunks into full response resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.document_prompt( + result = await self.prompt_client.document_prompt( query=query, documents=docs ) + resp = result.text if self.verbose: logger.debug("Query processing complete") diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 5cf7b991..7d8d8801 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -134,14 +134,14 @@ class Query: async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" - response = await self.rag.prompt_client.prompt( + result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) concepts = [] - if isinstance(response, str): - for line in response.strip().split('\n'): + if result.text: + for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) @@ -751,13 +751,14 @@ class GraphRag: logger.debug(f"Built edge map with {len(edge_map)} edges") # Step 1a: Edge Scoring - LLM scores edges for relevance - scoring_response = await self.prompt_client.prompt( + scoring_result = await self.prompt_client.prompt( "kg-edge-scoring", variables={ "query": query, "knowledge": edges_with_ids } ) + scoring_response = scoring_result.text if self.verbose: logger.debug(f"Edge scoring response: {scoring_response}") @@ -821,13 +822,17 @@ class GraphRag: ] # Run reasoning and document tracing concurrently - reasoning_task = self.prompt_client.prompt( - "kg-edge-reasoning", - variables={ - "query": query, - "knowledge": selected_edges_with_ids - } - ) + async def _get_reasoning(): + result = await self.prompt_client.prompt( + "kg-edge-reasoning", + variables={ + "query": query, + "knowledge": selected_edges_with_ids + } + ) + return result.text + + reasoning_task = _get_reasoning() doc_trace_task = q.trace_source_documents(selected_edge_uris) reasoning_response, source_documents = await asyncio.gather( @@ -928,10 +933,11 @@ class GraphRag: # Combine all chunks into full response resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.prompt( + result = await self.prompt_client.prompt( "kg-synthesis", variables=synthesis_variables, ) + resp = result.text if self.verbose: logger.debug("Query processing complete")