mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-12 00:05:13 +02:00
Prompt service returns token counts
This commit is contained in:
parent
05c29de5bc
commit
115db6a62f
14 changed files with 124 additions and 47 deletions
|
|
@ -21,7 +21,7 @@ from . embeddings_client import EmbeddingsClientSpec
|
|||
from . text_completion_client import (
|
||||
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
|
||||
)
|
||||
from . prompt_client import PromptClientSpec
|
||||
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
|
||||
from . triples_store_service import TriplesStoreService
|
||||
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
|
||||
from . document_embeddings_store_service import DocumentEmbeddingsStoreService
|
||||
|
|
|
|||
|
|
@ -1,10 +1,22 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
response_type: str # "text", "json", or "jsonl"
|
||||
text: Optional[str] = None # populated for "text"
|
||||
object: Any = None # populated for "json"
|
||||
objects: Optional[list] = None # populated for "jsonl"
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
|
|
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text: return resp.text
|
||||
if resp.text:
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=resp.text,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
return json.loads(resp.object)
|
||||
parsed = json.loads(resp.object)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=parsed,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object=parsed,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
last_text = ""
|
||||
last_object = None
|
||||
last_resp = None
|
||||
|
||||
async def forward_chunks(resp):
|
||||
nonlocal last_text, last_object
|
||||
nonlocal last_resp
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
|
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
|
|||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
if resp.text is not None:
|
||||
last_text = resp.text
|
||||
if chunk_callback:
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text, end_stream)
|
||||
else:
|
||||
chunk_callback(resp.text, end_stream)
|
||||
elif resp.object:
|
||||
last_object = resp.object
|
||||
|
||||
last_resp = resp
|
||||
|
||||
return end_stream
|
||||
|
||||
|
|
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
if last_text:
|
||||
return last_text
|
||||
if last_resp is None:
|
||||
return PromptResult(response_type="text")
|
||||
|
||||
return json.loads(last_object) if last_object else None
|
||||
if last_resp.object:
|
||||
parsed = json.loads(last_resp.object)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=parsed,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object=parsed,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=last_resp.text,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
return await self.prompt(
|
||||
|
|
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
|
|||
response_schema = PromptResponse,
|
||||
impl = PromptClient,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class MetaRouter:
|
|||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="task-type-classify",
|
||||
variables={
|
||||
"question": question,
|
||||
|
|
@ -81,7 +81,7 @@ class MetaRouter:
|
|||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
selected = result.text.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in self.task_types:
|
||||
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
|
||||
|
|
@ -120,7 +120,7 @@ class MetaRouter:
|
|||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="pattern-select",
|
||||
variables={
|
||||
"question": question,
|
||||
|
|
@ -133,7 +133,7 @@ class MetaRouter:
|
|||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
selected = result.text.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in valid_patterns:
|
||||
logger.info(f"MetaRouter: selected pattern '{selected}'")
|
||||
|
|
|
|||
|
|
@ -600,10 +600,11 @@ class PatternBase:
|
|||
|
||||
return "".join(accumulated)
|
||||
else:
|
||||
return await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
)
|
||||
return result.text
|
||||
|
||||
async def send_final_response(self, respond, streaming, answer_text,
|
||||
already_streamed=False, message_id=""):
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Use the plan-create prompt template
|
||||
plan_steps = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="plan-create",
|
||||
variables={
|
||||
"question": request.question,
|
||||
|
|
@ -125,6 +125,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
},
|
||||
)
|
||||
|
||||
plan_steps = result.objects
|
||||
# Validate we got a list
|
||||
if not isinstance(plan_steps, list) or not plan_steps:
|
||||
logger.warning("plan-create returned invalid result, falling back to single step")
|
||||
|
|
@ -240,7 +241,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Single-shot: ask LLM which tool + arguments to use for this goal
|
||||
tool_call = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="plan-step-execute",
|
||||
variables={
|
||||
"goal": goal,
|
||||
|
|
@ -259,6 +260,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
},
|
||||
)
|
||||
|
||||
tool_call = result.object
|
||||
tool_name = tool_call.get("tool", "")
|
||||
tool_arguments = tool_call.get("arguments", {})
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class SupervisorPattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Use the supervisor-decompose prompt template
|
||||
goals = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="supervisor-decompose",
|
||||
variables={
|
||||
"question": request.question,
|
||||
|
|
@ -113,6 +113,7 @@ class SupervisorPattern(PatternBase):
|
|||
},
|
||||
)
|
||||
|
||||
goals = result.objects
|
||||
# Validate result
|
||||
if not isinstance(goals, list):
|
||||
goals = []
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class AgentManager:
|
|||
client = context("prompt-request")
|
||||
|
||||
# Get streaming response
|
||||
response_text = await client.agent_react(
|
||||
await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk
|
||||
|
|
@ -275,10 +275,11 @@ class AgentManager:
|
|||
# Non-streaming path - get complete text and parse
|
||||
client = context("prompt-request")
|
||||
|
||||
response_text = await client.agent_react(
|
||||
prompt_result = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=False
|
||||
)
|
||||
response_text = prompt_result.text
|
||||
|
||||
logger.debug(f"Response text:\n{response_text}")
|
||||
|
||||
|
|
|
|||
|
|
@ -78,9 +78,10 @@ class TextCompletionImpl:
|
|||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
logger.debug("Prompt question...")
|
||||
return await client.question(
|
||||
result = await client.question(
|
||||
arguments.get("question")
|
||||
)
|
||||
return result.text
|
||||
|
||||
# This tool implementation knows how to do MCP tool invocation. This uses
|
||||
# the mcp-tool service.
|
||||
|
|
@ -227,10 +228,11 @@ class PromptImpl:
|
|||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
logger.debug(f"Prompt template invocation: {self.template_id}...")
|
||||
return await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=self.template_id,
|
||||
variables=arguments
|
||||
)
|
||||
return result.text
|
||||
|
||||
|
||||
# This tool implementation invokes a dynamically configured tool service
|
||||
|
|
|
|||
|
|
@ -117,10 +117,11 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
defs = await flow("prompt-request").extract_definitions(
|
||||
result = await flow("prompt-request").extract_definitions(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
defs = result.objects
|
||||
logger.debug(f"Definitions response: {defs}")
|
||||
|
||||
if type(defs) != list:
|
||||
|
|
|
|||
|
|
@ -376,10 +376,11 @@ class Processor(FlowProcessor):
|
|||
"""
|
||||
try:
|
||||
# Call prompt service with simplified format prompt
|
||||
extraction_response = await flow("prompt-request").prompt(
|
||||
result = await flow("prompt-request").prompt(
|
||||
id="extract-with-ontologies",
|
||||
variables=prompt_variables
|
||||
)
|
||||
extraction_response = result.object
|
||||
logger.debug(f"Simplified extraction response: {extraction_response}")
|
||||
|
||||
# Parse response into structured format
|
||||
|
|
|
|||
|
|
@ -100,10 +100,11 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
rels = await flow("prompt-request").extract_relationships(
|
||||
result = await flow("prompt-request").extract_relationships(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
rels = result.objects
|
||||
logger.debug(f"Prompt response: {rels}")
|
||||
|
||||
if type(rels) != list:
|
||||
|
|
|
|||
|
|
@ -148,11 +148,12 @@ class Processor(FlowProcessor):
|
|||
schema_dict = row_schema_translator.encode(schema)
|
||||
|
||||
# Use prompt client to extract rows based on schema
|
||||
objects = await flow("prompt-request").extract_objects(
|
||||
result = await flow("prompt-request").extract_objects(
|
||||
schema=schema_dict,
|
||||
text=text
|
||||
)
|
||||
|
||||
|
||||
objects = result.objects
|
||||
if not isinstance(objects, list):
|
||||
return []
|
||||
|
||||
|
|
|
|||
|
|
@ -37,14 +37,14 @@ class Query:
|
|||
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
result = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
if result.text:
|
||||
for line in result.text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
|
@ -228,7 +228,7 @@ class DocumentRag:
|
|||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
streaming=True,
|
||||
|
|
@ -237,10 +237,11 @@ class DocumentRag:
|
|||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
result = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs
|
||||
)
|
||||
resp = result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
|
|||
|
|
@ -134,14 +134,14 @@ class Query:
|
|||
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
result = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
if result.text:
|
||||
for line in result.text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
|
@ -751,13 +751,14 @@ class GraphRag:
|
|||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_response = await self.prompt_client.prompt(
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
scoring_response = scoring_result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring response: {scoring_response}")
|
||||
|
|
@ -821,13 +822,17 @@ class GraphRag:
|
|||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
reasoning_task = self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
return result.text
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_response, source_documents = await asyncio.gather(
|
||||
|
|
@ -928,10 +933,11 @@ class GraphRag:
|
|||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.prompt(
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
)
|
||||
resp = result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue