Prompt service returns token counts

This commit is contained in:
Cyber MacGeddon 2026-04-13 09:27:20 +01:00
parent 05c29de5bc
commit 115db6a62f
14 changed files with 124 additions and 47 deletions

View file

@ -21,7 +21,7 @@ from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import (
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
)
from . prompt_client import PromptClientSpec
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService

View file

@ -1,10 +1,22 @@
import json
import asyncio
from dataclasses import dataclass
from typing import Optional, Any
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
@dataclass
class PromptResult:
response_type: str # "text", "json", or "jsonl"
text: Optional[str] = None # populated for "text"
object: Any = None # populated for "json"
objects: Optional[list] = None # populated for "jsonl"
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
if resp.text:
return PromptResult(
response_type="text",
text=resp.text,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return json.loads(resp.object)
parsed = json.loads(resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
else:
last_text = ""
last_object = None
last_resp = None
async def forward_chunks(resp):
nonlocal last_text, last_object
nonlocal last_resp
if resp.error:
raise RuntimeError(resp.error.message)
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
end_stream = getattr(resp, 'end_of_stream', False)
if resp.text is not None:
last_text = resp.text
if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream)
else:
chunk_callback(resp.text, end_stream)
elif resp.object:
last_object = resp.object
last_resp = resp
return end_stream
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
timeout=timeout
)
if last_text:
return last_text
if last_resp is None:
return PromptResult(response_type="text")
return json.loads(last_object) if last_object else None
if last_resp.object:
parsed = json.loads(last_resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="text",
text=last_resp.text,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
response_schema = PromptResponse,
impl = PromptClient,
)

View file

@ -71,7 +71,7 @@ class MetaRouter:
try:
client = context("prompt-request")
response = await client.prompt(
result = await client.prompt(
id="task-type-classify",
variables={
"question": question,
@ -81,7 +81,7 @@ class MetaRouter:
],
},
)
selected = response.strip().lower().replace('"', '').replace("'", "")
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in self.task_types:
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
@ -120,7 +120,7 @@ class MetaRouter:
try:
client = context("prompt-request")
response = await client.prompt(
result = await client.prompt(
id="pattern-select",
variables={
"question": question,
@ -133,7 +133,7 @@ class MetaRouter:
],
},
)
selected = response.strip().lower().replace('"', '').replace("'", "")
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in valid_patterns:
logger.info(f"MetaRouter: selected pattern '{selected}'")

View file

@ -600,10 +600,11 @@ class PatternBase:
return "".join(accumulated)
else:
return await client.prompt(
result = await client.prompt(
id=prompt_id,
variables=variables,
)
return result.text
async def send_final_response(self, respond, streaming, answer_text,
already_streamed=False, message_id=""):

View file

@ -113,7 +113,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Use the plan-create prompt template
plan_steps = await client.prompt(
result = await client.prompt(
id="plan-create",
variables={
"question": request.question,
@ -125,6 +125,7 @@ class PlanThenExecutePattern(PatternBase):
},
)
plan_steps = result.objects
# Validate we got a list
if not isinstance(plan_steps, list) or not plan_steps:
logger.warning("plan-create returned invalid result, falling back to single step")
@ -240,7 +241,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Single-shot: ask LLM which tool + arguments to use for this goal
tool_call = await client.prompt(
result = await client.prompt(
id="plan-step-execute",
variables={
"goal": goal,
@ -259,6 +260,7 @@ class PlanThenExecutePattern(PatternBase):
},
)
tool_call = result.object
tool_name = tool_call.get("tool", "")
tool_arguments = tool_call.get("arguments", {})

View file

@ -100,7 +100,7 @@ class SupervisorPattern(PatternBase):
client = context("prompt-request")
# Use the supervisor-decompose prompt template
goals = await client.prompt(
result = await client.prompt(
id="supervisor-decompose",
variables={
"question": request.question,
@ -113,6 +113,7 @@ class SupervisorPattern(PatternBase):
},
)
goals = result.objects
# Validate result
if not isinstance(goals, list):
goals = []

View file

@ -255,7 +255,7 @@ class AgentManager:
client = context("prompt-request")
# Get streaming response
response_text = await client.agent_react(
await client.agent_react(
variables=variables,
streaming=True,
chunk_callback=on_chunk
@ -275,10 +275,11 @@ class AgentManager:
# Non-streaming path - get complete text and parse
client = context("prompt-request")
response_text = await client.agent_react(
prompt_result = await client.agent_react(
variables=variables,
streaming=False
)
response_text = prompt_result.text
logger.debug(f"Response text:\n{response_text}")

View file

@ -78,9 +78,10 @@ class TextCompletionImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug("Prompt question...")
return await client.question(
result = await client.question(
arguments.get("question")
)
return result.text
# This tool implementation knows how to do MCP tool invocation. This uses
# the mcp-tool service.
@ -227,10 +228,11 @@ class PromptImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug(f"Prompt template invocation: {self.template_id}...")
return await client.prompt(
result = await client.prompt(
id=self.template_id,
variables=arguments
)
return result.text
# This tool implementation invokes a dynamically configured tool service

View file

@ -117,10 +117,11 @@ class Processor(FlowProcessor):
try:
defs = await flow("prompt-request").extract_definitions(
result = await flow("prompt-request").extract_definitions(
text = chunk
)
defs = result.objects
logger.debug(f"Definitions response: {defs}")
if type(defs) != list:

View file

@ -376,10 +376,11 @@ class Processor(FlowProcessor):
"""
try:
# Call prompt service with simplified format prompt
extraction_response = await flow("prompt-request").prompt(
result = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
extraction_response = result.object
logger.debug(f"Simplified extraction response: {extraction_response}")
# Parse response into structured format

View file

@ -100,10 +100,11 @@ class Processor(FlowProcessor):
try:
rels = await flow("prompt-request").extract_relationships(
result = await flow("prompt-request").extract_relationships(
text = chunk
)
rels = result.objects
logger.debug(f"Prompt response: {rels}")
if type(rels) != list:

View file

@ -148,11 +148,12 @@ class Processor(FlowProcessor):
schema_dict = row_schema_translator.encode(schema)
# Use prompt client to extract rows based on schema
objects = await flow("prompt-request").extract_objects(
result = await flow("prompt-request").extract_objects(
schema=schema_dict,
text=text
)
objects = result.objects
if not isinstance(objects, list):
return []

View file

@ -37,14 +37,14 @@ class Query:
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@ -228,7 +228,7 @@ class DocumentRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
resp = await self.prompt_client.document_prompt(
await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
@ -237,10 +237,11 @@ class DocumentRag:
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.document_prompt(
result = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
resp = result.text
if self.verbose:
logger.debug("Query processing complete")

View file

@ -134,14 +134,14 @@ class Query:
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@ -751,13 +751,14 @@ class GraphRag:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt(
scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
scoring_response = scoring_result.text
if self.verbose:
logger.debug(f"Edge scoring response: {scoring_response}")
@ -821,13 +822,17 @@ class GraphRag:
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
async def _get_reasoning():
result = await self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
return result.text
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
@ -928,10 +933,11 @@ class GraphRag:
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.prompt(
result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
)
resp = result.text
if self.verbose:
logger.debug("Query processing complete")