mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-03 20:32:38 +02:00
parent
a8e437fc7f
commit
6c7af8789d
216 changed files with 31360 additions and 1611 deletions
|
|
@ -8,6 +8,7 @@ from . library import Library
|
|||
from . flow import Flow
|
||||
from . config import Config
|
||||
from . knowledge import Knowledge
|
||||
from . collection import Collection
|
||||
from . exceptions import *
|
||||
from . types import *
|
||||
|
||||
|
|
@ -68,3 +69,6 @@ class Api:
|
|||
|
||||
def library(self):
|
||||
return Library(self)
|
||||
|
||||
def collection(self):
|
||||
return Collection(self)
|
||||
|
|
|
|||
98
trustgraph-base/trustgraph/api/collection.py
Normal file
98
trustgraph-base/trustgraph/api/collection.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import datetime
|
||||
import logging
|
||||
|
||||
from . types import CollectionMetadata
|
||||
from . exceptions import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Collection:
|
||||
|
||||
def __init__(self, api):
|
||||
self.api = api
|
||||
|
||||
def request(self, request):
|
||||
return self.api.request(f"collection-management", request)
|
||||
|
||||
def list_collections(self, user, tag_filter=None):
|
||||
|
||||
input = {
|
||||
"operation": "list-collections",
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if tag_filter:
|
||||
input["tag_filter"] = tag_filter
|
||||
|
||||
object = self.request(input)
|
||||
|
||||
try:
|
||||
# Handle case where collections might be None or missing
|
||||
if object is None or "collections" not in object:
|
||||
return []
|
||||
|
||||
collections = object.get("collections", [])
|
||||
if collections is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
CollectionMetadata(
|
||||
user = v["user"],
|
||||
collection = v["collection"],
|
||||
name = v["name"],
|
||||
description = v["description"],
|
||||
tags = v["tags"],
|
||||
created_at = v["created_at"],
|
||||
updated_at = v["updated_at"]
|
||||
)
|
||||
for v in collections
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse collection list response", exc_info=True)
|
||||
raise ProtocolException(f"Response not formatted correctly")
|
||||
|
||||
def update_collection(self, user, collection, name=None, description=None, tags=None):
|
||||
|
||||
input = {
|
||||
"operation": "update-collection",
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
if name is not None:
|
||||
input["name"] = name
|
||||
if description is not None:
|
||||
input["description"] = description
|
||||
if tags is not None:
|
||||
input["tags"] = tags
|
||||
|
||||
object = self.request(input)
|
||||
|
||||
try:
|
||||
if "collections" in object and object["collections"]:
|
||||
v = object["collections"][0]
|
||||
return CollectionMetadata(
|
||||
user = v["user"],
|
||||
collection = v["collection"],
|
||||
name = v["name"],
|
||||
description = v["description"],
|
||||
tags = v["tags"],
|
||||
created_at = v["created_at"],
|
||||
updated_at = v["updated_at"]
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse collection update response", exc_info=True)
|
||||
raise ProtocolException(f"Response not formatted correctly")
|
||||
|
||||
def delete_collection(self, user, collection):
|
||||
|
||||
input = {
|
||||
"operation": "delete-collection",
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
object = self.request(input)
|
||||
|
||||
return {}
|
||||
|
|
@ -132,12 +132,24 @@ class FlowInstance:
|
|||
input
|
||||
)["response"]
|
||||
|
||||
def agent(self, question):
|
||||
def agent(self, question, user="trustgraph", state=None, group=None, history=None):
|
||||
|
||||
# The input consists of a question
|
||||
# The input consists of a question and optional context
|
||||
input = {
|
||||
"question": question
|
||||
"question": question,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
# Only include state if it has a value
|
||||
if state is not None:
|
||||
input["state"] = state
|
||||
|
||||
# Only include group if it has a value
|
||||
if group is not None:
|
||||
input["group"] = group
|
||||
|
||||
# Always include history (empty list if None)
|
||||
input["history"] = history or []
|
||||
|
||||
return self.request(
|
||||
"service/agent",
|
||||
|
|
@ -383,3 +395,245 @@ class FlowInstance:
|
|||
input
|
||||
)
|
||||
|
||||
def objects_query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
variables=None, operation_name=None
|
||||
):
|
||||
|
||||
# The input consists of a GraphQL query and optional variables
|
||||
input = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
if variables:
|
||||
input["variables"] = variables
|
||||
|
||||
if operation_name:
|
||||
input["operation_name"] = operation_name
|
||||
|
||||
response = self.request(
|
||||
"service/objects",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
# Return the GraphQL response structure
|
||||
result = {}
|
||||
|
||||
if "data" in response:
|
||||
result["data"] = response["data"]
|
||||
|
||||
if "errors" in response and response["errors"]:
|
||||
result["errors"] = response["errors"]
|
||||
|
||||
if "extensions" in response and response["extensions"]:
|
||||
result["extensions"] = response["extensions"]
|
||||
|
||||
return result
|
||||
|
||||
def nlp_query(self, question, max_results=100):
|
||||
"""
|
||||
Convert a natural language question to a GraphQL query.
|
||||
|
||||
Args:
|
||||
question: Natural language question
|
||||
max_results: Maximum number of results to return (default: 100)
|
||||
|
||||
Returns:
|
||||
dict with graphql_query, variables, detected_schemas, confidence
|
||||
"""
|
||||
|
||||
input = {
|
||||
"question": question,
|
||||
"max_results": max_results
|
||||
}
|
||||
|
||||
response = self.request(
|
||||
"service/nlp-query",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
def structured_query(self, question, user="trustgraph", collection="default"):
|
||||
"""
|
||||
Execute a natural language question against structured data.
|
||||
Combines NLP query conversion and GraphQL execution.
|
||||
|
||||
Args:
|
||||
question: Natural language question
|
||||
user: Cassandra keyspace identifier (default: "trustgraph")
|
||||
collection: Data collection identifier (default: "default")
|
||||
|
||||
Returns:
|
||||
dict with data and optional errors
|
||||
"""
|
||||
|
||||
input = {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
response = self.request(
|
||||
"service/structured-query",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
def detect_type(self, sample):
|
||||
"""
|
||||
Detect the data type of a structured data sample.
|
||||
|
||||
Args:
|
||||
sample: Data sample to analyze (string content)
|
||||
|
||||
Returns:
|
||||
dict with detected_type, confidence, and optional metadata
|
||||
"""
|
||||
|
||||
input = {
|
||||
"operation": "detect-type",
|
||||
"sample": sample
|
||||
}
|
||||
|
||||
response = self.request(
|
||||
"service/structured-diag",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response["detected-type"]
|
||||
|
||||
def generate_descriptor(self, sample, data_type, schema_name, options=None):
|
||||
"""
|
||||
Generate a descriptor for structured data mapping to a specific schema.
|
||||
|
||||
Args:
|
||||
sample: Data sample to analyze (string content)
|
||||
data_type: Data type (csv, json, xml)
|
||||
schema_name: Target schema name for descriptor generation
|
||||
options: Optional parameters (e.g., delimiter for CSV)
|
||||
|
||||
Returns:
|
||||
dict with descriptor and metadata
|
||||
"""
|
||||
|
||||
input = {
|
||||
"operation": "generate-descriptor",
|
||||
"sample": sample,
|
||||
"type": data_type,
|
||||
"schema-name": schema_name
|
||||
}
|
||||
|
||||
if options:
|
||||
input["options"] = options
|
||||
|
||||
response = self.request(
|
||||
"service/structured-diag",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response["descriptor"]
|
||||
|
||||
def diagnose_data(self, sample, schema_name=None, options=None):
|
||||
"""
|
||||
Perform combined data diagnosis: detect type and generate descriptor.
|
||||
|
||||
Args:
|
||||
sample: Data sample to analyze (string content)
|
||||
schema_name: Optional target schema name for descriptor generation
|
||||
options: Optional parameters (e.g., delimiter for CSV)
|
||||
|
||||
Returns:
|
||||
dict with detected_type, confidence, descriptor, and metadata
|
||||
"""
|
||||
|
||||
input = {
|
||||
"operation": "diagnose",
|
||||
"sample": sample
|
||||
}
|
||||
|
||||
if schema_name:
|
||||
input["schema-name"] = schema_name
|
||||
|
||||
if options:
|
||||
input["options"] = options
|
||||
|
||||
response = self.request(
|
||||
"service/structured-diag",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
def schema_selection(self, sample, options=None):
|
||||
"""
|
||||
Select matching schemas for a data sample using prompt analysis.
|
||||
|
||||
Args:
|
||||
sample: Data sample to analyze (string content)
|
||||
options: Optional parameters
|
||||
|
||||
Returns:
|
||||
dict with schema_matches array and metadata
|
||||
"""
|
||||
|
||||
input = {
|
||||
"operation": "schema-selection",
|
||||
"sample": sample
|
||||
}
|
||||
|
||||
if options:
|
||||
input["options"] = options
|
||||
|
||||
response = self.request(
|
||||
"service/structured-diag",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response["schema-matches"]
|
||||
|
||||
|
|
|
|||
|
|
@ -41,3 +41,13 @@ class ProcessingMetadata:
|
|||
user : str
|
||||
collection : str
|
||||
tags : List[str]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CollectionMetadata:
|
||||
user : str
|
||||
collection : str
|
||||
name : str
|
||||
description : str
|
||||
tags : List[str]
|
||||
created_at : str
|
||||
updated_at : str
|
||||
|
|
|
|||
|
|
@ -31,4 +31,5 @@ from . graph_rag_client import GraphRagClientSpec
|
|||
from . tool_service import ToolService
|
||||
from . tool_client import ToolClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
|
||||
|
|
|
|||
134
trustgraph-base/trustgraph/base/cassandra_config.py
Normal file
134
trustgraph-base/trustgraph/base/cassandra_config.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Cassandra configuration utilities for standardized parameter handling.
|
||||
|
||||
Provides consistent Cassandra configuration across all TrustGraph processors,
|
||||
including command-line arguments, environment variables, and defaults.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from typing import Optional, Tuple, List, Any
|
||||
|
||||
|
||||
def get_cassandra_defaults() -> dict:
|
||||
"""
|
||||
Get default Cassandra configuration values from environment variables or fallback defaults.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with 'host', 'username', and 'password' keys
|
||||
"""
|
||||
return {
|
||||
'host': os.getenv('CASSANDRA_HOST', 'cassandra'),
|
||||
'username': os.getenv('CASSANDRA_USERNAME'),
|
||||
'password': os.getenv('CASSANDRA_PASSWORD')
|
||||
}
|
||||
|
||||
|
||||
def add_cassandra_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""
|
||||
Add standardized Cassandra configuration arguments to an argument parser.
|
||||
|
||||
Shows environment variable values in help text when they are set.
|
||||
Password values are never displayed for security.
|
||||
|
||||
Args:
|
||||
parser: ArgumentParser instance to add arguments to
|
||||
"""
|
||||
defaults = get_cassandra_defaults()
|
||||
|
||||
# Format help text with environment variable indication
|
||||
host_help = f"Cassandra host list, comma-separated (default: {defaults['host']})"
|
||||
if 'CASSANDRA_HOST' in os.environ:
|
||||
host_help += " [from CASSANDRA_HOST]"
|
||||
|
||||
username_help = "Cassandra username"
|
||||
if defaults['username']:
|
||||
username_help += f" (default: {defaults['username']})"
|
||||
if 'CASSANDRA_USERNAME' in os.environ:
|
||||
username_help += " [from CASSANDRA_USERNAME]"
|
||||
|
||||
password_help = "Cassandra password"
|
||||
if defaults['password']:
|
||||
# Never show actual password value
|
||||
password_help += " (default: <set>)"
|
||||
if 'CASSANDRA_PASSWORD' in os.environ:
|
||||
password_help += " [from CASSANDRA_PASSWORD]"
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-host',
|
||||
default=defaults['host'],
|
||||
help=host_help
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-username',
|
||||
default=defaults['username'],
|
||||
help=username_help
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-password',
|
||||
default=defaults['password'],
|
||||
help=password_help
|
||||
)
|
||||
|
||||
|
||||
def resolve_cassandra_config(
|
||||
args: Optional[Any] = None,
|
||||
host: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None
|
||||
) -> Tuple[List[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Resolve Cassandra configuration from various sources.
|
||||
|
||||
Can accept either argparse args object or explicit parameters.
|
||||
Converts host string to list format for Cassandra driver.
|
||||
|
||||
Args:
|
||||
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password
|
||||
host: Optional explicit host parameter (overrides args)
|
||||
username: Optional explicit username parameter (overrides args)
|
||||
password: Optional explicit password parameter (overrides args)
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_list, username, password)
|
||||
"""
|
||||
# If args provided, extract values
|
||||
if args is not None:
|
||||
host = host or getattr(args, 'cassandra_host', None)
|
||||
username = username or getattr(args, 'cassandra_username', None)
|
||||
password = password or getattr(args, 'cassandra_password', None)
|
||||
|
||||
# Apply defaults if still None
|
||||
defaults = get_cassandra_defaults()
|
||||
host = host or defaults['host']
|
||||
username = username or defaults['username']
|
||||
password = password or defaults['password']
|
||||
|
||||
# Convert host string to list
|
||||
if isinstance(host, str):
|
||||
hosts = [h.strip() for h in host.split(',') if h.strip()]
|
||||
else:
|
||||
hosts = host
|
||||
|
||||
return hosts, username, password
|
||||
|
||||
|
||||
def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract and resolve Cassandra configuration from a parameters dictionary.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters that may contain Cassandra configuration
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_list, username, password)
|
||||
"""
|
||||
# Get Cassandra parameters
|
||||
host = params.get('cassandra_host')
|
||||
username = params.get('cassandra_username')
|
||||
password = params.get('cassandra_password')
|
||||
|
||||
# Use resolve function to handle defaults and list conversion
|
||||
return resolve_cassandra_config(host=host, username=username, password=password)
|
||||
|
|
@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.documents
|
||||
return resp.chunks
|
||||
|
||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
docs = await self.query_document_embeddings(request)
|
||||
|
||||
logger.debug("Sending document embeddings query response...")
|
||||
r = DocumentEmbeddingsResponse(documents=docs, error=None)
|
||||
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Document embeddings query request completed")
|
||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
type = "document-embeddings-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
chunks=None,
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
|
|||
|
|
@ -12,22 +12,27 @@ logger = logging.getLogger(__name__)
|
|||
class Publisher:
|
||||
|
||||
def __init__(self, client, topic, schema=None, max_size=10,
|
||||
chunking_enabled=True):
|
||||
chunking_enabled=True, drain_timeout=5.0):
|
||||
self.client = client
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = asyncio.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
self.running = True
|
||||
self.draining = False # New state for graceful shutdown
|
||||
self.task = None
|
||||
self.drain_timeout = drain_timeout
|
||||
|
||||
async def start(self):
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
async def stop(self):
|
||||
"""Initiate graceful shutdown with draining"""
|
||||
self.running = False
|
||||
self.draining = True
|
||||
|
||||
if self.task:
|
||||
# Wait for run() to complete draining
|
||||
await self.task
|
||||
|
||||
async def join(self):
|
||||
|
|
@ -38,7 +43,7 @@ class Publisher:
|
|||
|
||||
async def run(self):
|
||||
|
||||
while self.running:
|
||||
while self.running or self.draining:
|
||||
|
||||
try:
|
||||
|
||||
|
|
@ -48,32 +53,71 @@ class Publisher:
|
|||
chunking_enabled=self.chunking_enabled,
|
||||
)
|
||||
|
||||
while self.running:
|
||||
drain_end_time = None
|
||||
|
||||
while self.running or self.draining:
|
||||
|
||||
try:
|
||||
# Start drain timeout when entering drain mode
|
||||
if self.draining and drain_end_time is None:
|
||||
drain_end_time = time.time() + self.drain_timeout
|
||||
logger.info(f"Publisher entering drain mode, timeout={self.drain_timeout}s")
|
||||
|
||||
# Check drain timeout
|
||||
if self.draining and drain_end_time and time.time() > drain_end_time:
|
||||
if not self.q.empty():
|
||||
logger.warning(f"Drain timeout reached with {self.q.qsize()} messages remaining")
|
||||
self.draining = False
|
||||
break
|
||||
|
||||
# Calculate wait timeout based on mode
|
||||
if self.draining:
|
||||
# Shorter timeout during draining to exit quickly when empty
|
||||
timeout = min(0.1, drain_end_time - time.time()) if drain_end_time else 0.1
|
||||
else:
|
||||
# Normal operation timeout
|
||||
timeout = 0.25
|
||||
|
||||
id, item = await asyncio.wait_for(
|
||||
self.q.get(),
|
||||
timeout=0.25
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# If draining and queue is empty, we're done
|
||||
if self.draining and self.q.empty():
|
||||
logger.info("Publisher queue drained successfully")
|
||||
self.draining = False
|
||||
break
|
||||
continue
|
||||
except asyncio.QueueEmpty:
|
||||
# If draining and queue is empty, we're done
|
||||
if self.draining and self.q.empty():
|
||||
logger.info("Publisher queue drained successfully")
|
||||
self.draining = False
|
||||
break
|
||||
continue
|
||||
|
||||
if id:
|
||||
producer.send(item, { "id": id })
|
||||
else:
|
||||
producer.send(item)
|
||||
|
||||
# Flush producer before closing
|
||||
producer.flush()
|
||||
producer.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in publisher: {e}", exc_info=True)
|
||||
|
||||
if not self.running:
|
||||
if not self.running and not self.draining:
|
||||
return
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def send(self, id, item):
|
||||
if self.draining:
|
||||
# Optionally reject new messages during drain
|
||||
raise RuntimeError("Publisher is shutting down, not accepting new messages")
|
||||
await self.q.put((id, item))
|
||||
|
||||
|
|
|
|||
35
trustgraph-base/trustgraph/base/structured_query_client.py
Normal file
35
trustgraph-base/trustgraph/base/structured_query_client.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
|
||||
class StructuredQueryClient(RequestResponse):
|
||||
async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
|
||||
resp = await self.request(
|
||||
StructuredQueryRequest(
|
||||
question = question,
|
||||
user = user,
|
||||
collection = collection
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Return the full response structure for the tool to handle
|
||||
return {
|
||||
"data": resp.data,
|
||||
"errors": resp.errors if resp.errors else [],
|
||||
"error": resp.error
|
||||
}
|
||||
|
||||
class StructuredQueryClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(StructuredQueryClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = StructuredQueryRequest,
|
||||
response_name = response_name,
|
||||
response_schema = StructuredQueryResponse,
|
||||
impl = StructuredQueryClient,
|
||||
)
|
||||
|
|
@ -8,6 +8,7 @@ import asyncio
|
|||
import _pulsar
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,7 +16,8 @@ logger = logging.getLogger(__name__)
|
|||
class Subscriber:
|
||||
|
||||
def __init__(self, client, topic, subscription, consumer_name,
|
||||
schema=None, max_size=100, metrics=None):
|
||||
schema=None, max_size=100, metrics=None,
|
||||
backpressure_strategy="block", drain_timeout=5.0):
|
||||
self.client = client
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
|
|
@ -26,8 +28,12 @@ class Subscriber:
|
|||
self.max_size = max_size
|
||||
self.lock = asyncio.Lock()
|
||||
self.running = True
|
||||
self.draining = False # New state for graceful shutdown
|
||||
self.metrics = metrics
|
||||
self.task = None
|
||||
self.backpressure_strategy = backpressure_strategy
|
||||
self.drain_timeout = drain_timeout
|
||||
self.pending_acks = {} # Track messages awaiting delivery
|
||||
|
||||
self.consumer = None
|
||||
|
||||
|
|
@ -47,9 +53,12 @@ class Subscriber:
|
|||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
async def stop(self):
|
||||
"""Initiate graceful shutdown with draining"""
|
||||
self.running = False
|
||||
self.draining = True
|
||||
|
||||
if self.task:
|
||||
# Wait for run() to complete draining
|
||||
await self.task
|
||||
|
||||
async def join(self):
|
||||
|
|
@ -59,8 +68,8 @@ class Subscriber:
|
|||
await self.task
|
||||
|
||||
async def run(self):
|
||||
|
||||
while self.running:
|
||||
"""Enhanced run method with integrated draining logic"""
|
||||
while self.running or self.draining:
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
|
@ -71,65 +80,73 @@ class Subscriber:
|
|||
self.metrics.state("running")
|
||||
|
||||
logger.info("Subscriber running...")
|
||||
drain_end_time = None
|
||||
|
||||
while self.running:
|
||||
while self.running or self.draining:
|
||||
# Start drain timeout when entering drain mode
|
||||
if self.draining and drain_end_time is None:
|
||||
drain_end_time = time.time() + self.drain_timeout
|
||||
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
||||
|
||||
# Stop accepting new messages from Pulsar during drain
|
||||
if self.consumer:
|
||||
self.consumer.pause_message_listener()
|
||||
|
||||
# Check drain timeout
|
||||
if self.draining and drain_end_time and time.time() > drain_end_time:
|
||||
async with self.lock:
|
||||
total_pending = sum(
|
||||
q.qsize() for q in
|
||||
list(self.q.values()) + list(self.full.values())
|
||||
)
|
||||
if total_pending > 0:
|
||||
logger.warning(f"Drain timeout reached with {total_pending} messages in queues")
|
||||
self.draining = False
|
||||
break
|
||||
|
||||
# Check if we can exit drain mode
|
||||
if self.draining:
|
||||
async with self.lock:
|
||||
all_empty = all(
|
||||
q.empty() for q in
|
||||
list(self.q.values()) + list(self.full.values())
|
||||
)
|
||||
if all_empty and len(self.pending_acks) == 0:
|
||||
logger.info("Subscriber queues drained successfully")
|
||||
self.draining = False
|
||||
break
|
||||
|
||||
# Process messages only if not draining
|
||||
if not self.draining:
|
||||
try:
|
||||
msg = await asyncio.to_thread(
|
||||
self.consumer.receive,
|
||||
timeout_millis=250
|
||||
)
|
||||
except _pulsar.Timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
try:
|
||||
msg = await asyncio.to_thread(
|
||||
self.consumer.receive,
|
||||
timeout_millis=250
|
||||
)
|
||||
except _pulsar.Timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
||||
raise e
|
||||
if self.metrics:
|
||||
self.metrics.received()
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.received()
|
||||
# Process the message with deferred acknowledgment
|
||||
await self._process_message(msg)
|
||||
else:
|
||||
# During draining, just wait for queues to empty
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Acknowledge successful reception of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
|
||||
value = msg.value()
|
||||
|
||||
async with self.lock:
|
||||
|
||||
# FIXME: Hard-coded timeouts
|
||||
|
||||
if id in self.q:
|
||||
|
||||
try:
|
||||
# FIXME: Timeout means data goes missing
|
||||
await asyncio.wait_for(
|
||||
self.q[id].put(value),
|
||||
timeout=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.dropped()
|
||||
logger.warning(f"Failed to put message in queue: {e}")
|
||||
|
||||
for q in self.full.values():
|
||||
try:
|
||||
# FIXME: Timeout means data goes missing
|
||||
await asyncio.wait_for(
|
||||
q.put(value),
|
||||
timeout=1
|
||||
)
|
||||
except Exception as e:
|
||||
self.metrics.dropped()
|
||||
logger.warning(f"Failed to put message in full queue: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscriber exception: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Negative acknowledge any pending messages
|
||||
for msg in self.pending_acks.values():
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
self.pending_acks.clear()
|
||||
|
||||
if self.consumer:
|
||||
self.consumer.unsubscribe()
|
||||
|
|
@ -140,7 +157,7 @@ class Subscriber:
|
|||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
||||
if not self.running:
|
||||
if not self.running and not self.draining:
|
||||
return
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
|
|
@ -180,3 +197,71 @@ class Subscriber:
|
|||
# self.full[id].shutdown(immediate=True)
|
||||
del self.full[id]
|
||||
|
||||
async def _process_message(self, msg):
|
||||
"""Process a single message with deferred acknowledgment"""
|
||||
# Store message for later acknowledgment
|
||||
msg_id = str(uuid.uuid4())
|
||||
self.pending_acks[msg_id] = msg
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
|
||||
value = msg.value()
|
||||
delivery_success = False
|
||||
|
||||
async with self.lock:
|
||||
# Deliver to specific subscribers
|
||||
if id in self.q:
|
||||
delivery_success = await self._deliver_to_queue(
|
||||
self.q[id], value
|
||||
)
|
||||
|
||||
# Deliver to all subscribers
|
||||
for q in self.full.values():
|
||||
if await self._deliver_to_queue(q, value):
|
||||
delivery_success = True
|
||||
|
||||
# Acknowledge only on successful delivery
|
||||
if delivery_success:
|
||||
self.consumer.acknowledge(msg)
|
||||
del self.pending_acks[msg_id]
|
||||
else:
|
||||
# Negative acknowledge for retry
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
del self.pending_acks[msg_id]
|
||||
|
||||
async def _deliver_to_queue(self, queue, value):
|
||||
"""Deliver message to queue with backpressure handling"""
|
||||
try:
|
||||
if self.backpressure_strategy == "block":
|
||||
# Block until space available (no timeout)
|
||||
await queue.put(value)
|
||||
return True
|
||||
|
||||
elif self.backpressure_strategy == "drop_oldest":
|
||||
# Drop oldest message if queue full
|
||||
if queue.full():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
if self.metrics:
|
||||
self.metrics.dropped()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
await queue.put(value)
|
||||
return True
|
||||
|
||||
elif self.backpressure_strategy == "drop_new":
|
||||
# Drop new message if queue full
|
||||
if queue.full():
|
||||
if self.metrics:
|
||||
self.metrics.dropped()
|
||||
return False
|
||||
await queue.put(value)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deliver message: {e}")
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -47,5 +47,5 @@ class DocumentEmbeddingsClient(BaseClient):
|
|||
return self.call(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, limit=limit, timeout=timeout
|
||||
).documents
|
||||
).chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,11 @@ from .translators.embeddings_query import (
|
|||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
)
|
||||
from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
|
||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
|
||||
|
||||
# Register all service translators
|
||||
TranslatorRegistry.register_service(
|
||||
|
|
@ -107,6 +112,36 @@ TranslatorRegistry.register_service(
|
|||
GraphEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"objects-query",
|
||||
ObjectsQueryRequestTranslator(),
|
||||
ObjectsQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"nlp-query",
|
||||
QuestionToStructuredQueryRequestTranslator(),
|
||||
QuestionToStructuredQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"structured-query",
|
||||
StructuredQueryRequestTranslator(),
|
||||
StructuredQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"structured-diag",
|
||||
StructuredDataDiagnosisRequestTranslator(),
|
||||
StructuredDataDiagnosisResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"collection-management",
|
||||
CollectionManagementRequestTranslator(),
|
||||
CollectionManagementResponseTranslator()
|
||||
)
|
||||
|
||||
# Register single-direction translators for document loading
|
||||
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||
|
|
|
|||
|
|
@ -17,3 +17,5 @@ from .embeddings_query import (
|
|||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
)
|
||||
from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
|
|
|
|||
|
|
@ -9,17 +9,19 @@ class AgentRequestTranslator(MessageTranslator):
|
|||
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
|
||||
return AgentRequest(
|
||||
question=data["question"],
|
||||
plan=data.get("plan", ""),
|
||||
state=data.get("state", ""),
|
||||
history=data.get("history", [])
|
||||
state=data.get("state", None),
|
||||
group=data.get("group", None),
|
||||
history=data.get("history", []),
|
||||
user=data.get("user", "trustgraph")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question,
|
||||
"plan": obj.plan,
|
||||
"state": obj.state,
|
||||
"history": obj.history
|
||||
"group": obj.group,
|
||||
"history": obj.history,
|
||||
"user": obj.user
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
114
trustgraph-base/trustgraph/messaging/translators/collection.py
Normal file
114
trustgraph-base/trustgraph/messaging/translators/collection.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from typing import Dict, Any, List
|
||||
from ...schema import CollectionManagementRequest, CollectionManagementResponse, CollectionMetadata, Error
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class CollectionManagementRequestTranslator(MessageTranslator):
|
||||
"""Translator for CollectionManagementRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementRequest:
|
||||
return CollectionManagementRequest(
|
||||
operation=data.get("operation"),
|
||||
user=data.get("user"),
|
||||
collection=data.get("collection"),
|
||||
timestamp=data.get("timestamp"),
|
||||
name=data.get("name"),
|
||||
description=data.get("description"),
|
||||
tags=data.get("tags"),
|
||||
created_at=data.get("created_at"),
|
||||
updated_at=data.get("updated_at"),
|
||||
tag_filter=data.get("tag_filter"),
|
||||
limit=data.get("limit")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: CollectionManagementRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.operation is not None:
|
||||
result["operation"] = obj.operation
|
||||
if obj.user is not None:
|
||||
result["user"] = obj.user
|
||||
if obj.collection is not None:
|
||||
result["collection"] = obj.collection
|
||||
if obj.timestamp is not None:
|
||||
result["timestamp"] = obj.timestamp
|
||||
if obj.name is not None:
|
||||
result["name"] = obj.name
|
||||
if obj.description is not None:
|
||||
result["description"] = obj.description
|
||||
if obj.tags is not None:
|
||||
result["tags"] = list(obj.tags)
|
||||
if obj.created_at is not None:
|
||||
result["created_at"] = obj.created_at
|
||||
if obj.updated_at is not None:
|
||||
result["updated_at"] = obj.updated_at
|
||||
if obj.tag_filter is not None:
|
||||
result["tag_filter"] = list(obj.tag_filter)
|
||||
if obj.limit is not None:
|
||||
result["limit"] = obj.limit
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class CollectionManagementResponseTranslator(MessageTranslator):
|
||||
"""Translator for CollectionManagementResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementResponse:
|
||||
|
||||
# Handle error
|
||||
error = None
|
||||
if "error" in data and data["error"]:
|
||||
error_data = data["error"]
|
||||
error = Error(
|
||||
type=error_data.get("type"),
|
||||
message=error_data.get("message")
|
||||
)
|
||||
|
||||
# Handle collections array
|
||||
collections = []
|
||||
if "collections" in data:
|
||||
for coll_data in data["collections"]:
|
||||
collections.append(CollectionMetadata(
|
||||
user=coll_data.get("user"),
|
||||
collection=coll_data.get("collection"),
|
||||
name=coll_data.get("name"),
|
||||
description=coll_data.get("description"),
|
||||
tags=coll_data.get("tags"),
|
||||
created_at=coll_data.get("created_at"),
|
||||
updated_at=coll_data.get("updated_at")
|
||||
))
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=error,
|
||||
timestamp=data.get("timestamp"),
|
||||
collections=collections
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: CollectionManagementResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
print("COLLECTIONMGMT", obj, flush=True)
|
||||
|
||||
if obj.error is not None:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
if obj.timestamp is not None:
|
||||
result["timestamp"] = obj.timestamp
|
||||
if obj.collections is not None:
|
||||
result["collections"] = []
|
||||
for coll in obj.collections:
|
||||
result["collections"].append({
|
||||
"user": coll.user,
|
||||
"collection": coll.collection,
|
||||
"name": coll.name,
|
||||
"description": coll.description,
|
||||
"tags": list(coll.tags) if coll.tags else [],
|
||||
"created_at": coll.created_at,
|
||||
"updated_at": coll.updated_at
|
||||
})
|
||||
|
||||
print("RESULT IS", result, flush=True)
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
import json
|
||||
from ...schema import StructuredDataDiagnosisRequest, StructuredDataDiagnosisResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class StructuredDataDiagnosisRequestTranslator(MessageTranslator):
|
||||
"""Translator for StructuredDataDiagnosisRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> StructuredDataDiagnosisRequest:
|
||||
return StructuredDataDiagnosisRequest(
|
||||
operation=data["operation"],
|
||||
sample=data["sample"],
|
||||
type=data.get("type", ""),
|
||||
schema_name=data.get("schema-name", ""),
|
||||
options=data.get("options", {})
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: StructuredDataDiagnosisRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"operation": obj.operation,
|
||||
"sample": obj.sample,
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
if obj.type:
|
||||
result["type"] = obj.type
|
||||
if obj.schema_name:
|
||||
result["schema-name"] = obj.schema_name
|
||||
if obj.options:
|
||||
result["options"] = obj.options
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class StructuredDataDiagnosisResponseTranslator(MessageTranslator):
|
||||
"""Translator for StructuredDataDiagnosisResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> StructuredDataDiagnosisResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: StructuredDataDiagnosisResponse) -> Dict[str, Any]:
|
||||
result = {
|
||||
"operation": obj.operation
|
||||
}
|
||||
|
||||
# Add optional response fields if they exist
|
||||
if obj.detected_type:
|
||||
result["detected-type"] = obj.detected_type
|
||||
if obj.confidence is not None:
|
||||
result["confidence"] = obj.confidence
|
||||
if obj.descriptor:
|
||||
# Parse JSON-encoded descriptor
|
||||
try:
|
||||
result["descriptor"] = json.loads(obj.descriptor)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result["descriptor"] = obj.descriptor
|
||||
if obj.metadata:
|
||||
result["metadata"] = obj.metadata
|
||||
if obj.schema_matches is not None:
|
||||
result["schema-matches"] = obj.schema_matches
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: StructuredDataDiagnosisResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -36,10 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.documents:
|
||||
result["documents"] = [
|
||||
doc.decode("utf-8") if isinstance(doc, bytes) else doc
|
||||
for doc in obj.documents
|
||||
if obj.chunks is not None:
|
||||
result["chunks"] = [
|
||||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||||
for chunk in obj.chunks
|
||||
]
|
||||
|
||||
return result
|
||||
|
|
@ -81,7 +81,7 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.entities:
|
||||
if obj.entities is not None:
|
||||
result["entities"] = [
|
||||
self.value_translator.from_pulsar(entity)
|
||||
for entity in obj.entities
|
||||
|
|
@ -91,4 +91,4 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class QuestionToStructuredQueryRequestTranslator(MessageTranslator):
|
||||
"""Translator for QuestionToStructuredQueryRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> QuestionToStructuredQueryRequest:
|
||||
return QuestionToStructuredQueryRequest(
|
||||
question=data.get("question", ""),
|
||||
max_results=data.get("max_results", 100)
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: QuestionToStructuredQueryRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question,
|
||||
"max_results": obj.max_results
|
||||
}
|
||||
|
||||
|
||||
class QuestionToStructuredQueryResponseTranslator(MessageTranslator):
|
||||
"""Translator for QuestionToStructuredQueryResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> QuestionToStructuredQueryResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: QuestionToStructuredQueryResponse) -> Dict[str, Any]:
|
||||
result = {
|
||||
"graphql_query": obj.graphql_query,
|
||||
"variables": dict(obj.variables) if obj.variables else {},
|
||||
"detected_schemas": list(obj.detected_schemas) if obj.detected_schemas else [],
|
||||
"confidence": obj.confidence
|
||||
}
|
||||
|
||||
# Handle system-level error
|
||||
if obj.error:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: QuestionToStructuredQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from typing import Dict, Any, Tuple, Optional
|
||||
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from .base import MessageTranslator
|
||||
import json
|
||||
|
||||
|
||||
class ObjectsQueryRequestTranslator(MessageTranslator):
|
||||
"""Translator for ObjectsQueryRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryRequest:
|
||||
return ObjectsQueryRequest(
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
query=data.get("query", ""),
|
||||
variables=data.get("variables", {}),
|
||||
operation_name=data.get("operation_name", None)
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: ObjectsQueryRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"query": obj.query,
|
||||
"variables": dict(obj.variables) if obj.variables else {}
|
||||
}
|
||||
|
||||
if obj.operation_name:
|
||||
result["operation_name"] = obj.operation_name
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ObjectsQueryResponseTranslator(MessageTranslator):
|
||||
"""Translator for ObjectsQueryResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: ObjectsQueryResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Handle GraphQL response data
|
||||
if obj.data:
|
||||
try:
|
||||
result["data"] = json.loads(obj.data)
|
||||
except json.JSONDecodeError:
|
||||
result["data"] = obj.data
|
||||
else:
|
||||
result["data"] = None
|
||||
|
||||
# Handle GraphQL errors
|
||||
if obj.errors:
|
||||
result["errors"] = []
|
||||
for error in obj.errors:
|
||||
error_dict = {
|
||||
"message": error.message
|
||||
}
|
||||
if error.path:
|
||||
error_dict["path"] = list(error.path)
|
||||
if error.extensions:
|
||||
error_dict["extensions"] = dict(error.extensions)
|
||||
result["errors"].append(error_dict)
|
||||
|
||||
# Handle extensions
|
||||
if obj.extensions:
|
||||
result["extensions"] = dict(obj.extensions)
|
||||
|
||||
# Handle system-level error
|
||||
if obj.error:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: ObjectsQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
from .base import MessageTranslator
|
||||
import json
|
||||
|
||||
|
||||
class StructuredQueryRequestTranslator(MessageTranslator):
|
||||
"""Translator for StructuredQueryRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest:
|
||||
return StructuredQueryRequest(
|
||||
question=data.get("question", ""),
|
||||
user=data.get("user", "trustgraph"), # Default fallback
|
||||
collection=data.get("collection", "default") # Default fallback
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
}
|
||||
|
||||
|
||||
class StructuredQueryResponseTranslator(MessageTranslator):
|
||||
"""Translator for StructuredQueryResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: StructuredQueryResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Handle structured query response data
|
||||
if obj.data:
|
||||
try:
|
||||
result["data"] = json.loads(obj.data)
|
||||
except json.JSONDecodeError:
|
||||
result["data"] = obj.data
|
||||
else:
|
||||
result["data"] = None
|
||||
|
||||
# Handle errors (array of strings)
|
||||
if obj.errors:
|
||||
result["errors"] = list(obj.errors)
|
||||
else:
|
||||
result["errors"] = []
|
||||
|
||||
# Handle system-level error
|
||||
if obj.error:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: StructuredQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Map, Double
|
||||
from pulsar.schema import Record, String, Map, Double, Array
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.topic import topic
|
||||
|
|
@ -10,7 +10,7 @@ from ..core.topic import topic
|
|||
class ExtractedObject(Record):
|
||||
metadata = Metadata()
|
||||
schema_name = String() # Which schema this object belongs to
|
||||
values = Map(String()) # Field name -> value
|
||||
values = Array(Map(String())) # Array of objects, each object is field name -> value
|
||||
confidence = Double()
|
||||
source_span = String() # Text span where object was found
|
||||
|
||||
|
|
|
|||
|
|
@ -8,4 +8,8 @@ from .config import *
|
|||
from .library import *
|
||||
from .lookup import *
|
||||
from .nlp_query import *
|
||||
from .structured_query import *
|
||||
from .structured_query import *
|
||||
from .objects_query import *
|
||||
from .diagnosis import *
|
||||
from .collection import *
|
||||
from .storage import *
|
||||
|
|
@ -13,12 +13,14 @@ class AgentStep(Record):
|
|||
action = String()
|
||||
arguments = Map(String())
|
||||
observation = String()
|
||||
user = String() # User context for the step
|
||||
|
||||
class AgentRequest(Record):
|
||||
question = String()
|
||||
plan = String()
|
||||
state = String()
|
||||
group = Array(String())
|
||||
history = Array(AgentStep())
|
||||
user = String() # User context for multi-tenancy
|
||||
|
||||
class AgentResponse(Record):
|
||||
answer = String()
|
||||
|
|
|
|||
59
trustgraph-base/trustgraph/schema/services/collection.py
Normal file
59
trustgraph-base/trustgraph/schema/services/collection.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from pulsar.schema import Record, String, Integer, Array
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.primitives import Error
|
||||
from ..core.topic import topic
|
||||
|
||||
############################################################################
|
||||
|
||||
# Collection management operations
|
||||
|
||||
# Collection metadata operations (for librarian service)
|
||||
|
||||
class CollectionMetadata(Record):
|
||||
"""Collection metadata record"""
|
||||
user = String()
|
||||
collection = String()
|
||||
name = String()
|
||||
description = String()
|
||||
tags = Array(String())
|
||||
created_at = String() # ISO timestamp
|
||||
updated_at = String() # ISO timestamp
|
||||
|
||||
############################################################################
|
||||
|
||||
class CollectionManagementRequest(Record):
|
||||
"""Request for collection management operations"""
|
||||
operation = String() # e.g., "delete-collection"
|
||||
|
||||
# For 'list-collections'
|
||||
user = String()
|
||||
collection = String()
|
||||
timestamp = String() # ISO timestamp
|
||||
name = String()
|
||||
description = String()
|
||||
tags = Array(String())
|
||||
created_at = String() # ISO timestamp
|
||||
updated_at = String() # ISO timestamp
|
||||
|
||||
# For list
|
||||
tag_filter = Array(String()) # Optional filter by tags
|
||||
limit = Integer()
|
||||
|
||||
class CollectionManagementResponse(Record):
|
||||
"""Response for collection management operations"""
|
||||
error = Error() # Only populated if there's an error
|
||||
timestamp = String() # ISO timestamp
|
||||
collections = Array(CollectionMetadata())
|
||||
|
||||
|
||||
############################################################################
|
||||
|
||||
# Topics
|
||||
|
||||
collection_request_queue = topic(
|
||||
'collection', kind='non-persistent', namespace='request'
|
||||
)
|
||||
collection_response_queue = topic(
|
||||
'collection', kind='non-persistent', namespace='response'
|
||||
)
|
||||
33
trustgraph-base/trustgraph/schema/services/diagnosis.py
Normal file
33
trustgraph-base/trustgraph/schema/services/diagnosis.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from pulsar.schema import Record, String, Map, Double, Array
|
||||
from ..core.primitives import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# Structured data diagnosis services
|
||||
|
||||
class StructuredDataDiagnosisRequest(Record):
|
||||
operation = String() # "detect-type", "generate-descriptor", "diagnose", or "schema-selection"
|
||||
sample = String() # Data sample to analyze (text content)
|
||||
type = String() # Data type (csv, json, xml) - optional, required for generate-descriptor
|
||||
schema_name = String() # Target schema name for descriptor generation - optional
|
||||
|
||||
# JSON encoded options (e.g., delimiter for CSV)
|
||||
options = Map(String())
|
||||
|
||||
class StructuredDataDiagnosisResponse(Record):
|
||||
error = Error()
|
||||
|
||||
operation = String() # The operation that was performed
|
||||
detected_type = String() # Detected data type (for detect-type/diagnose) - optional
|
||||
confidence = Double() # Confidence score for type detection - optional
|
||||
|
||||
# JSON encoded descriptor (for generate-descriptor/diagnose) - optional
|
||||
descriptor = String()
|
||||
|
||||
# JSON encoded additional metadata (e.g., field count, sample records)
|
||||
metadata = Map(String())
|
||||
|
||||
# Array of matching schema IDs (for schema-selection operation) - optional
|
||||
schema_matches = Array(String())
|
||||
|
||||
############################################################################
|
||||
|
|
@ -7,16 +7,15 @@ from ..core.topic import topic
|
|||
|
||||
# NLP to Structured Query Service - converts natural language to GraphQL
|
||||
|
||||
class NLPToStructuredQueryRequest(Record):
|
||||
natural_language_query = String()
|
||||
class QuestionToStructuredQueryRequest(Record):
|
||||
question = String()
|
||||
max_results = Integer()
|
||||
context_hints = Map(String()) # Optional context for query generation
|
||||
|
||||
class NLPToStructuredQueryResponse(Record):
|
||||
class QuestionToStructuredQueryResponse(Record):
|
||||
error = Error()
|
||||
graphql_query = String() # Generated GraphQL query
|
||||
variables = Map(String()) # GraphQL variables if any
|
||||
detected_schemas = Array(String()) # Which schemas the query targets
|
||||
confidence = Double()
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
28
trustgraph-base/trustgraph/schema/services/objects_query.py
Normal file
28
trustgraph-base/trustgraph/schema/services/objects_query.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from pulsar.schema import Record, String, Map, Array
|
||||
|
||||
from ..core.primitives import Error
|
||||
from ..core.topic import topic
|
||||
|
||||
############################################################################
|
||||
|
||||
# Objects Query Service - executes GraphQL queries against structured data
|
||||
|
||||
class GraphQLError(Record):
|
||||
message = String()
|
||||
path = Array(String()) # Path to the field that caused the error
|
||||
extensions = Map(String()) # Additional error metadata
|
||||
|
||||
class ObjectsQueryRequest(Record):
|
||||
user = String() # Cassandra keyspace (follows pattern from TriplesQueryRequest)
|
||||
collection = String() # Data collection identifier (required for partition key)
|
||||
query = String() # GraphQL query string
|
||||
variables = Map(String()) # GraphQL variables
|
||||
operation_name = String() # Operation to execute for multi-operation documents
|
||||
|
||||
class ObjectsQueryResponse(Record):
|
||||
error = Error() # System-level error (connection, timeout, etc.)
|
||||
data = String() # JSON-encoded GraphQL response data
|
||||
errors = Array(GraphQLError()) # GraphQL field-level errors
|
||||
extensions = Map(String()) # Query metadata (execution time, etc.)
|
||||
|
||||
############################################################################
|
||||
|
|
@ -45,4 +45,11 @@ class DocumentEmbeddingsRequest(Record):
|
|||
|
||||
class DocumentEmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
chunks = Array(String())
|
||||
chunks = Array(String())
|
||||
|
||||
document_embeddings_request_queue = topic(
|
||||
"non-persistent://trustgraph/document-embeddings-request"
|
||||
)
|
||||
document_embeddings_response_queue = topic(
|
||||
"non-persistent://trustgraph/document-embeddings-response"
|
||||
)
|
||||
42
trustgraph-base/trustgraph/schema/services/storage.py
Normal file
42
trustgraph-base/trustgraph/schema/services/storage.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from pulsar.schema import Record, String
|
||||
|
||||
from ..core.primitives import Error
|
||||
from ..core.topic import topic
|
||||
|
||||
############################################################################
|
||||
|
||||
# Storage management operations
|
||||
|
||||
class StorageManagementRequest(Record):
|
||||
"""Request for storage management operations sent to store processors"""
|
||||
operation = String() # e.g., "delete-collection"
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
class StorageManagementResponse(Record):
|
||||
"""Response from storage processors for management operations"""
|
||||
error = Error() # Only populated if there's an error, if null success
|
||||
|
||||
############################################################################
|
||||
|
||||
# Storage management topics
|
||||
|
||||
# Topics for sending collection management requests to different storage types
|
||||
vector_storage_management_topic = topic(
|
||||
'vector-storage-management', kind='non-persistent', namespace='request'
|
||||
)
|
||||
|
||||
object_storage_management_topic = topic(
|
||||
'object-storage-management', kind='non-persistent', namespace='request'
|
||||
)
|
||||
|
||||
triples_storage_management_topic = topic(
|
||||
'triples-storage-management', kind='non-persistent', namespace='request'
|
||||
)
|
||||
|
||||
# Topic for receiving responses from storage processors
|
||||
storage_management_response_topic = topic(
|
||||
'storage-management', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
|
@ -8,13 +8,13 @@ from ..core.topic import topic
|
|||
# Structured Query Service - executes GraphQL queries
|
||||
|
||||
class StructuredQueryRequest(Record):
|
||||
query = String() # GraphQL query
|
||||
variables = Map(String()) # GraphQL variables
|
||||
operation_name = String() # Optional operation name for multi-operation documents
|
||||
question = String()
|
||||
user = String() # Cassandra keyspace identifier
|
||||
collection = String() # Data collection identifier
|
||||
|
||||
class StructuredQueryResponse(Record):
|
||||
error = Error()
|
||||
data = String() # JSON-encoded GraphQL response data
|
||||
errors = Array(String()) # GraphQL errors if any
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue