mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-04 04:15:14 +02:00
parent
a8e437fc7f
commit
6c7af8789d
216 changed files with 31360 additions and 1611 deletions
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=1.2,<1.3",
|
||||
"trustgraph-base>=1.4,<1.5",
|
||||
"aiohttp",
|
||||
"anthropic",
|
||||
"cassandra-driver",
|
||||
|
|
@ -40,6 +40,7 @@ dependencies = [
|
|||
"qdrant-client",
|
||||
"rdflib",
|
||||
"requests",
|
||||
"strawberry-graphql",
|
||||
"tabulate",
|
||||
"tiktoken",
|
||||
"urllib3",
|
||||
|
|
@ -86,14 +87,16 @@ kg-store = "trustgraph.storage.knowledge:run"
|
|||
librarian = "trustgraph.librarian:run"
|
||||
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
||||
metering = "trustgraph.metering:run"
|
||||
nlp-query = "trustgraph.retrieval.nlp_query:run"
|
||||
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
|
||||
oe-write-milvus = "trustgraph.storage.object_embeddings.milvus:run"
|
||||
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
|
||||
pdf-decoder = "trustgraph.decoding.pdf:run"
|
||||
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
|
||||
prompt-template = "trustgraph.prompt.template:run"
|
||||
rev-gateway = "trustgraph.rev_gateway:run"
|
||||
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
|
||||
run-processing = "trustgraph.processing:run"
|
||||
structured-query = "trustgraph.retrieval.structured_query:run"
|
||||
structured-diag = "trustgraph.retrieval.structured_diag:run"
|
||||
text-completion-azure = "trustgraph.model.text_completion.azure:run"
|
||||
text-completion-azure-openai = "trustgraph.model.text_completion.azure_openai:run"
|
||||
text-completion-claude = "trustgraph.model.text_completion.claude:run"
|
||||
|
|
|
|||
|
|
@ -269,13 +269,7 @@ class AgentManager:
|
|||
|
||||
logger.debug(f"TOOL>>> {act}")
|
||||
|
||||
# Instantiate the tool implementation with context and config
|
||||
if action.config:
|
||||
tool_instance = action.implementation(context, **action.config)
|
||||
else:
|
||||
tool_instance = action.implementation(context)
|
||||
|
||||
resp = await tool_instance.invoke(
|
||||
resp = await action.implementation(context).invoke(
|
||||
**act.arguments
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,12 +12,13 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
|
||||
from . agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
from . types import Final, Action, Tool, Argument
|
||||
|
||||
|
|
@ -79,6 +80,13 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
StructuredQueryClientSpec(
|
||||
request_name = "structured-query-request",
|
||||
response_name = "structured-query-response",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -137,11 +145,21 @@ class Processor(AgentService):
|
|||
template_id=data.get("template"),
|
||||
arguments=arguments
|
||||
)
|
||||
elif impl_id == "structured-query":
|
||||
impl = functools.partial(
|
||||
StructuredQueryImpl,
|
||||
collection=data.get("collection"),
|
||||
user=None # User will be provided dynamically via context
|
||||
)
|
||||
arguments = StructuredQueryImpl.get_arguments()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool type {impl_id} not known"
|
||||
)
|
||||
|
||||
# Validate tool configuration
|
||||
validate_tool_config(data)
|
||||
|
||||
tools[name] = Tool(
|
||||
name=name,
|
||||
description=data.get("description"),
|
||||
|
|
@ -219,14 +237,43 @@ class Processor(AgentService):
|
|||
|
||||
await respond(r)
|
||||
|
||||
# Apply tool filtering based on request groups and state
|
||||
filtered_tools = filter_tools_by_group_and_state(
|
||||
tools=self.agent.tools,
|
||||
requested_groups=getattr(request, 'group', None),
|
||||
current_state=getattr(request, 'state', None)
|
||||
)
|
||||
|
||||
logger.info(f"Filtered from {len(self.agent.tools)} to {len(filtered_tools)} available tools")
|
||||
|
||||
# Create temporary agent with filtered tools
|
||||
temp_agent = AgentManager(
|
||||
tools=filtered_tools,
|
||||
additional_context=self.agent.additional_context
|
||||
)
|
||||
|
||||
logger.debug("Call React")
|
||||
|
||||
act = await self.agent.react(
|
||||
# Create user-aware context wrapper that preserves the flow interface
|
||||
# but adds user information for tools that need it
|
||||
class UserAwareContext:
|
||||
def __init__(self, flow, user):
|
||||
self._flow = flow
|
||||
self._user = user
|
||||
|
||||
def __call__(self, service_name):
|
||||
client = self._flow(service_name)
|
||||
# For structured query clients, store user context
|
||||
if service_name == "structured-query-request":
|
||||
client._current_user = self._user
|
||||
return client
|
||||
|
||||
act = await temp_agent.react(
|
||||
question = request.question,
|
||||
history = history,
|
||||
think = think,
|
||||
observe = observe,
|
||||
context = flow,
|
||||
context = UserAwareContext(flow, request.user),
|
||||
)
|
||||
|
||||
logger.debug(f"Action: {act}")
|
||||
|
|
@ -255,11 +302,17 @@ class Processor(AgentService):
|
|||
logger.debug("Send next...")
|
||||
|
||||
history.append(act)
|
||||
|
||||
# Handle state transitions if tool execution was successful
|
||||
next_state = request.state
|
||||
if act.name in filtered_tools:
|
||||
executed_tool = filtered_tools[act.name]
|
||||
next_state = get_next_state(executed_tool, request.state or "undefined")
|
||||
|
||||
r = AgentRequest(
|
||||
question=request.question,
|
||||
plan=request.plan,
|
||||
state=request.state,
|
||||
state=next_state,
|
||||
group=getattr(request, 'group', []),
|
||||
history=[
|
||||
AgentStep(
|
||||
thought=h.thought,
|
||||
|
|
|
|||
|
|
@ -85,6 +85,49 @@ class McpToolImpl:
|
|||
return json.dumps(output)
|
||||
|
||||
|
||||
# This tool implementation knows how to query structured data using natural language
|
||||
class StructuredQueryImpl:
|
||||
def __init__(self, context, collection=None, user=None):
|
||||
self.context = context
|
||||
self.collection = collection # For multi-tenant scenarios
|
||||
self.user = user # User context for multi-tenancy
|
||||
|
||||
@staticmethod
|
||||
def get_arguments():
|
||||
return [
|
||||
Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="Natural language question about structured data (tables, databases, etc.)"
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
client = self.context("structured-query-request")
|
||||
logger.debug("Structured query question...")
|
||||
|
||||
# Get user from client context if available, otherwise use instance user or default
|
||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||
|
||||
result = await client.structured_query(
|
||||
question=arguments.get("question"),
|
||||
user=user,
|
||||
collection=self.collection or "default"
|
||||
)
|
||||
|
||||
# Format the result for the agent
|
||||
if isinstance(result, dict):
|
||||
if result.get("error"):
|
||||
return f"Error: {result['error']['message']}"
|
||||
elif result.get("data"):
|
||||
# Pretty format JSON data for agent consumption
|
||||
return json.dumps(result["data"], indent=2)
|
||||
else:
|
||||
return "No data returned"
|
||||
else:
|
||||
return str(result)
|
||||
|
||||
|
||||
# This tool implementation knows how to execute prompt templates
|
||||
class PromptImpl:
|
||||
def __init__(self, context, template_id, arguments=None):
|
||||
|
|
|
|||
165
trustgraph-flow/trustgraph/agent/tool_filter.py
Normal file
165
trustgraph-flow/trustgraph/agent/tool_filter.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""
|
||||
Tool filtering logic for the TrustGraph tool group system.
|
||||
|
||||
Provides functions to filter available tools based on group membership
|
||||
and execution state as defined in the tool-group tech spec.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_tools_by_group_and_state(
|
||||
tools: Dict[str, Any],
|
||||
requested_groups: Optional[List[str]] = None,
|
||||
current_state: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter tools based on group membership and execution state.
|
||||
|
||||
Args:
|
||||
tools: Dictionary of tool_name -> tool_object
|
||||
requested_groups: List of groups requested (defaults to ["default"])
|
||||
current_state: Current execution state (defaults to "undefined")
|
||||
|
||||
Returns:
|
||||
Dictionary of filtered tools that match group and state criteria
|
||||
"""
|
||||
|
||||
# Apply defaults as specified in tech spec
|
||||
if requested_groups is None:
|
||||
requested_groups = ["default"]
|
||||
if current_state is None or current_state == "":
|
||||
current_state = "undefined"
|
||||
|
||||
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")
|
||||
|
||||
filtered_tools = {}
|
||||
|
||||
for tool_name, tool in tools.items():
|
||||
if _is_tool_available(tool, requested_groups, current_state):
|
||||
filtered_tools[tool_name] = tool
|
||||
else:
|
||||
logger.debug(f"Tool {tool_name} filtered out")
|
||||
|
||||
logger.info(f"Filtered {len(tools)} tools to {len(filtered_tools)} available tools")
|
||||
return filtered_tools
|
||||
|
||||
|
||||
def _is_tool_available(
|
||||
tool: Any,
|
||||
requested_groups: List[str],
|
||||
current_state: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a tool is available based on group and state criteria.
|
||||
|
||||
Args:
|
||||
tool: Tool object with config attribute containing group/state metadata
|
||||
requested_groups: List of requested groups
|
||||
current_state: Current execution state
|
||||
|
||||
Returns:
|
||||
True if tool should be available, False otherwise
|
||||
"""
|
||||
|
||||
# Extract tool configuration
|
||||
config = getattr(tool, 'config', {})
|
||||
|
||||
# Get tool groups (default to ["default"] if not specified)
|
||||
tool_groups = config.get('group', ["default"])
|
||||
if not isinstance(tool_groups, list):
|
||||
tool_groups = [tool_groups]
|
||||
|
||||
# Get tool applicable states (default to all states if not specified)
|
||||
applicable_states = config.get('applicable-states', ["*"])
|
||||
if not isinstance(applicable_states, list):
|
||||
applicable_states = [applicable_states]
|
||||
|
||||
# Apply group filtering logic from tech spec:
|
||||
# Tool is available if intersection(tool_groups, requested_groups) is not empty
|
||||
# OR "*" is in requested_groups (wildcard access)
|
||||
group_match = (
|
||||
"*" in requested_groups or
|
||||
bool(set(tool_groups) & set(requested_groups))
|
||||
)
|
||||
|
||||
# Apply state filtering logic from tech spec:
|
||||
# Tool is available if current_state is in applicable_states
|
||||
# OR "*" is in applicable_states (available in all states)
|
||||
state_match = (
|
||||
"*" in applicable_states or
|
||||
current_state in applicable_states
|
||||
)
|
||||
|
||||
is_available = group_match and state_match
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
f"Tool availability check: tool_groups={tool_groups}, "
|
||||
f"requested_groups={requested_groups}, applicable_states={applicable_states}, "
|
||||
f"current_state={current_state}, group_match={group_match}, "
|
||||
f"state_match={state_match}, is_available={is_available}"
|
||||
)
|
||||
|
||||
return is_available
|
||||
|
||||
|
||||
def get_next_state(tool: Any, current_state: str) -> str:
|
||||
"""
|
||||
Get the next state after successful tool execution.
|
||||
|
||||
Args:
|
||||
tool: Tool object with config attribute
|
||||
current_state: Current execution state
|
||||
|
||||
Returns:
|
||||
Next state, or current_state if no transition is defined
|
||||
"""
|
||||
config = getattr(tool, 'config', {})
|
||||
if config is None:
|
||||
config = {}
|
||||
next_state = config.get('state')
|
||||
|
||||
if next_state:
|
||||
logger.debug(f"State transition: {current_state} -> {next_state}")
|
||||
return next_state
|
||||
else:
|
||||
logger.debug(f"No state transition defined, staying in {current_state}")
|
||||
return current_state
|
||||
|
||||
|
||||
def validate_tool_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate tool configuration for group and state fields.
|
||||
|
||||
Args:
|
||||
config: Tool configuration dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
|
||||
# Validate group field
|
||||
if 'group' in config:
|
||||
groups = config['group']
|
||||
if not isinstance(groups, list):
|
||||
raise ValueError("Tool 'group' field must be a list of strings")
|
||||
if not all(isinstance(g, str) for g in groups):
|
||||
raise ValueError("All group names must be strings")
|
||||
|
||||
# Validate state field
|
||||
if 'state' in config:
|
||||
state = config['state']
|
||||
if not isinstance(state, str):
|
||||
raise ValueError("Tool 'state' field must be a string")
|
||||
|
||||
# Validate applicable-states field
|
||||
if 'applicable-states' in config:
|
||||
states = config['applicable-states']
|
||||
if not isinstance(states, list):
|
||||
raise ValueError("Tool 'applicable-states' field must be a list of strings")
|
||||
if not all(isinstance(s, str) for s in states):
|
||||
raise ValueError("All state names must be strings")
|
||||
|
|
@ -45,13 +45,13 @@ class Configuration:
|
|||
|
||||
# FIXME: Some version vs config race conditions
|
||||
|
||||
def __init__(self, push, host, user, password, keyspace):
|
||||
def __init__(self, push, host, username, password, keyspace):
|
||||
|
||||
# External function to respond to update
|
||||
self.push = push
|
||||
|
||||
self.table_store = ConfigTableStore(
|
||||
host, user, password, keyspace
|
||||
host, username, password, keyspace
|
||||
)
|
||||
|
||||
async def inc_version(self):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from trustgraph.schema import FlowRequest, FlowResponse
|
|||
from trustgraph.schema import flow_request_queue, flow_response_queue
|
||||
|
||||
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
||||
from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
from . config import Configuration
|
||||
from . flow import FlowConfig
|
||||
|
|
@ -60,9 +61,21 @@ class Processor(AsyncProcessor):
|
|||
"flow_response_queue", default_flow_response_queue
|
||||
)
|
||||
|
||||
cassandra_host = params.get("cassandra_host", default_cassandra_host)
|
||||
cassandra_user = params.get("cassandra_user")
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration
|
||||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
id = params.get("id")
|
||||
|
||||
|
|
@ -76,8 +89,9 @@ class Processor(AsyncProcessor):
|
|||
"config_push_schema": ConfigPush.__name__,
|
||||
"flow_request_schema": FlowRequest.__name__,
|
||||
"flow_response_schema": FlowResponse.__name__,
|
||||
"cassandra_host": cassandra_host,
|
||||
"cassandra_user": cassandra_user,
|
||||
"cassandra_host": self.cassandra_host,
|
||||
"cassandra_username": self.cassandra_username,
|
||||
"cassandra_password": self.cassandra_password,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -142,9 +156,9 @@ class Processor(AsyncProcessor):
|
|||
)
|
||||
|
||||
self.config = Configuration(
|
||||
host = cassandra_host.split(","),
|
||||
user = cassandra_user,
|
||||
password = cassandra_password,
|
||||
host = self.cassandra_host,
|
||||
username = self.cassandra_username,
|
||||
password = self.cassandra_password,
|
||||
keyspace = keyspace,
|
||||
push = self.push
|
||||
)
|
||||
|
|
@ -276,23 +290,7 @@ class Processor(AsyncProcessor):
|
|||
help=f'Flow response queue {default_flow_response_queue}',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-host',
|
||||
default="cassandra",
|
||||
help=f'Graph host (default: cassandra)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-user',
|
||||
default=None,
|
||||
help=f'Cassandra user'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-password',
|
||||
default=None,
|
||||
help=f'Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ logger = logging.getLogger(__name__)
|
|||
class KnowledgeManager:
|
||||
|
||||
def __init__(
|
||||
self, cassandra_host, cassandra_user, cassandra_password,
|
||||
self, cassandra_host, cassandra_username, cassandra_password,
|
||||
keyspace, flow_config,
|
||||
):
|
||||
|
||||
self.table_store = KnowledgeTableStore(
|
||||
cassandra_host, cassandra_user, cassandra_password, keyspace
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace
|
||||
)
|
||||
|
||||
self.loader_queue = asyncio.Queue(maxsize=20)
|
||||
|
|
@ -248,6 +248,9 @@ class KnowledgeManager:
|
|||
await ge_pub.start()
|
||||
|
||||
async def publish_triples(t):
|
||||
# Override collection with request collection
|
||||
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
|
||||
t.metadata.collection = request.collection or "default"
|
||||
await t_pub.send(None, t)
|
||||
|
||||
logger.debug("Publishing triples...")
|
||||
|
|
@ -260,6 +263,9 @@ class KnowledgeManager:
|
|||
)
|
||||
|
||||
async def publish_ge(g):
|
||||
# Override collection with request collection
|
||||
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
|
||||
g.metadata.collection = request.collection or "default"
|
||||
await ge_pub.send(None, g)
|
||||
|
||||
logger.debug("Publishing graph embeddings...")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import logging
|
|||
|
||||
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber
|
||||
from .. base import ConsumerMetrics, ProducerMetrics
|
||||
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
from .. schema import KnowledgeRequest, KnowledgeResponse, Error
|
||||
from .. schema import knowledge_request_queue, knowledge_response_queue
|
||||
|
|
@ -49,16 +50,29 @@ class Processor(AsyncProcessor):
|
|||
"knowledge_response_queue", default_knowledge_response_queue
|
||||
)
|
||||
|
||||
cassandra_host = params.get("cassandra_host", default_cassandra_host)
|
||||
cassandra_user = params.get("cassandra_user")
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration
|
||||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"knowledge_request_queue": knowledge_request_queue,
|
||||
"knowledge_response_queue": knowledge_response_queue,
|
||||
"cassandra_host": cassandra_host,
|
||||
"cassandra_user": cassandra_user,
|
||||
"cassandra_host": self.cassandra_host,
|
||||
"cassandra_username": self.cassandra_username,
|
||||
"cassandra_password": self.cassandra_password,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -89,9 +103,9 @@ class Processor(AsyncProcessor):
|
|||
)
|
||||
|
||||
self.knowledge = KnowledgeManager(
|
||||
cassandra_host = cassandra_host.split(","),
|
||||
cassandra_user = cassandra_user,
|
||||
cassandra_password = cassandra_password,
|
||||
cassandra_host = self.cassandra_host,
|
||||
cassandra_username = self.cassandra_username,
|
||||
cassandra_password = self.cassandra_password,
|
||||
keyspace = keyspace,
|
||||
flow_config = self,
|
||||
)
|
||||
|
|
@ -210,23 +224,7 @@ class Processor(AsyncProcessor):
|
|||
help=f'Config response queue {default_knowledge_response_queue}',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-host',
|
||||
default="cassandra",
|
||||
help=f'Graph host (default: cassandra)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-user',
|
||||
default=None,
|
||||
help=f'Cassandra user'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-password',
|
||||
default=None,
|
||||
help=f'Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
|
|
|
|||
|
|
@ -1,137 +0,0 @@
|
|||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
|
||||
# Global list to track clusters for cleanup
|
||||
_active_clusters = []
|
||||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(
|
||||
self, hosts=None,
|
||||
keyspace="trustgraph", table="default", username=None, password=None
|
||||
):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.keyspace = keyspace
|
||||
self.table = table
|
||||
self.username = username
|
||||
|
||||
if username and password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||
else:
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
# Track this cluster globally
|
||||
_active_clusters.append(self.cluster)
|
||||
|
||||
self.init()
|
||||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute(f"""
|
||||
drop keyspace if exists {self.keyspace};
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute(f"""
|
||||
create keyspace if not exists {self.keyspace}
|
||||
with replication = {{
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
}};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
|
||||
self.session.execute(f"""
|
||||
create table if not exists {self.table} (
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_p
|
||||
ON {self.table} (p);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_o
|
||||
ON {self.table} (o);
|
||||
""");
|
||||
|
||||
def insert(self, s, p, o):
|
||||
|
||||
self.session.execute(
|
||||
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def get_all(self, limit=50):
|
||||
return self.session.execute(
|
||||
f"select s, p, o from {self.table} limit {limit}"
|
||||
)
|
||||
|
||||
def get_s(self, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p, o from {self.table} where s = %s limit {limit}",
|
||||
(s,)
|
||||
)
|
||||
|
||||
def get_p(self, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, o from {self.table} where p = %s limit {limit}",
|
||||
(p,)
|
||||
)
|
||||
|
||||
def get_o(self, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, p from {self.table} where o = %s limit {limit}",
|
||||
(o,)
|
||||
)
|
||||
|
||||
def get_sp(self, s, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select o from {self.table} where s = %s and p = %s limit {limit}",
|
||||
(s, p)
|
||||
)
|
||||
|
||||
def get_po(self, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s from {self.table} where p = %s and o = %s limit {limit} allow filtering",
|
||||
(p, o)
|
||||
)
|
||||
|
||||
def get_os(self, o, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p from {self.table} where o = %s and s = %s limit {limit}",
|
||||
(o, s)
|
||||
)
|
||||
|
||||
def get_spo(self, s, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the Cassandra session and cluster connections properly"""
|
||||
if hasattr(self, 'session') and self.session:
|
||||
self.session.shutdown()
|
||||
if hasattr(self, 'cluster') and self.cluster:
|
||||
self.cluster.shutdown()
|
||||
# Remove from global tracking
|
||||
if self.cluster in _active_clusters:
|
||||
_active_clusters.remove(self.cluster)
|
||||
350
trustgraph-flow/trustgraph/direct/cassandra_kg.py
Normal file
350
trustgraph-flow/trustgraph/direct/cassandra_kg.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.query import BatchStatement, SimpleStatement
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Global list to track clusters for cleanup
|
||||
_active_clusters = []
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class KnowledgeGraph:
|
||||
|
||||
def __init__(
|
||||
self, hosts=None,
|
||||
keyspace="trustgraph", username=None, password=None
|
||||
):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.keyspace = keyspace
|
||||
self.username = username
|
||||
|
||||
# Multi-table schema design for optimal performance
|
||||
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
|
||||
|
||||
if self.use_legacy:
|
||||
self.table = "triples" # Legacy single table
|
||||
else:
|
||||
# New optimized tables
|
||||
self.subject_table = "triples_s"
|
||||
self.po_table = "triples_p"
|
||||
self.object_table = "triples_o"
|
||||
|
||||
if username and password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||
else:
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
# Track this cluster globally
|
||||
_active_clusters.append(self.cluster)
|
||||
|
||||
self.init()
|
||||
|
||||
if not self.use_legacy:
|
||||
self.prepare_statements()
|
||||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute(f"""
|
||||
drop keyspace if exists {self.keyspace};
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute(f"""
|
||||
create keyspace if not exists {self.keyspace}
|
||||
with replication = {{
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
}};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
|
||||
if self.use_legacy:
|
||||
self.init_legacy_schema()
|
||||
else:
|
||||
self.init_optimized_schema()
|
||||
|
||||
def init_legacy_schema(self):
|
||||
"""Initialize legacy single-table schema for backward compatibility"""
|
||||
self.session.execute(f"""
|
||||
create table if not exists {self.table} (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (collection, s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_s
|
||||
ON {self.table} (s);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_p
|
||||
ON {self.table} (p);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_o
|
||||
ON {self.table} (o);
|
||||
""");
|
||||
|
||||
def init_optimized_schema(self):
|
||||
"""Initialize optimized multi-table schema for performance"""
|
||||
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.subject_table} (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY ((collection, s), p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.po_table} (
|
||||
collection text,
|
||||
p text,
|
||||
o text,
|
||||
s text,
|
||||
PRIMARY KEY ((collection, p), o, s)
|
||||
);
|
||||
""");
|
||||
|
||||
# Table 3: Object-centric queries (get_o)
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.object_table} (
|
||||
collection text,
|
||||
o text,
|
||||
s text,
|
||||
p text,
|
||||
PRIMARY KEY ((collection, o), s, p)
|
||||
);
|
||||
""");
|
||||
|
||||
logger.info("Optimized multi-table schema initialized")
|
||||
|
||||
def prepare_statements(self):
|
||||
"""Prepare statements for optimal performance"""
|
||||
# Insert statements for batch operations
|
||||
self.insert_subject_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
self.insert_po_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
self.insert_object_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
# Query statements for optimized access
|
||||
self.get_all_stmt = self.session.prepare(
|
||||
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING"
|
||||
)
|
||||
|
||||
self.get_s_stmt = self.session.prepare(
|
||||
f"SELECT p, o FROM {self.subject_table} WHERE collection = ? AND s = ? LIMIT ?"
|
||||
)
|
||||
|
||||
self.get_p_stmt = self.session.prepare(
|
||||
f"SELECT s, o FROM {self.po_table} WHERE collection = ? AND p = ? LIMIT ?"
|
||||
)
|
||||
|
||||
self.get_o_stmt = self.session.prepare(
|
||||
f"SELECT s, p FROM {self.object_table} WHERE collection = ? AND o = ? LIMIT ?"
|
||||
)
|
||||
|
||||
self.get_sp_stmt = self.session.prepare(
|
||||
f"SELECT o FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? LIMIT ?"
|
||||
)
|
||||
|
||||
# The critical optimization: get_po without ALLOW FILTERING!
|
||||
self.get_po_stmt = self.session.prepare(
|
||||
f"SELECT s FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? LIMIT ?"
|
||||
)
|
||||
|
||||
self.get_os_stmt = self.session.prepare(
|
||||
f"SELECT p FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? LIMIT ?"
|
||||
)
|
||||
|
||||
self.get_spo_stmt = self.session.prepare(
|
||||
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
|
||||
)
|
||||
|
||||
logger.info("Prepared statements initialized for optimal performance")
|
||||
|
||||
def insert(self, collection, s, p, o):
|
||||
|
||||
if self.use_legacy:
|
||||
self.session.execute(
|
||||
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
|
||||
(collection, s, p, o)
|
||||
)
|
||||
else:
|
||||
# Batch write to all three tables for consistency
|
||||
batch = BatchStatement()
|
||||
|
||||
# Insert into subject table
|
||||
batch.add(self.insert_subject_stmt, (collection, s, p, o))
|
||||
|
||||
# Insert into predicate-object table (column order: collection, p, o, s)
|
||||
batch.add(self.insert_po_stmt, (collection, p, o, s))
|
||||
|
||||
# Insert into object table (column order: collection, o, s, p)
|
||||
batch.add(self.insert_object_stmt, (collection, o, s, p))
|
||||
|
||||
self.session.execute(batch)
|
||||
|
||||
def get_all(self, collection, limit=50):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, p, o from {self.table} where collection = %s limit {limit}",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Use subject table for get_all queries
|
||||
return self.session.execute(
|
||||
self.get_all_stmt,
|
||||
(collection, limit)
|
||||
)
|
||||
|
||||
def get_s(self, collection, s, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
|
||||
(collection, s)
|
||||
)
|
||||
else:
|
||||
# Optimized: Direct partition access with (collection, s)
|
||||
return self.session.execute(
|
||||
self.get_s_stmt,
|
||||
(collection, s, limit)
|
||||
)
|
||||
|
||||
def get_p(self, collection, p, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
|
||||
(collection, p)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use po_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_p_stmt,
|
||||
(collection, p, limit)
|
||||
)
|
||||
|
||||
def get_o(self, collection, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
|
||||
(collection, o)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use object_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_o_stmt,
|
||||
(collection, o, limit)
|
||||
)
|
||||
|
||||
def get_sp(self, collection, s, p, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
|
||||
(collection, s, p)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table with clustering key access
|
||||
return self.session.execute(
|
||||
self.get_sp_stmt,
|
||||
(collection, s, p, limit)
|
||||
)
|
||||
|
||||
def get_po(self, collection, p, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
|
||||
(collection, p, o)
|
||||
)
|
||||
else:
|
||||
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
|
||||
return self.session.execute(
|
||||
self.get_po_stmt,
|
||||
(collection, p, o, limit)
|
||||
)
|
||||
|
||||
def get_os(self, collection, o, s, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
|
||||
(collection, o, s)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
|
||||
return self.session.execute(
|
||||
self.get_os_stmt,
|
||||
(collection, s, o, limit)
|
||||
)
|
||||
|
||||
def get_spo(self, collection, s, p, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
|
||||
(collection, s, p, o)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table for exact key lookup
|
||||
return self.session.execute(
|
||||
self.get_spo_stmt,
|
||||
(collection, s, p, o, limit)
|
||||
)
|
||||
|
||||
def delete_collection(self, collection):
|
||||
"""Delete all triples for a specific collection"""
|
||||
if self.use_legacy:
|
||||
self.session.execute(
|
||||
f"delete from {self.table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Delete from all three tables
|
||||
self.session.execute(
|
||||
f"delete from {self.subject_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
self.session.execute(
|
||||
f"delete from {self.po_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
self.session.execute(
|
||||
f"delete from {self.object_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the Cassandra session and cluster connections properly"""
|
||||
if hasattr(self, 'session') and self.session:
|
||||
self.session.shutdown()
|
||||
if hasattr(self, 'cluster') and self.cluster:
|
||||
self.cluster.shutdown()
|
||||
# Remove from global tracking
|
||||
if self.cluster in _active_clusters:
|
||||
_active_clusters.remove(self.cluster)
|
||||
|
|
@ -2,9 +2,32 @@
|
|||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_safe_collection_name(user, collection, prefix):
|
||||
"""
|
||||
Create a safe Milvus collection name from user/collection parameters.
|
||||
Milvus only allows letters, numbers, and underscores.
|
||||
"""
|
||||
def sanitize(s):
|
||||
# Replace non-alphanumeric characters (except underscore) with underscore
|
||||
# Then collapse multiple underscores into single underscore
|
||||
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
|
||||
safe = re.sub(r'_+', '_', safe)
|
||||
# Remove leading/trailing underscores
|
||||
safe = safe.strip('_')
|
||||
# Ensure it's not empty
|
||||
if not safe:
|
||||
safe = 'default'
|
||||
return safe
|
||||
|
||||
safe_user = sanitize(user)
|
||||
safe_collection = sanitize(collection)
|
||||
|
||||
return f"{prefix}_{safe_user}_{safe_collection}"
|
||||
|
||||
class DocVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
||||
|
|
@ -26,9 +49,9 @@ class DocVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def init_collection(self, dimension):
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
|
|
@ -75,14 +98,14 @@ class DocVectors:
|
|||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
self.collections[(dimension, user, collection)] = collection_name
|
||||
|
||||
def insert(self, embeds, doc):
|
||||
def insert(self, embeds, doc, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
data = [
|
||||
{
|
||||
|
|
@ -92,18 +115,18 @@ class DocVectors:
|
|||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
collection_name=self.collections[(dim, user, collection)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["doc"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["doc"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
coll = self.collections[dim]
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
|
|
@ -139,3 +162,20 @@ class DocVectors:
|
|||
|
||||
return res
|
||||
|
||||
def delete_collection(self, user, collection):
|
||||
"""Delete a collection for the given user and collection"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
# Check if collection exists
|
||||
if self.client.has_collection(collection_name):
|
||||
# Drop the collection
|
||||
self.client.drop_collection(collection_name)
|
||||
logger.info(f"Deleted Milvus collection: {collection_name}")
|
||||
|
||||
# Remove from our local cache
|
||||
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
|
||||
for key in keys_to_remove:
|
||||
del self.collections[key]
|
||||
else:
|
||||
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,32 @@
|
|||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_safe_collection_name(user, collection, prefix):
|
||||
"""
|
||||
Create a safe Milvus collection name from user/collection parameters.
|
||||
Milvus only allows letters, numbers, and underscores.
|
||||
"""
|
||||
def sanitize(s):
|
||||
# Replace non-alphanumeric characters (except underscore) with underscore
|
||||
# Then collapse multiple underscores into single underscore
|
||||
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
|
||||
safe = re.sub(r'_+', '_', safe)
|
||||
# Remove leading/trailing underscores
|
||||
safe = safe.strip('_')
|
||||
# Ensure it's not empty
|
||||
if not safe:
|
||||
safe = 'default'
|
||||
return safe
|
||||
|
||||
safe_user = sanitize(user)
|
||||
safe_collection = sanitize(collection)
|
||||
|
||||
return f"{prefix}_{safe_user}_{safe_collection}"
|
||||
|
||||
class EntityVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
||||
|
|
@ -26,9 +49,9 @@ class EntityVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def init_collection(self, dimension):
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
|
|
@ -75,14 +98,14 @@ class EntityVectors:
|
|||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
self.collections[(dimension, user, collection)] = collection_name
|
||||
|
||||
def insert(self, embeds, entity):
|
||||
def insert(self, embeds, entity, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
data = [
|
||||
{
|
||||
|
|
@ -92,18 +115,18 @@ class EntityVectors:
|
|||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
collection_name=self.collections[(dim, user, collection)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["entity"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["entity"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
coll = self.collections[dim]
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
|
|
@ -139,3 +162,20 @@ class EntityVectors:
|
|||
|
||||
return res
|
||||
|
||||
def delete_collection(self, user, collection):
|
||||
"""Delete a collection for the given user and collection"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
# Check if collection exists
|
||||
if self.client.has_collection(collection_name):
|
||||
# Drop the collection
|
||||
self.client.drop_collection(collection_name)
|
||||
logger.info(f"Deleted Milvus collection: {collection_name}")
|
||||
|
||||
# Remove from our local cache
|
||||
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
|
||||
for key in keys_to_remove:
|
||||
del self.collections[key]
|
||||
else:
|
||||
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,157 +0,0 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def init_collection(self, dimension, name):
|
||||
|
||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
name_field = FieldSchema(
|
||||
name="name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_name_field = FieldSchema(
|
||||
name="key_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_field = FieldSchema(
|
||||
name="key",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [
|
||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
||||
],
|
||||
description = "Object embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[(dimension, name)] = collection_name
|
||||
|
||||
def insert(self, embeds, name, key_name, key):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, name) not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"name": name,
|
||||
"key_name": key_name,
|
||||
"key": key,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[(dim, name)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
coll = self.collections[(dim, name)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
logger.debug("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
logger.debug(f"Unloading, reload at {self.next_reload}")
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -27,13 +27,13 @@ class Processor(FlowProcessor):
|
|||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
template_id = params.get("template-id", default_template_id)
|
||||
config_key = params.get("config-type", default_config_type)
|
||||
template_id = params.get("template_id", default_template_id)
|
||||
config_key = params.get("config_type", default_config_type)
|
||||
|
||||
super().__init__(**params | {
|
||||
"id": id,
|
||||
"template-id": template_id,
|
||||
"config-type": config_key,
|
||||
"template_id": template_id,
|
||||
"config_type": config_key,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class Processor(FlowProcessor):
|
|||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config-type": self.config_key,
|
||||
"config_type": self.config_key,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
|
@ -256,31 +256,34 @@ class Processor(FlowProcessor):
|
|||
flow
|
||||
)
|
||||
|
||||
# Emit each extracted object
|
||||
for obj in objects:
|
||||
# Emit extracted objects as a batch if any were found
|
||||
if objects:
|
||||
|
||||
# Calculate confidence (could be enhanced with actual confidence from prompt)
|
||||
confidence = 0.8 # Default confidence
|
||||
|
||||
# Convert all values to strings for Pulsar compatibility
|
||||
string_values = convert_values_to_strings(obj)
|
||||
# Convert all objects' values to strings for Pulsar compatibility
|
||||
batch_values = []
|
||||
for obj in objects:
|
||||
string_values = convert_values_to_strings(obj)
|
||||
batch_values.append(string_values)
|
||||
|
||||
# Create ExtractedObject
|
||||
# Create ExtractedObject with batched values
|
||||
extracted = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id=f"{v.metadata.id}:{schema_name}:{hash(str(obj))}",
|
||||
id=f"{v.metadata.id}:{schema_name}",
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
schema_name=schema_name,
|
||||
values=string_values,
|
||||
values=batch_values, # Array of objects
|
||||
confidence=confidence,
|
||||
source_span=chunk_text[:100] # First 100 chars as source reference
|
||||
)
|
||||
|
||||
await flow("output").send(extracted)
|
||||
logger.debug(f"Emitted extracted object for schema {schema_name}")
|
||||
logger.debug(f"Emitted batch of {len(objects)} objects for schema {schema_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Object extraction exception: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
from ... schema import CollectionManagementRequest, CollectionManagementResponse
|
||||
from ... schema import collection_request_queue, collection_response_queue
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class CollectionManagementRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||
|
||||
super(CollectionManagementRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
consumer_name = consumer,
|
||||
subscription = subscriber,
|
||||
request_queue=collection_request_queue,
|
||||
response_queue=collection_response_queue,
|
||||
request_schema=CollectionManagementRequest,
|
||||
response_schema=CollectionManagementResponse,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("collection-management")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("collection-management")
|
||||
|
||||
def to_request(self, body):
|
||||
print("REQUEST", body, flush=True)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
print("RESPONSE", message, flush=True)
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -26,46 +26,66 @@ class DocumentEmbeddingsExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = DocumentEmbeddings
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = DocumentEmbeddings,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_document_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -8,6 +9,9 @@ from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
|||
from ... base import Publisher
|
||||
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentEmbeddingsImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class DocumentEmbeddingsImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -26,46 +26,66 @@ class EntityContextsExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = EntityContexts
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = EntityContexts,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_entity_contexts(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
|||
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EntityContextsImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class EntityContextsImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -26,46 +26,66 @@ class GraphEmbeddingsExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = GraphEmbeddings
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = GraphEmbeddings,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
|||
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GraphEmbeddingsImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class GraphEmbeddingsImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from . config import ConfigRequestor
|
|||
from . flow import FlowRequestor
|
||||
from . librarian import LibrarianRequestor
|
||||
from . knowledge import KnowledgeRequestor
|
||||
from . collection_management import CollectionManagementRequestor
|
||||
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
from . agent import AgentRequestor
|
||||
|
|
@ -19,6 +20,10 @@ from . prompt import PromptRequestor
|
|||
from . graph_rag import GraphRagRequestor
|
||||
from . document_rag import DocumentRagRequestor
|
||||
from . triples_query import TriplesQueryRequestor
|
||||
from . objects_query import ObjectsQueryRequestor
|
||||
from . nlp_query import NLPQueryRequestor
|
||||
from . structured_query import StructuredQueryRequestor
|
||||
from . structured_diag import StructuredDiagRequestor
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
|
|
@ -34,6 +39,7 @@ from . triples_import import TriplesImport
|
|||
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
from . entity_contexts_import import EntityContextsImport
|
||||
from . objects_import import ObjectsImport
|
||||
|
||||
from . core_export import CoreExport
|
||||
from . core_import import CoreImport
|
||||
|
|
@ -50,6 +56,10 @@ request_response_dispatchers = {
|
|||
"embeddings": EmbeddingsRequestor,
|
||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"triples": TriplesQueryRequestor,
|
||||
"objects": ObjectsQueryRequestor,
|
||||
"nlp-query": NLPQueryRequestor,
|
||||
"structured-query": StructuredQueryRequestor,
|
||||
"structured-diag": StructuredDiagRequestor,
|
||||
}
|
||||
|
||||
global_dispatchers = {
|
||||
|
|
@ -57,6 +67,7 @@ global_dispatchers = {
|
|||
"flow": FlowRequestor,
|
||||
"librarian": LibrarianRequestor,
|
||||
"knowledge": KnowledgeRequestor,
|
||||
"collection-management": CollectionManagementRequestor,
|
||||
}
|
||||
|
||||
sender_dispatchers = {
|
||||
|
|
@ -76,6 +87,7 @@ import_dispatchers = {
|
|||
"graph-embeddings": GraphEmbeddingsImport,
|
||||
"document-embeddings": DocumentEmbeddingsImport,
|
||||
"entity-contexts": EntityContextsImport,
|
||||
"objects": ObjectsImport,
|
||||
}
|
||||
|
||||
class DispatcherWrapper:
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class Mux:
|
|||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
await self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
break
|
||||
|
|
@ -165,6 +165,6 @@ class Mux:
|
|||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
await self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
|
|
|
|||
30
trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py
Normal file
30
trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from ... schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class NLPQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(NLPQueryRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=QuestionToStructuredQueryRequest,
|
||||
response_schema=QuestionToStructuredQueryResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("nlp-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("nlp-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import ExtractedObject
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectsImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
self.publisher = Publisher(
|
||||
pulsar_client, topic = queue, schema = ExtractedObject
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
||||
# Handle both single object and array of objects for backward compatibility
|
||||
values_data = data["values"]
|
||||
if not isinstance(values_data, list):
|
||||
# Single object - wrap in array
|
||||
values_data = [values_data]
|
||||
|
||||
elt = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"].get("metadata", [])),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
schema_name=data["schema_name"],
|
||||
values=values_data,
|
||||
confidence=data.get("confidence", 1.0),
|
||||
source_span=data.get("source_span", ""),
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
async def run(self):
|
||||
|
||||
while self.running.get():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
30
trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py
Normal file
30
trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class ObjectsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(ObjectsQueryRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=ObjectsQueryRequest,
|
||||
response_schema=ObjectsQueryResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from ... schema import StructuredDataDiagnosisRequest, StructuredDataDiagnosisResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class StructuredDiagRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(StructuredDiagRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=StructuredDataDiagnosisRequest,
|
||||
response_schema=StructuredDataDiagnosisResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("structured-diag")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("structured-diag")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from ... schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class StructuredQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(StructuredQueryRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=StructuredQueryRequest,
|
||||
response_schema=StructuredQueryResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("structured-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("structured-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -26,46 +26,66 @@ class TriplesExport:
|
|||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Signal stop to prevent new messages
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
# Step 2: Wait briefly for in-flight messages
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
|
||||
if hasattr(self, 'subs'):
|
||||
await self.subs.unsubscribe_all(self.id)
|
||||
await self.subs.stop()
|
||||
|
||||
# Step 4: Close websocket last
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = Triples
|
||||
"""Enhanced run with better error handling"""
|
||||
self.subs = Subscriber(
|
||||
client = self.pulsar_client,
|
||||
topic = self.queue,
|
||||
consumer_name = self.consumer,
|
||||
subscription = self.subscriber,
|
||||
schema = Triples,
|
||||
backpressure_strategy = "block" # Configurable
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
|
||||
await self.subs.start()
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
q = await self.subs.subscribe_all(self.id)
|
||||
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
consecutive_errors = 0 # Reset on success
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
logger.error(f"Exception sending to websocket: {str(e)}")
|
||||
consecutive_errors += 1
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error("Too many consecutive errors, shutting down")
|
||||
break
|
||||
|
||||
# Brief pause before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Graceful cleanup handled in destroy()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
|
|
@ -9,6 +10,9 @@ from ... base import Publisher
|
|||
|
||||
from . serialize import to_subgraph
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TriplesImport:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -26,13 +30,17 @@ class TriplesImport:
|
|||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
# Step 1: Stop accepting new messages
|
||||
self.running.stop()
|
||||
|
||||
# Step 2: Wait for publisher to drain its queue
|
||||
logger.info("Draining publisher queue...")
|
||||
await self.publisher.stop()
|
||||
|
||||
# Step 3: Close websocket only after queue is drained
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
|
|
|||
|
|
@ -25,24 +25,43 @@ class SocketEndpoint:
|
|||
await dispatcher.run()
|
||||
|
||||
async def listener(self, ws, dispatcher, running):
|
||||
|
||||
async for msg in ws:
|
||||
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
"""Enhanced listener with graceful shutdown"""
|
||||
try:
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
else:
|
||||
# Graceful shutdown on close
|
||||
logger.info("Websocket closing, initiating graceful shutdown")
|
||||
running.stop()
|
||||
|
||||
# Allow time for dispatcher cleanup
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Close websocket if not already closed
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
running.stop()
|
||||
await ws.close()
|
||||
# This executes when the async for loop completes normally (no break)
|
||||
logger.debug("Websocket iteration completed, performing cleanup")
|
||||
running.stop()
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
# Handle exceptions and cleanup
|
||||
running.stop()
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
raise
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
"""Enhanced handler with better cleanup"""
|
||||
try:
|
||||
token = request.query['token']
|
||||
except:
|
||||
|
|
@ -55,7 +74,9 @@ class SocketEndpoint:
|
|||
ws = web.WebSocketResponse(max_msg_size=52428800)
|
||||
|
||||
await ws.prepare(request)
|
||||
|
||||
|
||||
dispatcher = None
|
||||
|
||||
try:
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
|
@ -80,9 +101,6 @@ class SocketEndpoint:
|
|||
|
||||
logger.debug("Task group closed")
|
||||
|
||||
# Finally?
|
||||
await dispatcher.destroy()
|
||||
|
||||
except ExceptionGroup as e:
|
||||
|
||||
logger.error("Exception group occurred:", exc_info=True)
|
||||
|
|
@ -90,11 +108,34 @@ class SocketEndpoint:
|
|||
for se in e.exceptions:
|
||||
logger.error(f" Exception type: {type(se)}")
|
||||
logger.error(f" Exception: {se}")
|
||||
|
||||
# Attempt graceful dispatcher shutdown
|
||||
if dispatcher and hasattr(dispatcher, 'destroy'):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
dispatcher.destroy(),
|
||||
timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Dispatcher shutdown timed out")
|
||||
except Exception as de:
|
||||
logger.error(f"Error during dispatcher cleanup: {de}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Socket exception: {e}", exc_info=True)
|
||||
|
||||
await ws.close()
|
||||
|
||||
|
||||
finally:
|
||||
# Ensure dispatcher cleanup
|
||||
if dispatcher and hasattr(dispatcher, 'destroy'):
|
||||
try:
|
||||
await dispatcher.destroy()
|
||||
except Exception as de:
|
||||
logger.error(f"Error in final dispatcher cleanup: {de}")
|
||||
|
||||
# Ensure websocket is closed
|
||||
if ws and not ws.closed:
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
|
|
|
|||
315
trustgraph-flow/trustgraph/librarian/collection_manager.py
Normal file
315
trustgraph-flow/trustgraph/librarian/collection_manager.py
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
"""
|
||||
Collection management for the librarian
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error
|
||||
from .. schema import CollectionMetadata
|
||||
from .. schema import StorageManagementRequest, StorageManagementResponse
|
||||
from .. exceptions import RequestError
|
||||
from .. tables.library import LibraryTableStore
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CollectionManager:
|
||||
"""Manages collection metadata and coordinates collection operations across storage types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cassandra_host,
|
||||
cassandra_username,
|
||||
cassandra_password,
|
||||
keyspace,
|
||||
vector_storage_producer=None,
|
||||
object_storage_producer=None,
|
||||
triples_storage_producer=None,
|
||||
storage_response_consumer=None
|
||||
):
|
||||
"""
|
||||
Initialize the CollectionManager
|
||||
|
||||
Args:
|
||||
cassandra_host: Cassandra host(s)
|
||||
cassandra_username: Cassandra username
|
||||
cassandra_password: Cassandra password
|
||||
keyspace: Cassandra keyspace for library data
|
||||
vector_storage_producer: Producer for vector storage management
|
||||
object_storage_producer: Producer for object storage management
|
||||
triples_storage_producer: Producer for triples storage management
|
||||
storage_response_consumer: Consumer for storage management responses
|
||||
"""
|
||||
self.table_store = LibraryTableStore(
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace
|
||||
)
|
||||
|
||||
# Storage management producers
|
||||
self.vector_storage_producer = vector_storage_producer
|
||||
self.object_storage_producer = object_storage_producer
|
||||
self.triples_storage_producer = triples_storage_producer
|
||||
self.storage_response_consumer = storage_response_consumer
|
||||
|
||||
# Track pending deletion operations
|
||||
self.pending_deletions = {}
|
||||
|
||||
logger.info("Collection manager initialized")
|
||||
|
||||
async def ensure_collection_exists(self, user: str, collection: str):
|
||||
"""
|
||||
Ensure a collection exists, creating it if necessary (lazy creation)
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
collection: Collection ID
|
||||
"""
|
||||
try:
|
||||
# Check if collection already exists
|
||||
existing = await self.table_store.get_collection(user, collection)
|
||||
if existing:
|
||||
logger.debug(f"Collection {user}/{collection} already exists")
|
||||
return
|
||||
|
||||
# Create new collection with default metadata
|
||||
logger.info(f"Creating new collection {user}/{collection}")
|
||||
await self.table_store.create_collection(
|
||||
user=user,
|
||||
collection=collection,
|
||||
name=collection, # Default name to collection ID
|
||||
description="",
|
||||
tags=set()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
# Don't fail the operation if collection creation fails
|
||||
# This maintains backward compatibility
|
||||
|
||||
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
||||
"""
|
||||
List collections for a user with optional tag filtering
|
||||
|
||||
Args:
|
||||
request: Collection management request
|
||||
|
||||
Returns:
|
||||
CollectionManagementResponse with list of collections
|
||||
"""
|
||||
try:
|
||||
tag_filter = list(request.tag_filter) if request.tag_filter else None
|
||||
collections = await self.table_store.list_collections(request.user, tag_filter)
|
||||
|
||||
collection_metadata = [
|
||||
CollectionMetadata(
|
||||
user=coll["user"],
|
||||
collection=coll["collection"],
|
||||
name=coll["name"],
|
||||
description=coll["description"],
|
||||
tags=coll["tags"],
|
||||
created_at=coll["created_at"],
|
||||
updated_at=coll["updated_at"]
|
||||
)
|
||||
for coll in collections
|
||||
]
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=None,
|
||||
collections=collection_metadata,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing collections: {e}")
|
||||
raise RequestError(f"Failed to list collections: {str(e)}")
|
||||
|
||||
async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
||||
"""
|
||||
Update collection metadata (creates if doesn't exist)
|
||||
|
||||
Args:
|
||||
request: Collection management request
|
||||
|
||||
Returns:
|
||||
CollectionManagementResponse with updated collection
|
||||
"""
|
||||
try:
|
||||
# Check if collection exists, create if it doesn't
|
||||
existing = await self.table_store.get_collection(request.user, request.collection)
|
||||
if not existing:
|
||||
# Create new collection with provided metadata
|
||||
logger.info(f"Creating new collection {request.user}/{request.collection}")
|
||||
|
||||
name = request.name if request.name else request.collection
|
||||
description = request.description if request.description else ""
|
||||
tags = set(request.tags) if request.tags else set()
|
||||
|
||||
await self.table_store.create_collection(
|
||||
user=request.user,
|
||||
collection=request.collection,
|
||||
name=name,
|
||||
description=description,
|
||||
tags=tags
|
||||
)
|
||||
|
||||
# Get the newly created collection for response
|
||||
created_collection = await self.table_store.get_collection(request.user, request.collection)
|
||||
|
||||
collection_metadata = CollectionMetadata(
|
||||
user=created_collection["user"],
|
||||
collection=created_collection["collection"],
|
||||
name=created_collection["name"],
|
||||
description=created_collection["description"],
|
||||
tags=created_collection["tags"],
|
||||
created_at=created_collection["created_at"],
|
||||
updated_at=created_collection["updated_at"]
|
||||
)
|
||||
else:
|
||||
# Collection exists, update it
|
||||
name = request.name if request.name else None
|
||||
description = request.description if request.description else None
|
||||
tags = list(request.tags) if request.tags else None
|
||||
|
||||
updated_collection = await self.table_store.update_collection(
|
||||
request.user, request.collection, name, description, tags
|
||||
)
|
||||
|
||||
collection_metadata = CollectionMetadata(
|
||||
user=updated_collection["user"],
|
||||
collection=updated_collection["collection"],
|
||||
name=updated_collection["name"],
|
||||
description=updated_collection["description"],
|
||||
tags=updated_collection["tags"],
|
||||
created_at="", # Not returned by update
|
||||
updated_at=updated_collection["updated_at"]
|
||||
)
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=None,
|
||||
collections=[collection_metadata],
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating collection: {e}")
|
||||
raise RequestError(f"Failed to update collection: {str(e)}")
|
||||
|
||||
async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
||||
"""
|
||||
Delete collection with cascade to all storage types
|
||||
|
||||
Args:
|
||||
request: Collection management request
|
||||
|
||||
Returns:
|
||||
CollectionManagementResponse indicating success or failure
|
||||
"""
|
||||
try:
|
||||
deletion_key = (request.user, request.collection)
|
||||
|
||||
logger.info(f"Starting cascade deletion for {request.user}/{request.collection}")
|
||||
|
||||
# Track this deletion request
|
||||
self.pending_deletions[deletion_key] = {
|
||||
"responses_pending": 3, # vector, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
"deletion_complete": asyncio.Event()
|
||||
}
|
||||
|
||||
# Create storage management request
|
||||
storage_request = StorageManagementRequest(
|
||||
operation="delete-collection",
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Send deletion requests to all storage types
|
||||
if self.vector_storage_producer:
|
||||
await self.vector_storage_producer.send(storage_request)
|
||||
if self.object_storage_producer:
|
||||
await self.object_storage_producer.send(storage_request)
|
||||
if self.triples_storage_producer:
|
||||
await self.triples_storage_producer.send(storage_request)
|
||||
|
||||
# Wait for all storage deletions to complete (with timeout)
|
||||
deletion_info = self.pending_deletions[deletion_key]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
deletion_info["deletion_complete"].wait(),
|
||||
timeout=30.0 # 30 second timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout waiting for storage deletion responses for {deletion_key}")
|
||||
deletion_info["all_successful"] = False
|
||||
deletion_info["error_messages"].append("Timeout waiting for storage deletion")
|
||||
|
||||
# Check if all deletions succeeded
|
||||
if not deletion_info["all_successful"]:
|
||||
error_msg = f"Storage deletion failed: {'; '.join(deletion_info['error_messages'])}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[deletion_key]
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=Error(
|
||||
type="storage_deletion_error",
|
||||
message=error_msg
|
||||
),
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# All storage deletions succeeded, now delete metadata
|
||||
logger.info(f"Storage deletions complete, removing metadata for {deletion_key}")
|
||||
await self.table_store.delete_collection(request.user, request.collection)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[deletion_key]
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=None,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting collection: {e}")
|
||||
# Clean up tracking on error
|
||||
if deletion_key in self.pending_deletions:
|
||||
del self.pending_deletions[deletion_key]
|
||||
raise RequestError(f"Failed to delete collection: {str(e)}")
|
||||
|
||||
async def on_storage_response(self, response: StorageManagementResponse):
|
||||
"""
|
||||
Handle storage management responses for deletion tracking
|
||||
|
||||
Args:
|
||||
response: Storage management response
|
||||
"""
|
||||
logger.debug(f"Received storage response: error={response.error}")
|
||||
|
||||
# Find matching deletion by checking all pending deletions
|
||||
# Note: This is simplified correlation - in production we'd want better correlation
|
||||
for deletion_key, info in list(self.pending_deletions.items()):
|
||||
if info["responses_pending"] > 0:
|
||||
# Record this response
|
||||
info["responses_received"].append(response)
|
||||
info["responses_pending"] -= 1
|
||||
|
||||
# Check if this response indicates failure
|
||||
if response.error and response.error.message:
|
||||
info["all_successful"] = False
|
||||
info["error_messages"].append(response.error.message)
|
||||
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
|
||||
else:
|
||||
logger.debug(f"Storage deletion succeeded for {deletion_key}")
|
||||
|
||||
# If all responses received, signal completion
|
||||
if info["responses_pending"] == 0:
|
||||
logger.info(f"All storage responses received for {deletion_key}")
|
||||
info["deletion_complete"].set()
|
||||
|
||||
break # Only process for first matching deletion
|
||||
|
|
@ -16,7 +16,7 @@ class Librarian:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cassandra_host, cassandra_user, cassandra_password,
|
||||
cassandra_host, cassandra_username, cassandra_password,
|
||||
minio_host, minio_access_key, minio_secret_key,
|
||||
bucket_name, keyspace, load_document,
|
||||
):
|
||||
|
|
@ -26,7 +26,7 @@ class Librarian:
|
|||
)
|
||||
|
||||
self.table_store = LibraryTableStore(
|
||||
cassandra_host, cassandra_user, cassandra_password, keyspace
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace
|
||||
)
|
||||
|
||||
self.load_document = load_document
|
||||
|
|
|
|||
|
|
@ -8,12 +8,19 @@ import asyncio
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber
|
||||
from .. base import ConsumerMetrics, ProducerMetrics
|
||||
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
from .. schema import LibrarianRequest, LibrarianResponse, Error
|
||||
from .. schema import librarian_request_queue, librarian_response_queue
|
||||
from .. schema import CollectionManagementRequest, CollectionManagementResponse
|
||||
from .. schema import collection_request_queue, collection_response_queue
|
||||
from .. schema import StorageManagementRequest, StorageManagementResponse
|
||||
from .. schema import vector_storage_management_topic, object_storage_management_topic
|
||||
from .. schema import triples_storage_management_topic, storage_management_response_topic
|
||||
|
||||
from .. schema import Document, Metadata
|
||||
from .. schema import TextDocument, Metadata
|
||||
|
|
@ -21,6 +28,7 @@ from .. schema import TextDocument, Metadata
|
|||
from .. exceptions import RequestError
|
||||
|
||||
from . librarian import Librarian
|
||||
from . collection_manager import CollectionManager
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,6 +37,8 @@ default_ident = "librarian"
|
|||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
default_collection_request_queue = collection_request_queue
|
||||
default_collection_response_queue = collection_response_queue
|
||||
|
||||
default_minio_host = "minio:9000"
|
||||
default_minio_access_key = "minioadmin"
|
||||
|
|
@ -56,6 +66,14 @@ class Processor(AsyncProcessor):
|
|||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
collection_request_queue = params.get(
|
||||
"collection_request_queue", default_collection_request_queue
|
||||
)
|
||||
|
||||
collection_response_queue = params.get(
|
||||
"collection_response_queue", default_collection_response_queue
|
||||
)
|
||||
|
||||
minio_host = params.get("minio_host", default_minio_host)
|
||||
minio_access_key = params.get(
|
||||
"minio_access_key",
|
||||
|
|
@ -66,18 +84,33 @@ class Processor(AsyncProcessor):
|
|||
default_minio_secret_key
|
||||
)
|
||||
|
||||
cassandra_host = params.get("cassandra_host", default_cassandra_host)
|
||||
cassandra_user = params.get("cassandra_user")
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration
|
||||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"librarian_request_queue": librarian_request_queue,
|
||||
"librarian_response_queue": librarian_response_queue,
|
||||
"collection_request_queue": collection_request_queue,
|
||||
"collection_response_queue": collection_response_queue,
|
||||
"minio_host": minio_host,
|
||||
"minio_access_key": minio_access_key,
|
||||
"cassandra_host": cassandra_host,
|
||||
"cassandra_user": cassandra_user,
|
||||
"cassandra_host": self.cassandra_host,
|
||||
"cassandra_username": self.cassandra_username,
|
||||
"cassandra_password": self.cassandra_password,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -89,6 +122,18 @@ class Processor(AsyncProcessor):
|
|||
processor = self.id, flow = None, name = "librarian-response"
|
||||
)
|
||||
|
||||
collection_request_metrics = ConsumerMetrics(
|
||||
processor = self.id, flow = None, name = "collection-request"
|
||||
)
|
||||
|
||||
collection_response_metrics = ProducerMetrics(
|
||||
processor = self.id, flow = None, name = "collection-response"
|
||||
)
|
||||
|
||||
storage_response_metrics = ConsumerMetrics(
|
||||
processor = self.id, flow = None, name = "storage-response"
|
||||
)
|
||||
|
||||
self.librarian_request_consumer = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
client = self.pulsar_client,
|
||||
|
|
@ -107,10 +152,58 @@ class Processor(AsyncProcessor):
|
|||
metrics = librarian_response_metrics,
|
||||
)
|
||||
|
||||
self.collection_request_consumer = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
client = self.pulsar_client,
|
||||
flow = None,
|
||||
topic = collection_request_queue,
|
||||
subscriber = id,
|
||||
schema = CollectionManagementRequest,
|
||||
handler = self.on_collection_request,
|
||||
metrics = collection_request_metrics,
|
||||
)
|
||||
|
||||
self.collection_response_producer = Producer(
|
||||
client = self.pulsar_client,
|
||||
topic = collection_response_queue,
|
||||
schema = CollectionManagementResponse,
|
||||
metrics = collection_response_metrics,
|
||||
)
|
||||
|
||||
# Storage management producers for collection deletion
|
||||
self.vector_storage_producer = Producer(
|
||||
client = self.pulsar_client,
|
||||
topic = vector_storage_management_topic,
|
||||
schema = StorageManagementRequest,
|
||||
)
|
||||
|
||||
self.object_storage_producer = Producer(
|
||||
client = self.pulsar_client,
|
||||
topic = object_storage_management_topic,
|
||||
schema = StorageManagementRequest,
|
||||
)
|
||||
|
||||
self.triples_storage_producer = Producer(
|
||||
client = self.pulsar_client,
|
||||
topic = triples_storage_management_topic,
|
||||
schema = StorageManagementRequest,
|
||||
)
|
||||
|
||||
self.storage_response_consumer = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
client = self.pulsar_client,
|
||||
flow = None,
|
||||
topic = storage_management_response_topic,
|
||||
subscriber = id,
|
||||
schema = StorageManagementResponse,
|
||||
handler = self.on_storage_response,
|
||||
metrics = storage_response_metrics,
|
||||
)
|
||||
|
||||
self.librarian = Librarian(
|
||||
cassandra_host = cassandra_host.split(","),
|
||||
cassandra_user = cassandra_user,
|
||||
cassandra_password = cassandra_password,
|
||||
cassandra_host = self.cassandra_host,
|
||||
cassandra_username = self.cassandra_username,
|
||||
cassandra_password = self.cassandra_password,
|
||||
minio_host = minio_host,
|
||||
minio_access_key = minio_access_key,
|
||||
minio_secret_key = minio_secret_key,
|
||||
|
|
@ -119,6 +212,17 @@ class Processor(AsyncProcessor):
|
|||
load_document = self.load_document,
|
||||
)
|
||||
|
||||
self.collection_manager = CollectionManager(
|
||||
cassandra_host = self.cassandra_host,
|
||||
cassandra_username = self.cassandra_username,
|
||||
cassandra_password = self.cassandra_password,
|
||||
keyspace = keyspace,
|
||||
vector_storage_producer = self.vector_storage_producer,
|
||||
object_storage_producer = self.object_storage_producer,
|
||||
triples_storage_producer = self.triples_storage_producer,
|
||||
storage_response_consumer = self.storage_response_consumer,
|
||||
)
|
||||
|
||||
self.register_config_handler(self.on_librarian_config)
|
||||
|
||||
self.flows = {}
|
||||
|
|
@ -130,6 +234,12 @@ class Processor(AsyncProcessor):
|
|||
await super(Processor, self).start()
|
||||
await self.librarian_request_consumer.start()
|
||||
await self.librarian_response_producer.start()
|
||||
await self.collection_request_consumer.start()
|
||||
await self.collection_response_producer.start()
|
||||
await self.vector_storage_producer.start()
|
||||
await self.object_storage_producer.start()
|
||||
await self.triples_storage_producer.start()
|
||||
await self.storage_response_consumer.start()
|
||||
|
||||
async def on_librarian_config(self, config, version):
|
||||
|
||||
|
|
@ -209,6 +319,19 @@ class Processor(AsyncProcessor):
|
|||
|
||||
logger.debug("Document submitted")
|
||||
|
||||
async def add_processing_with_collection(self, request):
|
||||
"""
|
||||
Wrapper for add_processing that ensures collection exists
|
||||
"""
|
||||
# Ensure collection exists when processing is added
|
||||
if hasattr(request, 'processing_metadata') and request.processing_metadata:
|
||||
user = request.processing_metadata.user
|
||||
collection = request.processing_metadata.collection
|
||||
await self.collection_manager.ensure_collection_exists(user, collection)
|
||||
|
||||
# Call the original add_processing method
|
||||
return await self.librarian.add_processing(request)
|
||||
|
||||
async def process_request(self, v):
|
||||
|
||||
if v.operation is None:
|
||||
|
|
@ -222,7 +345,7 @@ class Processor(AsyncProcessor):
|
|||
"update-document": self.librarian.update_document,
|
||||
"get-document-metadata": self.librarian.get_document_metadata,
|
||||
"get-document-content": self.librarian.get_document_content,
|
||||
"add-processing": self.librarian.add_processing,
|
||||
"add-processing": self.add_processing_with_collection,
|
||||
"remove-processing": self.librarian.remove_processing,
|
||||
"list-documents": self.librarian.list_documents,
|
||||
"list-processing": self.librarian.list_processing,
|
||||
|
|
@ -282,6 +405,73 @@ class Processor(AsyncProcessor):
|
|||
|
||||
logger.debug("Librarian input processing complete")
|
||||
|
||||
async def process_collection_request(self, v):
|
||||
"""
|
||||
Process collection management requests
|
||||
"""
|
||||
if v.operation is None:
|
||||
raise RequestError("Null operation")
|
||||
|
||||
logger.debug(f"Collection request: {v.operation}")
|
||||
|
||||
impls = {
|
||||
"list-collections": self.collection_manager.list_collections,
|
||||
"update-collection": self.collection_manager.update_collection,
|
||||
"delete-collection": self.collection_manager.delete_collection,
|
||||
}
|
||||
|
||||
if v.operation not in impls:
|
||||
raise RequestError(f"Invalid collection operation: {v.operation}")
|
||||
|
||||
return await impls[v.operation](v)
|
||||
|
||||
async def on_collection_request(self, msg, consumer, flow):
|
||||
"""
|
||||
Handle collection management request messages
|
||||
"""
|
||||
v = msg.value()
|
||||
id = msg.properties().get("id", "unknown")
|
||||
|
||||
logger.info(f"Handling collection request {id}...")
|
||||
|
||||
try:
|
||||
resp = await self.process_collection_request(v)
|
||||
await self.collection_response_producer.send(
|
||||
resp, properties={"id": id}
|
||||
)
|
||||
except RequestError as e:
|
||||
resp = CollectionManagementResponse(
|
||||
error=Error(
|
||||
type="request-error",
|
||||
message=str(e),
|
||||
),
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
await self.collection_response_producer.send(
|
||||
resp, properties={"id": id}
|
||||
)
|
||||
except Exception as e:
|
||||
resp = CollectionManagementResponse(
|
||||
error=Error(
|
||||
type="unexpected-error",
|
||||
message=str(e),
|
||||
),
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
await self.collection_response_producer.send(
|
||||
resp, properties={"id": id}
|
||||
)
|
||||
|
||||
logger.debug("Collection request processing complete")
|
||||
|
||||
async def on_storage_response(self, msg, consumer, flow):
|
||||
"""
|
||||
Handle storage management response messages
|
||||
"""
|
||||
v = msg.value()
|
||||
logger.debug("Received storage management response")
|
||||
await self.collection_manager.on_storage_response(v)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
@ -299,6 +489,18 @@ class Processor(AsyncProcessor):
|
|||
help=f'Config response queue {default_librarian_response_queue}',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--collection-request-queue',
|
||||
default=default_collection_request_queue,
|
||||
help=f'Collection request queue (default: {default_collection_request_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--collection-response-queue',
|
||||
default=default_collection_response_queue,
|
||||
help=f'Collection response queue (default: {default_collection_response_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--minio-host',
|
||||
default=default_minio_host,
|
||||
|
|
@ -319,23 +521,7 @@ class Processor(AsyncProcessor):
|
|||
f'(default: {default_minio_access_key})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-host',
|
||||
default="cassandra",
|
||||
help=f'Graph host (default: cassandra)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-user',
|
||||
default=None,
|
||||
help=f'Cassandra user'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-password',
|
||||
default=None,
|
||||
help=f'Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class Processor(FlowProcessor):
|
|||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config-type": self.config_key,
|
||||
"config_type": self.config_key,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,12 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
for vec in msg.vectors:
|
||||
|
||||
resp = self.vecstore.search(vec, limit=msg.limit)
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit
|
||||
)
|
||||
|
||||
for r in resp:
|
||||
chunk = r["entity"]["doc"]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,39 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
}
|
||||
)
|
||||
|
||||
self.last_index_name = None
|
||||
|
||||
def ensure_index_exists(self, index_name, dim):
|
||||
"""Ensure index exists, create if it doesn't"""
|
||||
if index_name != self.last_index_name:
|
||||
if not self.pinecone.has_index(index_name):
|
||||
try:
|
||||
self.pinecone.create_index(
|
||||
name=index_name,
|
||||
dimension=dim,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(
|
||||
cloud="aws",
|
||||
region="us-east-1",
|
||||
)
|
||||
)
|
||||
logger.info(f"Created index: {index_name}")
|
||||
|
||||
# Wait for index to be ready
|
||||
import time
|
||||
for i in range(0, 1000):
|
||||
if self.pinecone.describe_index(index_name).status["ready"]:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
if not self.pinecone.describe_index(index_name).status["ready"]:
|
||||
raise RuntimeError("Gave up waiting for index creation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pinecone index creation failed: {e}")
|
||||
raise e
|
||||
self.last_index_name = index_name
|
||||
|
||||
async def query_document_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
|
@ -62,9 +95,11 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
||||
"d-" + msg.user + "-" + msg.collection
|
||||
)
|
||||
|
||||
self.ensure_index_exists(index_name, dim)
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
results = index.query(
|
||||
|
|
|
|||
|
|
@ -38,6 +38,24 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.last_collection = None
|
||||
|
||||
def ensure_collection_exists(self, collection, dim):
|
||||
"""Ensure collection exists, create if it doesn't"""
|
||||
if collection != self.last_collection:
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
logger.info(f"Created collection: {collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Qdrant collection creation failed: {e}")
|
||||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
async def query_document_embeddings(self, msg):
|
||||
|
||||
|
|
@ -49,10 +67,11 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + msg.user + "_" + msg.collection + "_" +
|
||||
str(dim)
|
||||
"d_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
self.ensure_collection_exists(collection, dim)
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
|
|
|
|||
|
|
@ -50,7 +50,12 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
for vec in msg.vectors:
|
||||
|
||||
resp = self.vecstore.search(vec, limit=msg.limit * 2)
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit * 2
|
||||
)
|
||||
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
|
|
|
|||
|
|
@ -49,6 +49,39 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
}
|
||||
)
|
||||
|
||||
self.last_index_name = None
|
||||
|
||||
def ensure_index_exists(self, index_name, dim):
|
||||
"""Ensure index exists, create if it doesn't"""
|
||||
if index_name != self.last_index_name:
|
||||
if not self.pinecone.has_index(index_name):
|
||||
try:
|
||||
self.pinecone.create_index(
|
||||
name=index_name,
|
||||
dimension=dim,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(
|
||||
cloud="aws",
|
||||
region="us-east-1",
|
||||
)
|
||||
)
|
||||
logger.info(f"Created index: {index_name}")
|
||||
|
||||
# Wait for index to be ready
|
||||
import time
|
||||
for i in range(0, 1000):
|
||||
if self.pinecone.describe_index(index_name).status["ready"]:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
if not self.pinecone.describe_index(index_name).status["ready"]:
|
||||
raise RuntimeError("Gave up waiting for index creation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pinecone index creation failed: {e}")
|
||||
raise e
|
||||
self.last_index_name = index_name
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
|
|
@ -71,9 +104,11 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
||||
"t-" + msg.user + "-" + msg.collection
|
||||
)
|
||||
|
||||
self.ensure_index_exists(index_name, dim)
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
|
|
|
|||
|
|
@ -38,6 +38,24 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.last_collection = None
|
||||
|
||||
def ensure_collection_exists(self, collection, dim):
|
||||
"""Ensure collection exists, create if it doesn't"""
|
||||
if collection != self.last_collection:
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
logger.info(f"Created collection: {collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Qdrant collection creation failed: {e}")
|
||||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
|
|
@ -56,10 +74,11 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"t_" + msg.user + "_" + msg.collection + "_" +
|
||||
str(dim)
|
||||
"t_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
self.ensure_collection_exists(collection, dim)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
search_result = self.qdrant.query_points(
|
||||
|
|
|
|||
0
trustgraph-flow/trustgraph/query/objects/__init__.py
Normal file
0
trustgraph-flow/trustgraph/query/objects/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . service import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
run()
|
||||
|
||||
738
trustgraph-flow/trustgraph/query/objects/cassandra/service.py
Normal file
738
trustgraph-flow/trustgraph/query/objects/cassandra/service.py
Normal file
|
|
@ -0,0 +1,738 @@
|
|||
"""
|
||||
Objects query service using GraphQL. Input is a GraphQL query with variables.
|
||||
Output is GraphQL response data with any errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
from strawberry.types import Info
|
||||
from strawberry.scalars import JSON
|
||||
from strawberry.tools import create_type
|
||||
|
||||
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from .... schema import Error, RowSchema, Field as SchemaField
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "objects-query"
|
||||
|
||||
# GraphQL filter input types
|
||||
@strawberry.input
|
||||
class IntFilter:
|
||||
eq: Optional[int] = None
|
||||
gt: Optional[int] = None
|
||||
gte: Optional[int] = None
|
||||
lt: Optional[int] = None
|
||||
lte: Optional[int] = None
|
||||
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[int] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[int]] = None
|
||||
|
||||
@strawberry.input
|
||||
class StringFilter:
|
||||
eq: Optional[str] = None
|
||||
contains: Optional[str] = None
|
||||
startsWith: Optional[str] = None
|
||||
endsWith: Optional[str] = None
|
||||
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[str] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[str]] = None
|
||||
|
||||
@strawberry.input
|
||||
class FloatFilter:
|
||||
eq: Optional[float] = None
|
||||
gt: Optional[float] = None
|
||||
gte: Optional[float] = None
|
||||
lt: Optional[float] = None
|
||||
lte: Optional[float] = None
|
||||
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[float] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = ObjectsQueryRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = ObjectsQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# GraphQL schema
|
||||
self.graphql_schema: Optional[Schema] = None
|
||||
|
||||
# GraphQL types cache
|
||||
self.graphql_types: Dict[str, type] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
# Known keyspaces and tables
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.known_tables: Dict[str, Set[str]] = {}
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def sanitize_table(self, name: str) -> str:
|
||||
"""Sanitize table names for Cassandra compatibility"""
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
|
||||
"""Parse GraphQL filter key into field name and operator"""
|
||||
if not filter_key:
|
||||
return ("", "eq")
|
||||
|
||||
# Support common GraphQL filter patterns:
|
||||
# field_name -> (field_name, "eq")
|
||||
# field_name_gt -> (field_name, "gt")
|
||||
# field_name_gte -> (field_name, "gte")
|
||||
# field_name_lt -> (field_name, "lt")
|
||||
# field_name_lte -> (field_name, "lte")
|
||||
# field_name_in -> (field_name, "in")
|
||||
|
||||
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
||||
|
||||
for op_suffix in operators:
|
||||
if filter_key.endswith(op_suffix):
|
||||
field_name = filter_key[:-len(op_suffix)]
|
||||
operator = op_suffix[1:] # Remove the leading underscore
|
||||
return (field_name, operator)
|
||||
|
||||
# Default to equality if no operator suffix found
|
||||
return (filter_key, "eq")
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
self.graphql_types = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Regenerate GraphQL schema
|
||||
self.generate_graphql_schema()
|
||||
|
||||
def get_python_type(self, field_type: str):
|
||||
"""Convert schema field type to Python type for GraphQL"""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"timestamp": str, # Use string for timestamps in GraphQL
|
||||
"date": str,
|
||||
"time": str,
|
||||
"uuid": str
|
||||
}
|
||||
return type_mapping.get(field_type, str)
|
||||
|
||||
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
|
||||
"""Create a GraphQL type from a RowSchema"""
|
||||
|
||||
# Create annotations for the GraphQL type
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
for field in row_schema.fields:
|
||||
python_type = self.get_python_type(field.type)
|
||||
|
||||
# Make field optional if not required
|
||||
if not field.required and not field.primary:
|
||||
annotations[field.name] = Optional[python_type]
|
||||
defaults[field.name] = None
|
||||
else:
|
||||
annotations[field.name] = python_type
|
||||
|
||||
# Create the class dynamically
|
||||
type_name = f"{schema_name.capitalize()}Type"
|
||||
graphql_class = type(
|
||||
type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry decorator
|
||||
return strawberry.type(graphql_class)
|
||||
|
||||
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
|
||||
"""Create a dynamic filter input type for a schema"""
|
||||
# Create the filter type dynamically
|
||||
filter_type_name = f"{schema_name.capitalize()}Filter"
|
||||
|
||||
# Add __annotations__ and defaults for the fields
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
||||
|
||||
for field in row_schema.fields:
|
||||
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
|
||||
|
||||
# Allow filtering on any field for now, not just indexed/primary
|
||||
# if field.indexed or field.primary:
|
||||
if field.type == "integer":
|
||||
annotations[field.name] = Optional[IntFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added IntFilter for {field.name}")
|
||||
elif field.type == "float":
|
||||
annotations[field.name] = Optional[FloatFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added FloatFilter for {field.name}")
|
||||
elif field.type == "string":
|
||||
annotations[field.name] = Optional[StringFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added StringFilter for {field.name}")
|
||||
|
||||
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
|
||||
|
||||
# Create the class dynamically
|
||||
FilterType = type(
|
||||
filter_type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry input decorator
|
||||
FilterType = strawberry.input(FilterType)
|
||||
|
||||
return FilterType
|
||||
|
||||
def create_sort_direction_enum(self):
|
||||
"""Create sort direction enum"""
|
||||
@strawberry.enum
|
||||
class SortDirection(Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
return SortDirection
|
||||
|
||||
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
|
||||
"""Parse the idiomatic nested filter structure"""
|
||||
if not where_obj:
|
||||
return {}
|
||||
|
||||
conditions = {}
|
||||
|
||||
logger.info(f"Parsing where clause: {where_obj}")
|
||||
|
||||
for field_name, filter_obj in where_obj.__dict__.items():
|
||||
if filter_obj is None:
|
||||
continue
|
||||
|
||||
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
||||
|
||||
if hasattr(filter_obj, '__dict__'):
|
||||
# This is a filter object (StringFilter, IntFilter, etc.)
|
||||
for operator, value in filter_obj.__dict__.items():
|
||||
if value is not None:
|
||||
logger.info(f"Found operator {operator} with value {value}")
|
||||
# Map GraphQL operators to our internal format
|
||||
if operator == "eq":
|
||||
conditions[field_name] = value
|
||||
elif operator in ["gt", "gte", "lt", "lte"]:
|
||||
conditions[f"{field_name}_{operator}"] = value
|
||||
elif operator == "in_":
|
||||
conditions[f"{field_name}_in"] = value
|
||||
elif operator == "contains":
|
||||
conditions[f"{field_name}_contains"] = value
|
||||
|
||||
logger.info(f"Final parsed conditions: {conditions}")
|
||||
return conditions
|
||||
|
||||
def generate_graphql_schema(self):
|
||||
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
|
||||
if not self.schemas:
|
||||
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
||||
self.graphql_schema = None
|
||||
return
|
||||
|
||||
# Create GraphQL types and filter types for each schema
|
||||
filter_types = {}
|
||||
sort_direction_enum = self.create_sort_direction_enum()
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.create_graphql_type(schema_name, row_schema)
|
||||
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
|
||||
|
||||
self.graphql_types[schema_name] = graphql_type
|
||||
filter_types[schema_name] = filter_type
|
||||
|
||||
# Create the Query class with resolvers
|
||||
query_dict = {'__annotations__': {}}
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.graphql_types[schema_name]
|
||||
filter_type = filter_types[schema_name]
|
||||
|
||||
# Create resolver function for this schema
|
||||
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
|
||||
async def resolver(
|
||||
info: Info,
|
||||
where: Optional[f_type] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[sort_enum] = None,
|
||||
limit: Optional[int] = 100
|
||||
) -> List[g_type]:
|
||||
# Get the processor instance from context
|
||||
processor = info.context["processor"]
|
||||
user = info.context["user"]
|
||||
collection = info.context["collection"]
|
||||
|
||||
# Parse the idiomatic where clause
|
||||
filters = processor.parse_idiomatic_where_clause(where)
|
||||
|
||||
# Query Cassandra
|
||||
results = await processor.query_cassandra(
|
||||
user, collection, s_name, r_schema,
|
||||
filters, limit, order_by, direction
|
||||
)
|
||||
|
||||
# Convert to GraphQL types
|
||||
graphql_results = []
|
||||
for row in results:
|
||||
graphql_obj = g_type(**row)
|
||||
graphql_results.append(graphql_obj)
|
||||
|
||||
return graphql_results
|
||||
|
||||
return resolver
|
||||
|
||||
# Add resolver to query
|
||||
resolver_name = schema_name
|
||||
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
|
||||
|
||||
# Add field to query dictionary
|
||||
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
|
||||
query_dict['__annotations__'][resolver_name] = List[graphql_type]
|
||||
|
||||
# Create the Query class
|
||||
Query = type('Query', (), query_dict)
|
||||
Query = strawberry.type(Query)
|
||||
|
||||
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
||||
self.graphql_schema = strawberry.Schema(
|
||||
query=Query,
|
||||
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
||||
)
|
||||
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
||||
|
||||
async def query_cassandra(
|
||||
self,
|
||||
user: str,
|
||||
collection: str,
|
||||
schema_name: str,
|
||||
row_schema: RowSchema,
|
||||
filters: Dict[str, Any],
|
||||
limit: int,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[Any] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query against Cassandra"""
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Build the query
|
||||
keyspace = self.sanitize_name(user)
|
||||
table = self.sanitize_table(schema_name)
|
||||
|
||||
# Start with basic SELECT
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
|
||||
# Add WHERE clauses
|
||||
where_clauses = [f"collection = %s"]
|
||||
params = [collection]
|
||||
|
||||
# Add filters for indexed or primary key fields
|
||||
for filter_key, value in filters.items():
|
||||
if value is not None:
|
||||
# Parse field name and operator from filter key
|
||||
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
|
||||
result = self.parse_filter_key(filter_key)
|
||||
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
|
||||
|
||||
if not result or len(result) != 2:
|
||||
logger.error(f"parse_filter_key returned invalid result: {result}")
|
||||
continue # Skip this filter
|
||||
|
||||
field_name, operator = result
|
||||
|
||||
# Find the field in schema
|
||||
schema_field = None
|
||||
for f in row_schema.fields:
|
||||
if f.name == field_name:
|
||||
schema_field = f
|
||||
break
|
||||
|
||||
if schema_field:
|
||||
safe_field = self.sanitize_name(field_name)
|
||||
|
||||
# Build WHERE clause based on operator
|
||||
if operator == "eq":
|
||||
where_clauses.append(f"{safe_field} = %s")
|
||||
params.append(value)
|
||||
elif operator == "gt":
|
||||
where_clauses.append(f"{safe_field} > %s")
|
||||
params.append(value)
|
||||
elif operator == "gte":
|
||||
where_clauses.append(f"{safe_field} >= %s")
|
||||
params.append(value)
|
||||
elif operator == "lt":
|
||||
where_clauses.append(f"{safe_field} < %s")
|
||||
params.append(value)
|
||||
elif operator == "lte":
|
||||
where_clauses.append(f"{safe_field} <= %s")
|
||||
params.append(value)
|
||||
elif operator == "in":
|
||||
if isinstance(value, list):
|
||||
placeholders = ",".join(["%s"] * len(value))
|
||||
where_clauses.append(f"{safe_field} IN ({placeholders})")
|
||||
params.extend(value)
|
||||
else:
|
||||
# Default to equality for unknown operators
|
||||
where_clauses.append(f"{safe_field} = %s")
|
||||
params.append(value)
|
||||
|
||||
if where_clauses:
|
||||
query += " WHERE " + " AND ".join(where_clauses)
|
||||
|
||||
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
|
||||
cassandra_order_by_added = False
|
||||
if order_by and direction:
|
||||
# Validate that order_by field exists in schema
|
||||
order_field_exists = any(f.name == order_by for f in row_schema.fields)
|
||||
if order_field_exists:
|
||||
safe_order_field = self.sanitize_name(order_by)
|
||||
direction_str = "ASC" if direction.value == "asc" else "DESC"
|
||||
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
|
||||
query += f" ORDER BY {safe_order_field} {direction_str}"
|
||||
|
||||
# Add limit first (must come before ALLOW FILTERING)
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
|
||||
query += " ALLOW FILTERING"
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
result = self.session.execute(query, params)
|
||||
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
|
||||
except Exception as e:
|
||||
# If ORDER BY fails, try without it
|
||||
if order_by and direction and "ORDER BY" in query:
|
||||
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
|
||||
# Remove ORDER BY clause and retry
|
||||
query_parts = query.split(" ORDER BY ")
|
||||
if len(query_parts) == 2:
|
||||
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
|
||||
result = self.session.execute(query_without_order, params)
|
||||
cassandra_order_by_added = False
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
|
||||
# Convert rows to dicts
|
||||
results = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
for field in row_schema.fields:
|
||||
safe_field = self.sanitize_name(field.name)
|
||||
if hasattr(row, safe_field):
|
||||
value = getattr(row, safe_field)
|
||||
# Use original field name in result
|
||||
row_dict[field.name] = value
|
||||
results.append(row_dict)
|
||||
|
||||
# Post-query sorting if Cassandra didn't handle ORDER BY
|
||||
if order_by and direction and not cassandra_order_by_added:
|
||||
reverse_order = (direction.value == "desc")
|
||||
try:
|
||||
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def execute_graphql_query(
|
||||
self,
|
||||
query: str,
|
||||
variables: Dict[str, Any],
|
||||
operation_name: Optional[str],
|
||||
user: str,
|
||||
collection: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a GraphQL query"""
|
||||
|
||||
if not self.graphql_schema:
|
||||
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
||||
|
||||
# Create context for the query
|
||||
context = {
|
||||
"processor": self,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
# Execute the query
|
||||
result = await self.graphql_schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
operation_name=operation_name,
|
||||
context_value=context
|
||||
)
|
||||
|
||||
# Build response
|
||||
response = {}
|
||||
|
||||
if result.data:
|
||||
response["data"] = result.data
|
||||
else:
|
||||
response["data"] = None
|
||||
|
||||
if result.errors:
|
||||
response["errors"] = [
|
||||
{
|
||||
"message": str(error),
|
||||
"path": getattr(error, "path", []),
|
||||
"extensions": getattr(error, "extensions", {})
|
||||
}
|
||||
for error in result.errors
|
||||
]
|
||||
else:
|
||||
response["errors"] = []
|
||||
|
||||
# Add extensions if any
|
||||
if hasattr(result, "extensions") and result.extensions:
|
||||
response["extensions"] = result.extensions
|
||||
|
||||
return response
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling objects query request {id}...")
|
||||
|
||||
# Execute GraphQL query
|
||||
result = await self.execute_graphql_query(
|
||||
query=request.query,
|
||||
variables=dict(request.variables) if request.variables else {},
|
||||
operation_name=request.operation_name,
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Create response
|
||||
graphql_errors = []
|
||||
if "errors" in result and result["errors"]:
|
||||
for err in result["errors"]:
|
||||
graphql_error = GraphQLError(
|
||||
message=err.get("message", ""),
|
||||
path=err.get("path", []),
|
||||
extensions=err.get("extensions", {})
|
||||
)
|
||||
graphql_errors.append(graphql_error)
|
||||
|
||||
response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
||||
errors=graphql_errors,
|
||||
extensions=result.get("extensions", {})
|
||||
)
|
||||
|
||||
logger.debug("Sending objects query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.debug("Objects query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in objects query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = ObjectsQueryResponse(
|
||||
error = Error(
|
||||
type = "objects-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
data = None,
|
||||
errors = [],
|
||||
extensions = {}
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for objects-query-graphql-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -6,36 +6,44 @@ null. Output is a list of triples.
|
|||
|
||||
import logging
|
||||
|
||||
from .... direct.cassandra import TrustGraph
|
||||
from .... direct.cassandra_kg import KnowledgeGraph
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... base import TriplesQueryService
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(TriplesQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
graph_username = params.get("graph_username", None)
|
||||
graph_password = params.get("graph_password", None)
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"graph_host": graph_host,
|
||||
"graph_username": graph_username,
|
||||
"cassandra_host": ','.join(hosts),
|
||||
"cassandra_username": username,
|
||||
}
|
||||
)
|
||||
|
||||
self.graph_host = [graph_host]
|
||||
self.username = graph_username
|
||||
self.password = graph_password
|
||||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
def create_value(self, ent):
|
||||
|
|
@ -48,21 +56,21 @@ class Processor(TriplesQueryService):
|
|||
|
||||
try:
|
||||
|
||||
table = (query.user, query.collection)
|
||||
user = query.user
|
||||
|
||||
if table != self.table:
|
||||
if self.username and self.password:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=query.user, table=query.collection,
|
||||
username=self.username, password=self.password
|
||||
if user != self.table:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=query.user, table=query.collection,
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
)
|
||||
self.table = table
|
||||
self.table = user
|
||||
|
||||
triples = []
|
||||
|
||||
|
|
@ -70,13 +78,13 @@ class Processor(TriplesQueryService):
|
|||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
resp = self.tg.get_spo(
|
||||
query.s.value, query.p.value, query.o.value,
|
||||
query.collection, query.s.value, query.p.value, query.o.value,
|
||||
limit=query.limit
|
||||
)
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
else:
|
||||
resp = self.tg.get_sp(
|
||||
query.s.value, query.p.value,
|
||||
query.collection, query.s.value, query.p.value,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
|
|
@ -84,14 +92,14 @@ class Processor(TriplesQueryService):
|
|||
else:
|
||||
if query.o is not None:
|
||||
resp = self.tg.get_os(
|
||||
query.o.value, query.s.value,
|
||||
query.collection, query.o.value, query.s.value,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, t.p, query.o.value))
|
||||
else:
|
||||
resp = self.tg.get_s(
|
||||
query.s.value,
|
||||
query.collection, query.s.value,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
|
|
@ -100,14 +108,14 @@ class Processor(TriplesQueryService):
|
|||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
resp = self.tg.get_po(
|
||||
query.p.value, query.o.value,
|
||||
query.collection, query.p.value, query.o.value,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, query.p.value, query.o.value))
|
||||
else:
|
||||
resp = self.tg.get_p(
|
||||
query.p.value,
|
||||
query.collection, query.p.value,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
|
|
@ -115,13 +123,14 @@ class Processor(TriplesQueryService):
|
|||
else:
|
||||
if query.o is not None:
|
||||
resp = self.tg.get_o(
|
||||
query.o.value,
|
||||
query.collection, query.o.value,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, query.o.value))
|
||||
else:
|
||||
resp = self.tg.get_all(
|
||||
query.collection,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
|
|
@ -147,24 +156,7 @@ class Processor(TriplesQueryService):
|
|||
def add_args(parser):
|
||||
|
||||
TriplesQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default="localhost",
|
||||
help=f'Graph host (default: localhost)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-username',
|
||||
default=None,
|
||||
help=f'Cassandra username'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-password',
|
||||
default=None,
|
||||
help=f'Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ class Processor(TriplesQueryService):
|
|||
|
||||
try:
|
||||
|
||||
# Extract user and collection, use defaults if not provided
|
||||
user = query.user if query.user else "default"
|
||||
collection = query.collection if query.collection else "default"
|
||||
|
||||
triples = []
|
||||
|
||||
if query.s is not None:
|
||||
|
|
@ -64,10 +68,13 @@ class Processor(TriplesQueryService):
|
|||
# SPO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value, value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -75,10 +82,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value, uri=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -90,10 +100,13 @@ class Processor(TriplesQueryService):
|
|||
# SP
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -102,10 +115,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -120,10 +136,13 @@ class Processor(TriplesQueryService):
|
|||
# SO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -132,10 +151,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, uri=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -148,10 +170,13 @@ class Processor(TriplesQueryService):
|
|||
# S
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -160,10 +185,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -181,10 +209,13 @@ class Processor(TriplesQueryService):
|
|||
# PO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value, value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -193,10 +224,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value, dest=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -209,10 +243,13 @@ class Processor(TriplesQueryService):
|
|||
# P
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -221,10 +258,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -239,10 +279,13 @@ class Processor(TriplesQueryService):
|
|||
# O
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -251,10 +294,13 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -267,9 +313,12 @@ class Processor(TriplesQueryService):
|
|||
# *
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -278,9 +327,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ class Processor(TriplesQueryService):
|
|||
|
||||
try:
|
||||
|
||||
# Extract user and collection, use defaults if not provided
|
||||
user = query.user if query.user else "default"
|
||||
collection = query.collection if query.collection else "default"
|
||||
|
||||
triples = []
|
||||
|
||||
if query.s is not None:
|
||||
|
|
@ -64,9 +68,12 @@ class Processor(TriplesQueryService):
|
|||
# SPO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src",
|
||||
src=query.s.value, rel=query.p.value, value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -74,9 +81,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src",
|
||||
src=query.s.value, rel=query.p.value, uri=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -88,9 +98,12 @@ class Processor(TriplesQueryService):
|
|||
# SP
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN dest.value as dest",
|
||||
src=query.s.value, rel=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -99,9 +112,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest",
|
||||
src=query.s.value, rel=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -116,9 +132,12 @@ class Processor(TriplesQueryService):
|
|||
# SO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=query.s.value, value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -127,9 +146,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=query.s.value, uri=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -142,9 +164,12 @@ class Processor(TriplesQueryService):
|
|||
# S
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest",
|
||||
src=query.s.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -153,9 +178,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest",
|
||||
src=query.s.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -173,9 +201,12 @@ class Processor(TriplesQueryService):
|
|||
# PO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=query.p.value, value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -184,9 +215,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=query.p.value, dest=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -199,9 +233,12 @@ class Processor(TriplesQueryService):
|
|||
# P
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest",
|
||||
uri=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -210,9 +247,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.uri as dest",
|
||||
uri=query.p.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -227,9 +267,12 @@ class Processor(TriplesQueryService):
|
|||
# O
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
value=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -238,9 +281,12 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
uri=query.o.value,
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -253,8 +299,11 @@ class Processor(TriplesQueryService):
|
|||
# *
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -263,8 +312,11 @@ class Processor(TriplesQueryService):
|
|||
triples.append((data["src"], data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,12 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
response = await self.rag.query(v.query, doc_limit=doc_limit)
|
||||
response = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from . service import *
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
run()
|
||||
25
trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt
Normal file
25
trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
You are a database schema selection expert. Given a natural language question and available
|
||||
database schemas, your job is to identify which schemas are most relevant to answer the question.
|
||||
|
||||
## Available Schemas:
|
||||
{% for schema in schemas %}
|
||||
**{{ schema.name }}**: {{ schema.description }}
|
||||
Fields:
|
||||
{% for field in schema.fields %}
|
||||
- {{ field.name }} ({{ field.type }}): {{ field.description }}
|
||||
{% endfor %}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
## Question:
|
||||
{{ question }}
|
||||
|
||||
## Instructions:
|
||||
1. Analyze the question to understand what data is being requested
|
||||
2. Examine each schema to understand what data it contains
|
||||
3. Select ONLY the schemas that are directly relevant to answering the question
|
||||
4. Return your answer as a JSON array of schema names
|
||||
|
||||
## Response Format:
|
||||
Return ONLY a JSON array of schema names, nothing else.
|
||||
Example: ["customers", "orders", "products"]
|
||||
101
trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt
Normal file
101
trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
You are a GraphQL query generation expert. Given a natural language question and relevant database
|
||||
schemas, generate a precise GraphQL query to answer the question.
|
||||
|
||||
## Question:
|
||||
{{ question }}
|
||||
|
||||
## Relevant Schemas:
|
||||
{% for schema in schemas %}
|
||||
**{{ schema.name }}**: {{ schema.description }}
|
||||
Fields:
|
||||
{% for field in schema.fields %}
|
||||
- {{ field.name }} ({{ field.type }}){% if field.description %}: {{ field.description }}{% endif
|
||||
%}{% if field.primary_key %} [PRIMARY KEY]{% endif %}{% if field.required %} [REQUIRED]{% endif
|
||||
%}{% if field.indexed %} [INDEXED]{% endif %}{% if field.enum_values %} [OPTIONS: {{
|
||||
field.enum_values|join(', ') }}]{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
## GraphQL Query Rules:
|
||||
1. Use the schema names as GraphQL query fields (e.g., `customers`, `orders`)
|
||||
2. Apply filters using the `where` parameter with nested filter objects
|
||||
3. Available filter operators per field type:
|
||||
- String fields: `eq`, `contains`, `startsWith`, `endsWith`, `in`, `not`, `not_in`
|
||||
- Integer/Float fields: `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `not`, `not_in`
|
||||
4. Use `order_by` for sorting (field name as string)
|
||||
5. Use `direction` for sort direction: `ASC` or `DESC`
|
||||
6. Use `limit` to restrict number of results
|
||||
7. Select specific fields in the query body
|
||||
|
||||
## Example GraphQL Queries:
|
||||
|
||||
**Question**: "Show me customers from California"
|
||||
```graphql
|
||||
query {
|
||||
customers(where: {state: {eq: "California"}}, limit: 100) {
|
||||
customer_id
|
||||
name
|
||||
email
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
Question: "Top 10 products by price"
|
||||
query {
|
||||
products(order_by: "price", direction: DESC, limit: 10) {
|
||||
product_id
|
||||
name
|
||||
price
|
||||
category
|
||||
}
|
||||
}
|
||||
|
||||
Question: "Recent orders over $100"
|
||||
query {
|
||||
orders(
|
||||
where: {
|
||||
total_amount: {gt: 100}
|
||||
order_date: {gte: "2024-01-01"}
|
||||
}
|
||||
order_by: "order_date"
|
||||
direction: DESC
|
||||
limit: 50
|
||||
) {
|
||||
order_id
|
||||
customer_id
|
||||
total_amount
|
||||
order_date
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
Instructions:
|
||||
|
||||
1. Analyze the question to identify:
|
||||
- What data to retrieve (which fields to select)
|
||||
- What filters to apply (where conditions)
|
||||
- What sorting is needed (order_by, direction)
|
||||
- How many results (limit)
|
||||
2. Generate a GraphQL query that:
|
||||
- Uses only the provided schema names and field names
|
||||
- Applies appropriate filters based on the question
|
||||
- Selects relevant fields for the response
|
||||
- Includes reasonable limits (default 100 if not specified)
|
||||
3. If variables are needed, include them in the response
|
||||
|
||||
Response Format:
|
||||
|
||||
Return a JSON object with:
|
||||
- "query": the GraphQL query string
|
||||
- "variables": object with any GraphQL variables (empty object if none)
|
||||
- "confidence": float between 0.0-1.0 indicating confidence in the query
|
||||
|
||||
Example:
|
||||
{
|
||||
"query": "query { customers(where: {state: {eq: \"California\"}}, limit: 100) { customer_id name
|
||||
email state } }",
|
||||
"variables": {},
|
||||
"confidence": 0.95
|
||||
}
|
||||
|
||||
315
trustgraph-flow/trustgraph/retrieval/nlp_query/service.py
Normal file
315
trustgraph-flow/trustgraph/retrieval/nlp_query/service.py
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
"""
|
||||
NLP to Structured Query Service - converts natural language questions to GraphQL queries.
|
||||
Two-phase approach: 1) Select relevant schemas, 2) Generate GraphQL query.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from ...schema import PromptRequest
|
||||
from ...schema import Error, RowSchema, Field as SchemaField
|
||||
|
||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, PromptClientSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "nlp-query"
|
||||
default_schema_selection_template = "schema-selection"
|
||||
default_graphql_generation_template = "graphql-generation"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
# Configurable prompt template names
|
||||
self.schema_selection_template = params.get("schema_selection_template", default_schema_selection_template)
|
||||
self.graphql_generation_template = params.get("graphql_generation_template", default_graphql_generation_template)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = QuestionToStructuredQueryRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = QuestionToStructuredQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Client spec for calling prompt service
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
logger.info("NLP Query service initialized")
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
async def phase1_select_schemas(self, question: str, flow) -> List[str]:
|
||||
"""Phase 1: Use prompt service to select relevant schemas for the question"""
|
||||
logger.info("Starting Phase 1: Schema selection")
|
||||
|
||||
# Prepare schema information for the prompt
|
||||
schema_info = []
|
||||
for name, schema in self.schemas.items():
|
||||
schema_desc = {
|
||||
"name": name,
|
||||
"description": schema.description,
|
||||
"fields": [{"name": f.name, "type": f.type, "description": f.description}
|
||||
for f in schema.fields]
|
||||
}
|
||||
schema_info.append(schema_desc)
|
||||
|
||||
# Create prompt variables
|
||||
variables = {
|
||||
"question": question,
|
||||
"schemas": schema_info # Pass structured data directly
|
||||
}
|
||||
|
||||
# Call prompt service for schema selection
|
||||
# Convert variables to JSON-encoded terms
|
||||
terms = {k: json.dumps(v) for k, v in variables.items()}
|
||||
prompt_request = PromptRequest(
|
||||
id=self.schema_selection_template,
|
||||
terms=terms
|
||||
)
|
||||
|
||||
try:
|
||||
response = await flow("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error is not None:
|
||||
raise Exception(f"Prompt service error: {response.error}")
|
||||
|
||||
# Parse the response to get selected schema names
|
||||
# Response could be in either text or object field
|
||||
response_data = response.text if response.text else response.object
|
||||
if response_data is None:
|
||||
raise Exception("Prompt service returned empty response")
|
||||
|
||||
# Parse JSON array of schema names
|
||||
selected_schemas = json.loads(response_data)
|
||||
|
||||
logger.info(f"Phase 1 selected schemas: {selected_schemas}")
|
||||
return selected_schemas
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Phase 1 schema selection failed: {e}")
|
||||
raise
|
||||
|
||||
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]:
|
||||
"""Phase 2: Generate GraphQL query using selected schemas"""
|
||||
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
|
||||
|
||||
# Get detailed schema information for selected schemas only
|
||||
selected_schema_info = []
|
||||
for schema_name in selected_schemas:
|
||||
if schema_name in self.schemas:
|
||||
schema = self.schemas[schema_name]
|
||||
schema_desc = {
|
||||
"name": schema_name,
|
||||
"description": schema.description,
|
||||
"fields": [
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"description": f.description,
|
||||
"required": f.required,
|
||||
"primary_key": f.primary,
|
||||
"indexed": f.indexed,
|
||||
"enum_values": f.enum_values if f.enum_values else []
|
||||
}
|
||||
for f in schema.fields
|
||||
]
|
||||
}
|
||||
selected_schema_info.append(schema_desc)
|
||||
|
||||
# Create prompt variables for GraphQL generation
|
||||
variables = {
|
||||
"question": question,
|
||||
"schemas": selected_schema_info # Pass structured data directly
|
||||
}
|
||||
|
||||
# Call prompt service for GraphQL generation
|
||||
# Convert variables to JSON-encoded terms
|
||||
terms = {k: json.dumps(v) for k, v in variables.items()}
|
||||
prompt_request = PromptRequest(
|
||||
id=self.graphql_generation_template,
|
||||
terms=terms
|
||||
)
|
||||
|
||||
try:
|
||||
response = await flow("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error is not None:
|
||||
raise Exception(f"Prompt service error: {response.error}")
|
||||
|
||||
# Parse the response to get GraphQL query and variables
|
||||
# Response could be in either text or object field
|
||||
response_data = response.text if response.text else response.object
|
||||
if response_data is None:
|
||||
raise Exception("Prompt service returned empty response")
|
||||
|
||||
# Parse JSON with "query" and "variables" fields
|
||||
result = json.loads(response_data)
|
||||
|
||||
logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Phase 2 GraphQL generation failed: {e}")
|
||||
raise
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming question to structured query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling NLP query request {id}: {request.question[:100]}...")
|
||||
|
||||
# Phase 1: Select relevant schemas
|
||||
selected_schemas = await self.phase1_select_schemas(request.question, flow)
|
||||
|
||||
# Phase 2: Generate GraphQL query
|
||||
graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas, flow)
|
||||
|
||||
# Create response
|
||||
response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query=graphql_result.get("query", ""),
|
||||
variables=graphql_result.get("variables", {}),
|
||||
detected_schemas=selected_schemas,
|
||||
confidence=graphql_result.get("confidence", 0.8) # Default confidence
|
||||
)
|
||||
|
||||
logger.info("Sending NLP query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.info("NLP query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in NLP query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = QuestionToStructuredQueryResponse(
|
||||
error = Error(
|
||||
type = "nlp-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
graphql_query = "",
|
||||
variables = {},
|
||||
detected_schemas = [],
|
||||
confidence = 0.0
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--schema-selection-template',
|
||||
default=default_schema_selection_template,
|
||||
help=f'Prompt template name for schema selection (default: {default_schema_selection_template})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graphql-generation-template',
|
||||
default=default_graphql_generation_template,
|
||||
help=f'Prompt template name for GraphQL generation (default: {default_graphql_generation_template})'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for nlp-query command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# Structured data diagnosis service
|
||||
from .service import *
|
||||
494
trustgraph-flow/trustgraph/retrieval/structured_diag/service.py
Normal file
494
trustgraph-flow/trustgraph/retrieval/structured_diag/service.py
Normal file
|
|
@ -0,0 +1,494 @@
|
|||
"""
|
||||
Structured Data Diagnosis Service - analyzes structured data and generates descriptors.
|
||||
Supports three operations: detect-type, generate-descriptor, and diagnose (combined).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from ...schema import StructuredDataDiagnosisRequest, StructuredDataDiagnosisResponse
|
||||
from ...schema import PromptRequest, Error, RowSchema, Field as SchemaField
|
||||
|
||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, PromptClientSpec
|
||||
|
||||
from .type_detector import detect_data_type, detect_csv_options
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "structured-diag"
|
||||
default_csv_prompt = "diagnose-csv"
|
||||
default_json_prompt = "diagnose-json"
|
||||
default_xml_prompt = "diagnose-xml"
|
||||
default_schema_selection_prompt = "schema-selection"
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
# Configurable prompt template names
|
||||
self.csv_prompt = params.get("csv_prompt", default_csv_prompt)
|
||||
self.json_prompt = params.get("json_prompt", default_json_prompt)
|
||||
self.xml_prompt = params.get("xml_prompt", default_xml_prompt)
|
||||
self.schema_selection_prompt = params.get("schema_selection_prompt", default_schema_selection_prompt)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = StructuredDataDiagnosisRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = StructuredDataDiagnosisResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Client spec for calling prompt service
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
logger.info("Structured Data Diagnosis service initialized")
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming structured data diagnosis request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling structured data diagnosis request {id}: operation={request.operation}")
|
||||
|
||||
if request.operation == "detect-type":
|
||||
response = await self.detect_type_operation(request, flow)
|
||||
elif request.operation == "generate-descriptor":
|
||||
response = await self.generate_descriptor_operation(request, flow)
|
||||
elif request.operation == "diagnose":
|
||||
response = await self.diagnose_operation(request, flow)
|
||||
elif request.operation == "schema-selection":
|
||||
response = await self.schema_selection_operation(request, flow)
|
||||
else:
|
||||
error = Error(
|
||||
type="InvalidOperation",
|
||||
message=f"Unknown operation: {request.operation}. Supported: detect-type, generate-descriptor, diagnose, schema-selection"
|
||||
)
|
||||
response = StructuredDataDiagnosisResponse(
|
||||
error=error,
|
||||
operation=request.operation
|
||||
)
|
||||
|
||||
# Send response
|
||||
await flow("response").send(
|
||||
response, properties={"id": id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing diagnosis request: {e}", exc_info=True)
|
||||
|
||||
error = Error(
|
||||
type="ProcessingError",
|
||||
message=f"Failed to process diagnosis request: {str(e)}"
|
||||
)
|
||||
|
||||
response = StructuredDataDiagnosisResponse(
|
||||
error=error,
|
||||
operation=request.operation if request else "unknown"
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
response, properties={"id": id}
|
||||
)
|
||||
|
||||
async def detect_type_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
|
||||
"""Handle detect-type operation"""
|
||||
logger.info("Processing detect-type operation")
|
||||
|
||||
detected_type, confidence = detect_data_type(request.sample)
|
||||
|
||||
metadata = {}
|
||||
if detected_type == "csv":
|
||||
csv_options = detect_csv_options(request.sample)
|
||||
metadata["csv_options"] = json.dumps(csv_options)
|
||||
|
||||
return StructuredDataDiagnosisResponse(
|
||||
error=None,
|
||||
operation=request.operation,
|
||||
detected_type=detected_type or "",
|
||||
confidence=confidence,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
async def generate_descriptor_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
|
||||
"""Handle generate-descriptor operation"""
|
||||
logger.info(f"Processing generate-descriptor operation for type: {request.type}")
|
||||
|
||||
if not request.type:
|
||||
error = Error(
|
||||
type="MissingParameter",
|
||||
message="Type parameter is required for generate-descriptor operation"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
if not request.schema_name:
|
||||
error = Error(
|
||||
type="MissingParameter",
|
||||
message="Schema name parameter is required for generate-descriptor operation"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
# Get target schema
|
||||
if request.schema_name not in self.schemas:
|
||||
error = Error(
|
||||
type="SchemaNotFound",
|
||||
message=f"Schema '{request.schema_name}' not found in configuration"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
target_schema = self.schemas[request.schema_name]
|
||||
|
||||
# Generate descriptor using prompt service
|
||||
descriptor = await self.generate_descriptor_with_prompt(
|
||||
request.sample, request.type, target_schema, request.options, flow
|
||||
)
|
||||
|
||||
if descriptor is None:
|
||||
error = Error(
|
||||
type="DescriptorGenerationFailed",
|
||||
message="Failed to generate descriptor using prompt service"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
return StructuredDataDiagnosisResponse(
|
||||
error=None,
|
||||
operation=request.operation,
|
||||
descriptor=json.dumps(descriptor),
|
||||
metadata={"schema_name": request.schema_name, "type": request.type}
|
||||
)
|
||||
|
||||
async def diagnose_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
|
||||
"""Handle combined diagnose operation"""
|
||||
logger.info("Processing combined diagnose operation")
|
||||
|
||||
# Step 1: Detect type
|
||||
detected_type, confidence = detect_data_type(request.sample)
|
||||
|
||||
if not detected_type:
|
||||
error = Error(
|
||||
type="TypeDetectionFailed",
|
||||
message="Unable to detect data type from sample"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
# Step 2: Use provided schema name or auto-select first available
|
||||
schema_name = request.schema_name
|
||||
if not schema_name and self.schemas:
|
||||
schema_name = list(self.schemas.keys())[0]
|
||||
logger.info(f"Auto-selected schema: {schema_name}")
|
||||
|
||||
if not schema_name:
|
||||
error = Error(
|
||||
type="NoSchemaAvailable",
|
||||
message="No schema specified and no schemas available in configuration"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
if schema_name not in self.schemas:
|
||||
error = Error(
|
||||
type="SchemaNotFound",
|
||||
message=f"Schema '{schema_name}' not found in configuration"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
target_schema = self.schemas[schema_name]
|
||||
|
||||
# Step 3: Generate descriptor
|
||||
descriptor = await self.generate_descriptor_with_prompt(
|
||||
request.sample, detected_type, target_schema, request.options, flow
|
||||
)
|
||||
|
||||
if descriptor is None:
|
||||
error = Error(
|
||||
type="DescriptorGenerationFailed",
|
||||
message="Failed to generate descriptor using prompt service"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
metadata = {
|
||||
"schema_name": schema_name,
|
||||
"auto_selected_schema": request.schema_name != schema_name
|
||||
}
|
||||
|
||||
if detected_type == "csv":
|
||||
csv_options = detect_csv_options(request.sample)
|
||||
metadata["csv_options"] = json.dumps(csv_options)
|
||||
|
||||
return StructuredDataDiagnosisResponse(
|
||||
error=None,
|
||||
operation=request.operation,
|
||||
detected_type=detected_type,
|
||||
confidence=confidence,
|
||||
descriptor=json.dumps(descriptor),
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
async def schema_selection_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
|
||||
"""Handle schema-selection operation"""
|
||||
logger.info("Processing schema-selection operation")
|
||||
|
||||
# Prepare all schemas for the prompt - match the original config format
|
||||
all_schemas = []
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
schema_info = {
|
||||
"name": row_schema.name,
|
||||
"description": row_schema.description,
|
||||
"fields": [
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"description": f.description,
|
||||
"required": f.required,
|
||||
"primary_key": f.primary,
|
||||
"indexed": f.indexed,
|
||||
"enum": f.enum_values if f.enum_values else [],
|
||||
"size": f.size if hasattr(f, 'size') else 0
|
||||
}
|
||||
for f in row_schema.fields
|
||||
]
|
||||
}
|
||||
all_schemas.append(schema_info)
|
||||
|
||||
# Create prompt variables - schemas array contains ALL schemas
|
||||
# Note: The prompt expects 'question' not 'sample'
|
||||
variables = {
|
||||
"question": request.sample, # The prompt template expects 'question'
|
||||
"schemas": all_schemas,
|
||||
"options": request.options or {}
|
||||
}
|
||||
|
||||
# Call prompt service with configurable template
|
||||
terms = {k: json.dumps(v) for k, v in variables.items()}
|
||||
prompt_request = PromptRequest(
|
||||
id=self.schema_selection_prompt,
|
||||
terms=terms
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"Calling prompt service for schema selection with template: {self.schema_selection_prompt}")
|
||||
response = await flow("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error:
|
||||
logger.error(f"Prompt service error: {response.error.message}")
|
||||
error = Error(
|
||||
type="PromptServiceError",
|
||||
message="Failed to select schemas using prompt service"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
# Check both text and object fields for response
|
||||
response_data = None
|
||||
if response.object and response.object.strip():
|
||||
response_data = response.object.strip()
|
||||
logger.debug(f"Using response from 'object' field: {response_data}")
|
||||
elif response.text and response.text.strip():
|
||||
response_data = response.text.strip()
|
||||
logger.debug(f"Using response from 'text' field: {response_data}")
|
||||
else:
|
||||
logger.error("Empty response from prompt service (checked both text and object fields)")
|
||||
error = Error(
|
||||
type="PromptServiceError",
|
||||
message="Empty response from prompt service"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
# Parse the response as JSON array of schema IDs
|
||||
try:
|
||||
schema_matches = json.loads(response_data)
|
||||
if not isinstance(schema_matches, list):
|
||||
raise ValueError("Response must be an array")
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse schema matches response: {e}")
|
||||
error = Error(
|
||||
type="ParseError",
|
||||
message="Failed to parse schema selection response as JSON array"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
return StructuredDataDiagnosisResponse(
|
||||
error=None,
|
||||
operation=request.operation,
|
||||
schema_matches=schema_matches
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling prompt service: {e}", exc_info=True)
|
||||
error = Error(
|
||||
type="PromptServiceError",
|
||||
message="Failed to select schemas using prompt service"
|
||||
)
|
||||
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
|
||||
|
||||
async def generate_descriptor_with_prompt(
|
||||
self, sample: str, data_type: str, target_schema: RowSchema,
|
||||
options: Dict[str, str], flow
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Generate descriptor using appropriate prompt service"""
|
||||
|
||||
# Select prompt template based on data type
|
||||
prompt_templates = {
|
||||
"csv": self.csv_prompt,
|
||||
"json": self.json_prompt,
|
||||
"xml": self.xml_prompt
|
||||
}
|
||||
|
||||
prompt_id = prompt_templates.get(data_type)
|
||||
if not prompt_id:
|
||||
logger.error(f"No prompt template defined for data type: {data_type}")
|
||||
return None
|
||||
|
||||
# Prepare schema information for prompt
|
||||
schema_info = {
|
||||
"name": target_schema.name,
|
||||
"description": target_schema.description,
|
||||
"fields": [
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"description": f.description,
|
||||
"required": f.required,
|
||||
"primary_key": f.primary,
|
||||
"indexed": f.indexed,
|
||||
"enum_values": f.enum_values if f.enum_values else []
|
||||
}
|
||||
for f in target_schema.fields
|
||||
]
|
||||
}
|
||||
|
||||
# Create prompt variables
|
||||
variables = {
|
||||
"sample": sample,
|
||||
"schemas": [schema_info], # Array with single target schema
|
||||
"options": options or {}
|
||||
}
|
||||
|
||||
# Call prompt service
|
||||
terms = {k: json.dumps(v) for k, v in variables.items()}
|
||||
prompt_request = PromptRequest(
|
||||
id=prompt_id,
|
||||
terms=terms
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"Calling prompt service with template: {prompt_id}")
|
||||
response = await flow("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error:
|
||||
logger.error(f"Prompt service error: {response.error.message}")
|
||||
return None
|
||||
|
||||
# Parse response
|
||||
if response.object:
|
||||
try:
|
||||
return json.loads(response.object)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse prompt response as JSON: {e}")
|
||||
logger.debug(f"Response object: {response.object}")
|
||||
return None
|
||||
elif response.text:
|
||||
try:
|
||||
return json.loads(response.text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse prompt text response as JSON: {e}")
|
||||
logger.debug(f"Response text: {response.text}")
|
||||
return None
|
||||
else:
|
||||
logger.error("Empty response from prompt service")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling prompt service: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for structured-diag command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
"""
|
||||
Algorithmic data type detection for structured data.
|
||||
Determines if data is CSV, JSON, or XML based on content analysis.
|
||||
"""
|
||||
|
||||
import json
|
||||
import xml.etree.ElementTree as ET
|
||||
import csv
|
||||
from io import StringIO
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def detect_data_type(sample: str) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Detect the data type (csv, json, xml) of a data sample.
|
||||
|
||||
Args:
|
||||
sample: String containing data sample to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (detected_type, confidence_score)
|
||||
detected_type: "csv", "json", "xml", or None if unable to determine
|
||||
confidence_score: Float between 0.0 and 1.0 indicating confidence
|
||||
"""
|
||||
if not sample or not sample.strip():
|
||||
return None, 0.0
|
||||
|
||||
sample = sample.strip()
|
||||
|
||||
# Simple pattern matching
|
||||
if sample.startswith('<?xml') or (sample.startswith('<') and '</' in sample):
|
||||
return 'xml', 0.9
|
||||
elif sample.startswith(('{', '[')):
|
||||
return 'json', 0.9
|
||||
else:
|
||||
return 'csv', 0.8
|
||||
|
||||
|
||||
def _check_json_format(sample: str) -> float:
|
||||
"""Check if sample is valid JSON format"""
|
||||
try:
|
||||
# Must start with { or [
|
||||
if not (sample.startswith('{') or sample.startswith('[')):
|
||||
return 0.0
|
||||
|
||||
# Try to parse as JSON
|
||||
data = json.loads(sample)
|
||||
|
||||
# Higher confidence for structured data
|
||||
if isinstance(data, dict):
|
||||
return 0.95
|
||||
elif isinstance(data, list) and len(data) > 0:
|
||||
# Check if it's an array of objects (common for structured data)
|
||||
if isinstance(data[0], dict):
|
||||
return 0.9
|
||||
else:
|
||||
return 0.7
|
||||
else:
|
||||
return 0.6
|
||||
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
|
||||
def _check_xml_format(sample: str) -> float:
|
||||
"""Check if sample is valid XML format"""
|
||||
# XML declaration or starts with tag
|
||||
if sample.startswith('<?xml') or sample.startswith('<'):
|
||||
# Must have closing tags for valid XML
|
||||
if '</' in sample and '>' in sample:
|
||||
try:
|
||||
# Quick parse test
|
||||
ET.fromstring(sample)
|
||||
return 0.9 # Valid XML
|
||||
except ET.ParseError:
|
||||
return 0.3 # Looks like XML but malformed
|
||||
else:
|
||||
return 0.1 # Incomplete XML
|
||||
|
||||
return 0.0 # Not XML
|
||||
|
||||
|
||||
def _check_csv_format(sample: str) -> float:
|
||||
"""Check if sample is valid CSV format"""
|
||||
try:
|
||||
lines = sample.strip().split('\n')
|
||||
if len(lines) < 2:
|
||||
return 0.0
|
||||
|
||||
# Try to parse as CSV with different delimiters
|
||||
delimiters = [',', ';', '\t', '|']
|
||||
best_score = 0.0
|
||||
|
||||
for delimiter in delimiters:
|
||||
score = _check_csv_with_delimiter(sample, delimiter)
|
||||
best_score = max(best_score, score)
|
||||
|
||||
return best_score
|
||||
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def _check_csv_with_delimiter(sample: str, delimiter: str) -> float:
|
||||
"""Check CSV format with specific delimiter"""
|
||||
try:
|
||||
reader = csv.reader(StringIO(sample), delimiter=delimiter)
|
||||
rows = list(reader)
|
||||
|
||||
if len(rows) < 2:
|
||||
return 0.0
|
||||
|
||||
# Check consistency of column counts
|
||||
first_row_cols = len(rows[0])
|
||||
if first_row_cols < 2:
|
||||
return 0.0
|
||||
|
||||
consistent_rows = 0
|
||||
for row in rows[1:]:
|
||||
if len(row) == first_row_cols:
|
||||
consistent_rows += 1
|
||||
|
||||
consistency_ratio = consistent_rows / (len(rows) - 1) if len(rows) > 1 else 0
|
||||
|
||||
# Base score on consistency and structure
|
||||
if consistency_ratio > 0.8:
|
||||
# Higher score for more columns and rows
|
||||
column_bonus = min(first_row_cols * 0.05, 0.2)
|
||||
row_bonus = min(len(rows) * 0.01, 0.1)
|
||||
return min(0.7 + column_bonus + row_bonus, 0.95)
|
||||
elif consistency_ratio > 0.6:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.2
|
||||
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def detect_csv_options(sample: str) -> Dict[str, any]:
|
||||
"""
|
||||
Detect CSV-specific options like delimiter and header presence.
|
||||
|
||||
Args:
|
||||
sample: CSV data sample
|
||||
|
||||
Returns:
|
||||
Dict with detected options: delimiter, has_header, etc.
|
||||
"""
|
||||
options = {
|
||||
"delimiter": ",",
|
||||
"has_header": True,
|
||||
"encoding": "utf-8"
|
||||
}
|
||||
|
||||
try:
|
||||
lines = sample.strip().split('\n')
|
||||
if len(lines) < 2:
|
||||
return options
|
||||
|
||||
# Detect delimiter
|
||||
delimiters = [',', ';', '\t', '|']
|
||||
best_delimiter = ","
|
||||
best_score = 0
|
||||
|
||||
for delimiter in delimiters:
|
||||
score = _check_csv_with_delimiter(sample, delimiter)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_delimiter = delimiter
|
||||
|
||||
options["delimiter"] = best_delimiter
|
||||
|
||||
# Detect header (heuristic: first row has text, second row has more numbers/structured data)
|
||||
reader = csv.reader(StringIO(sample), delimiter=best_delimiter)
|
||||
rows = list(reader)
|
||||
|
||||
if len(rows) >= 2:
|
||||
first_row = rows[0]
|
||||
second_row = rows[1]
|
||||
|
||||
# Count numeric fields in each row
|
||||
first_numeric = sum(1 for cell in first_row if _is_numeric(cell))
|
||||
second_numeric = sum(1 for cell in second_row if _is_numeric(cell))
|
||||
|
||||
# If second row has more numeric values, first row is likely header
|
||||
if second_numeric > first_numeric and first_numeric < len(first_row) * 0.7:
|
||||
options["has_header"] = True
|
||||
else:
|
||||
options["has_header"] = False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error detecting CSV options: {e}")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _is_numeric(value: str) -> bool:
|
||||
"""Check if a string value represents a number"""
|
||||
try:
|
||||
float(value.strip())
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
|
@ -0,0 +1 @@
|
|||
from . service import *
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
run()
|
||||
175
trustgraph-flow/trustgraph/retrieval/structured_query/service.py
Normal file
175
trustgraph-flow/trustgraph/retrieval/structured_query/service.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Structured Query Service - orchestrates natural language question processing.
|
||||
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
|
||||
and returns the results.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from ...schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from ...schema import Error
|
||||
|
||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "structured-query"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = StructuredQueryRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = StructuredQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Client spec for calling NLP query service
|
||||
self.register_specification(
|
||||
RequestResponseSpec(
|
||||
request_name = "nlp-query-request",
|
||||
response_name = "nlp-query-response",
|
||||
request_schema = QuestionToStructuredQueryRequest,
|
||||
response_schema = QuestionToStructuredQueryResponse
|
||||
)
|
||||
)
|
||||
|
||||
# Client spec for calling objects query service
|
||||
self.register_specification(
|
||||
RequestResponseSpec(
|
||||
request_name = "objects-query-request",
|
||||
response_name = "objects-query-response",
|
||||
request_schema = ObjectsQueryRequest,
|
||||
response_schema = ObjectsQueryResponse
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Structured Query service initialized")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming structured query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling structured query request {id}: {request.question[:100]}...")
|
||||
|
||||
# Step 1: Convert question to GraphQL using NLP query service
|
||||
logger.info("Step 1: Converting question to GraphQL")
|
||||
nlp_request = QuestionToStructuredQueryRequest(
|
||||
question=request.question,
|
||||
max_results=100 # Default limit
|
||||
)
|
||||
|
||||
nlp_response = await flow("nlp-query-request").request(nlp_request)
|
||||
|
||||
if nlp_response.error is not None:
|
||||
raise Exception(f"NLP query service error: {nlp_response.error.message}")
|
||||
|
||||
if not nlp_response.graphql_query:
|
||||
raise Exception("NLP query service returned empty GraphQL query")
|
||||
|
||||
logger.info(f"Generated GraphQL query: {nlp_response.graphql_query[:200]}...")
|
||||
logger.info(f"Detected schemas: {nlp_response.detected_schemas}")
|
||||
logger.info(f"Confidence: {nlp_response.confidence}")
|
||||
|
||||
# Step 2: Execute GraphQL query using objects query service
|
||||
logger.info("Step 2: Executing GraphQL query")
|
||||
|
||||
# Convert variables to strings (GraphQL variables can be various types, but Pulsar schema expects strings)
|
||||
variables_as_strings = {}
|
||||
if nlp_response.variables:
|
||||
for key, value in nlp_response.variables.items():
|
||||
if isinstance(value, str):
|
||||
variables_as_strings[key] = value
|
||||
else:
|
||||
variables_as_strings[key] = str(value)
|
||||
|
||||
# Use user/collection values from request
|
||||
objects_request = ObjectsQueryRequest(
|
||||
user=request.user,
|
||||
collection=request.collection,
|
||||
query=nlp_response.graphql_query,
|
||||
variables=variables_as_strings,
|
||||
operation_name=None
|
||||
)
|
||||
|
||||
objects_response = await flow("objects-query-request").request(objects_request)
|
||||
|
||||
if objects_response.error is not None:
|
||||
raise Exception(f"Objects query service error: {objects_response.error.message}")
|
||||
|
||||
# Handle GraphQL errors from the objects query service
|
||||
graphql_errors = []
|
||||
if objects_response.errors:
|
||||
for gql_error in objects_response.errors:
|
||||
graphql_errors.append(f"{gql_error.message} (path: {gql_error.path})")
|
||||
|
||||
logger.info("Step 3: Returning results")
|
||||
|
||||
# Create response
|
||||
response = StructuredQueryResponse(
|
||||
error=None,
|
||||
data=objects_response.data or "null", # JSON string
|
||||
errors=graphql_errors
|
||||
)
|
||||
|
||||
logger.info("Sending structured query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.info("Structured query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in structured query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = StructuredQueryResponse(
|
||||
error = Error(
|
||||
type = "structured-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
data = "null",
|
||||
errors = []
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
# No additional arguments needed for this orchestrator service
|
||||
|
||||
def run():
|
||||
"""Entry point for structured-query command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -3,8 +3,17 @@
|
|||
Accepts entity/vector pairs and writes them to a Milvus store.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .... direct.milvus_doc_embeddings import DocVectors
|
||||
from .... base import DocumentEmbeddingsStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import vector_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
|
@ -23,6 +32,34 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
self.vecstore = DocVectors(store_uri)
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=vector_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
for emb in message.chunks:
|
||||
|
|
@ -33,7 +70,11 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
self.vecstore.insert(vec, chunk)
|
||||
self.vecstore.insert(
|
||||
vec, chunk,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
@ -46,6 +87,48 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
self.vecstore.delete_collection(message.user, message.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ import os
|
|||
import logging
|
||||
|
||||
from .... base import DocumentEmbeddingsStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import vector_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -55,6 +59,34 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
self.last_index_name = None
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=vector_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def create_index(self, index_name, dim):
|
||||
|
||||
self.pinecone.create_index(
|
||||
|
|
@ -96,7 +128,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
|
@ -160,6 +192,54 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Pinecone region, (default: {default_region}'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
index_name = f"d-{message.user}-{message.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
self.pinecone.delete_index(index_name)
|
||||
logger.info(f"Deleted Pinecone index: {index_name}")
|
||||
else:
|
||||
logger.info(f"Index {index_name} does not exist, nothing to delete")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import uuid
|
|||
import logging
|
||||
|
||||
from .... base import DocumentEmbeddingsStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import vector_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -36,6 +40,37 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
# Set up storage management if base class attributes are available
|
||||
# (they may not be in unit tests)
|
||||
if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'):
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=vector_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
for emb in message.chunks:
|
||||
|
|
@ -48,8 +83,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + message.metadata.user + "_" +
|
||||
message.metadata.collection + "_" +
|
||||
str(dim)
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
if collection != self.last_collection:
|
||||
|
|
@ -99,6 +133,54 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Qdrant API key (default: None)'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
collection_name = f"d_{message.user}_{message.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
else:
|
||||
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,17 @@
|
|||
Accepts entity/vector pairs and writes them to a Milvus store.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... base import GraphEmbeddingsStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import vector_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
|
@ -23,13 +32,45 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
self.vecstore = EntityVectors(store_uri)
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=vector_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value != "" and entity.entity.value is not None:
|
||||
for vec in entity.vectors:
|
||||
self.vecstore.insert(vec, entity.entity.value)
|
||||
self.vecstore.insert(
|
||||
vec, entity.entity.value,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
@ -42,6 +83,48 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
self.vecstore.delete_collection(message.user, message.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ import os
|
|||
import logging
|
||||
|
||||
from .... base import GraphEmbeddingsStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import vector_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -55,6 +59,34 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
self.last_index_name = None
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=vector_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def create_index(self, index_name, dim):
|
||||
|
||||
self.pinecone.create_index(
|
||||
|
|
@ -95,7 +127,7 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
|
@ -159,6 +191,54 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Pinecone region, (default: {default_region}'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
index_name = f"t-{message.user}-{message.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
self.pinecone.delete_index(index_name)
|
||||
logger.info(f"Deleted Pinecone index: {index_name}")
|
||||
else:
|
||||
logger.info(f"Index {index_name} does not exist, nothing to delete")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import uuid
|
|||
import logging
|
||||
|
||||
from .... base import GraphEmbeddingsStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import vector_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -36,10 +40,41 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
# Set up storage management if base class attributes are available
|
||||
# (they may not be in unit tests)
|
||||
if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'):
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=vector_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def get_collection(self, dim, user, collection):
|
||||
|
||||
cname = (
|
||||
"t_" + user + "_" + collection + "_" + str(dim)
|
||||
"t_" + user + "_" + collection
|
||||
)
|
||||
|
||||
if cname != self.last_collection:
|
||||
|
|
@ -105,6 +140,54 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Qdrant API key'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
collection_name = f"t_{message.user}_{message.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
else:
|
||||
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ import urllib.parse
|
|||
|
||||
from ... schema import Triples, GraphEmbeddings
|
||||
from ... base import FlowProcessor, ConsumerSpec
|
||||
from ... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
from ... tables.knowledge import KnowledgeTableStore
|
||||
|
||||
default_ident = "kg-store"
|
||||
|
||||
default_cassandra_host = "cassandra"
|
||||
keyspace = "knowledge"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
|
@ -22,15 +22,18 @@ class Processor(FlowProcessor):
|
|||
|
||||
id = params.get("id")
|
||||
|
||||
cassandra_host = params.get("cassandra_host", default_cassandra_host)
|
||||
cassandra_user = params.get("cassandra_user")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
# Use helper to resolve configuration
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=params.get("cassandra_host"),
|
||||
username=params.get("cassandra_username"),
|
||||
password=params.get("cassandra_password")
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"cassandra_host": cassandra_host,
|
||||
"cassandra_user": cassandra_user,
|
||||
"cassandra_host": ','.join(hosts),
|
||||
"cassandra_username": username,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -51,9 +54,9 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
self.table_store = KnowledgeTableStore(
|
||||
cassandra_host = cassandra_host.split(","),
|
||||
cassandra_user = cassandra_user,
|
||||
cassandra_password = cassandra_password,
|
||||
cassandra_host = hosts,
|
||||
cassandra_username = username,
|
||||
cassandra_password = password,
|
||||
keyspace = keyspace,
|
||||
)
|
||||
|
||||
|
|
@ -71,6 +74,7 @@ class Processor(FlowProcessor):
|
|||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . write import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . write import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
|
||||
"""
|
||||
Accepts entity/vector pairs and writes them to a Milvus store.
|
||||
"""
|
||||
|
||||
from .... schema import ObjectEmbeddings
|
||||
from .... schema import object_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... direct.milvus_object_embeddings import ObjectVectors
|
||||
from .... base import Consumer
|
||||
|
||||
module = "oe-write"
|
||||
|
||||
default_input_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ObjectEmbeddings,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = ObjectVectors(store_uri)
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
if v.id != "" and v.id is not None:
|
||||
for vec in v.vectors:
|
||||
self.vecstore.insert(vec, v.name, v.key_name, v.id)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
@ -13,13 +13,15 @@ from cassandra import ConsistencyLevel
|
|||
|
||||
from .... schema import ExtractedObject
|
||||
from .... schema import RowSchema, Field
|
||||
from .... base import FlowProcessor, ConsumerSpec
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse
|
||||
from .... schema import object_storage_management_topic, storage_management_response_topic
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "objects-write"
|
||||
default_graph_host = 'localhost'
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -27,10 +29,22 @@ class Processor(FlowProcessor):
|
|||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Cassandra connection parameters
|
||||
self.graph_host = params.get("graph_host", default_graph_host)
|
||||
self.graph_username = params.get("graph_username", None)
|
||||
self.graph_password = params.get("graph_password", None)
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
|
@ -38,7 +52,7 @@ class Processor(FlowProcessor):
|
|||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config-type": self.config_key,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -49,7 +63,38 @@ class Processor(FlowProcessor):
|
|||
handler = self.on_object
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Set up storage management consumer and producer directly
|
||||
# (FlowProcessor doesn't support topic-based specs outside of flows)
|
||||
from .... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Create storage management consumer
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=object_storage_management_topic,
|
||||
subscriber=f"{id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Create storage management response producer
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
|
|
@ -70,20 +115,20 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
|
||||
try:
|
||||
if self.graph_username and self.graph_password:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.graph_username,
|
||||
password=self.graph_password
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=[self.graph_host],
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=[self.graph_host])
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.graph_host}")
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
|
|
@ -299,7 +344,7 @@ class Processor(FlowProcessor):
|
|||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(f"Storing object for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
|
|
@ -316,59 +361,161 @@ class Processor(FlowProcessor):
|
|||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Build column names and values
|
||||
columns = ["collection"]
|
||||
values = [obj.metadata.collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
values.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Process fields
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = obj.values.get(field.name)
|
||||
# Process each object in the batch
|
||||
for obj_index, value_map in enumerate(obj.values):
|
||||
# Build column names and values for this object
|
||||
columns = ["collection"]
|
||||
values = [obj.metadata.collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Handle required fields
|
||||
if field.required and raw_value is None:
|
||||
logger.warning(f"Required field {field.name} is missing in object")
|
||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
values.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Check if primary key field is NULL
|
||||
if field.primary and raw_value is None:
|
||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object")
|
||||
return
|
||||
# Process fields for this object
|
||||
skip_object = False
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = value_map.get(field.name)
|
||||
|
||||
# Handle required fields
|
||||
if field.required and raw_value is None:
|
||||
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
|
||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
||||
|
||||
# Check if primary key field is NULL
|
||||
if field.primary and raw_value is None:
|
||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
|
||||
skip_object = True
|
||||
break
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
|
||||
columns.append(safe_field_name)
|
||||
values.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
# Skip this object if primary key validation failed
|
||||
if skip_object:
|
||||
continue
|
||||
|
||||
columns.append(safe_field_name)
|
||||
values.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Build and execute insert query
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
# Debug: Show data being inserted
|
||||
logger.debug(f"Storing {obj.schema_name}: {dict(zip(columns, values))}")
|
||||
|
||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
||||
|
||||
# Build and execute insert query for this object
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
# Debug: Show data being inserted
|
||||
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
|
||||
|
||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
||||
|
||||
try:
|
||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
||||
self.session.execute(insert_cql, tuple(values))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def on_storage_management(self, msg, consumer, flow):
|
||||
"""Handle storage management requests for collection operations"""
|
||||
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
|
||||
|
||||
try:
|
||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
||||
self.session.execute(insert_cql, tuple(values))
|
||||
if msg.operation == "delete-collection":
|
||||
await self.delete_collection(msg.user, msg.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
|
||||
else:
|
||||
logger.warning(f"Unknown storage management operation: {msg.operation}")
|
||||
# Send error response
|
||||
from .... schema import Error
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="unknown_operation",
|
||||
message=f"Unknown operation: {msg.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert object: {e}", exc_info=True)
|
||||
raise
|
||||
logger.error(f"Error handling storage management request: {e}", exc_info=True)
|
||||
# Send error response
|
||||
from .... schema import Error
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.send("storage-response", response)
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Check if keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
# Query to verify keyspace exists
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
if not result.one():
|
||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||
return
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# Get all tables in the keyspace that might contain collection data
|
||||
get_tables_cql = """
|
||||
SELECT table_name FROM system_schema.tables
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
|
||||
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
|
||||
tables_deleted = 0
|
||||
|
||||
for row in tables:
|
||||
table_name = row.table_name
|
||||
|
||||
# Check if the table has a collection column
|
||||
check_column_cql = """
|
||||
SELECT column_name FROM system_schema.columns
|
||||
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
|
||||
"""
|
||||
|
||||
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
|
||||
if result.one():
|
||||
# Table has collection column, delete data for this collection
|
||||
try:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{table_name}
|
||||
WHERE collection = %s
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection,))
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
|
|
@ -381,24 +528,7 @@ class Processor(FlowProcessor):
|
|||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default=default_graph_host,
|
||||
help=f'Cassandra host (default: {default_graph_host})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-username',
|
||||
default=None,
|
||||
help='Cassandra username'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-password',
|
||||
default=None,
|
||||
help='Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
||||
"""
|
||||
|
||||
raise RuntimeError("This code is no longer in use")
|
||||
|
||||
import pulsar
|
||||
import base64
|
||||
import os
|
||||
|
|
@ -14,9 +16,9 @@ from cassandra.auth import PlainTextAuthProvider
|
|||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
|
||||
from .... schema import Rows
|
||||
from .... schema import rows_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -24,9 +26,8 @@ logger = logging.getLogger(__name__)
|
|||
module = "rows-write"
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
|
||||
default_input_queue = rows_store_queue
|
||||
default_input_queue = "rows-store" # Default queue name
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
|
|
@ -34,26 +35,35 @@ class Processor(Consumer):
|
|||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
graph_username = params.get("graph_username", None)
|
||||
graph_password = params.get("graph_password", None)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Rows,
|
||||
"graph_host": graph_host,
|
||||
"graph_username": graph_username,
|
||||
"graph_password": graph_password,
|
||||
"cassandra_host": ','.join(hosts),
|
||||
"cassandra_username": username,
|
||||
"cassandra_password": password,
|
||||
}
|
||||
)
|
||||
|
||||
if graph_username and graph_password:
|
||||
auth_provider = PlainTextAuthProvider(username=graph_username, password=graph_password)
|
||||
self.cluster = Cluster(graph_host.split(","), auth_provider=auth_provider, ssl_context=ssl_context)
|
||||
if username and password:
|
||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||
else:
|
||||
self.cluster = Cluster(graph_host.split(","))
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.tables = set()
|
||||
|
|
@ -128,24 +138,7 @@ class Processor(Consumer):
|
|||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default="localhost",
|
||||
help=f'Graph host (default: localhost)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-username',
|
||||
default=None,
|
||||
help=f'Cassandra username'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-password',
|
||||
default=None,
|
||||
help=f'Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
|
|
|
|||
|
|
@ -10,15 +10,19 @@ import argparse
|
|||
import time
|
||||
import logging
|
||||
|
||||
from .... direct.cassandra import TrustGraph
|
||||
from .... direct.cassandra_kg import KnowledgeGraph
|
||||
from .... base import TriplesStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import triples_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-write"
|
||||
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(TriplesStoreService):
|
||||
|
||||
|
|
@ -26,80 +30,175 @@ class Processor(TriplesStoreService):
|
|||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
graph_username = params.get("graph_username", None)
|
||||
graph_password = params.get("graph_password", None)
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"graph_host": graph_host,
|
||||
"graph_username": graph_username
|
||||
"cassandra_host": ','.join(hosts),
|
||||
"cassandra_username": username
|
||||
}
|
||||
)
|
||||
|
||||
self.graph_host = [graph_host]
|
||||
self.username = graph_username
|
||||
self.password = graph_password
|
||||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=triples_storage_management_topic,
|
||||
subscriber=f"{id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def store_triples(self, message):
|
||||
|
||||
table = (message.metadata.user, message.metadata.collection)
|
||||
user = message.metadata.user
|
||||
|
||||
if self.table is None or self.table != table:
|
||||
if self.table is None or self.table != user:
|
||||
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
if self.username and self.password:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.metadata.user,
|
||||
table=message.metadata.collection,
|
||||
username=self.username, password=self.password
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.metadata.user,
|
||||
table=message.metadata.collection,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}", exc_info=True)
|
||||
time.sleep(1)
|
||||
raise e
|
||||
|
||||
self.table = table
|
||||
self.table = user
|
||||
|
||||
for t in message.triples:
|
||||
self.tg.insert(
|
||||
message.metadata.collection,
|
||||
t.s.value,
|
||||
t.p.value,
|
||||
t.o.value
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete all data for a specific collection from the unified triples table"""
|
||||
try:
|
||||
# Create or reuse connection for this user's keyspace
|
||||
if self.table is None or self.table != message.user:
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
|
||||
raise
|
||||
|
||||
self.table = message.user
|
||||
|
||||
# Delete all triples for this collection from the unified table
|
||||
# In the unified table schema, collection is the partition key
|
||||
delete_cql = """
|
||||
DELETE FROM triples
|
||||
WHERE collection = ?
|
||||
"""
|
||||
|
||||
try:
|
||||
self.tg.session.execute(delete_cql, (message.collection,))
|
||||
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection data: {e}")
|
||||
raise
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
TriplesStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default="localhost",
|
||||
help=f'Graph host (default: localhost)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-username',
|
||||
default=None,
|
||||
help=f'Cassandra username'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-password',
|
||||
default=None,
|
||||
help=f'Cassandra password'
|
||||
)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ import logging
|
|||
from falkordb import FalkorDB
|
||||
|
||||
from .... base import TriplesStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import triples_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -40,14 +44,44 @@ class Processor(TriplesStoreService):
|
|||
|
||||
self.io = FalkorDB.from_url(graph_url).select_graph(database)
|
||||
|
||||
def create_node(self, uri):
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
logger.debug(f"Create node {uri}")
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=triples_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def create_node(self, uri, user, collection):
|
||||
|
||||
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
|
||||
|
||||
res = self.io.query(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
params={
|
||||
"uri": uri,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -56,14 +90,16 @@ class Processor(TriplesStoreService):
|
|||
time=res.run_time_ms
|
||||
))
|
||||
|
||||
def create_literal(self, value):
|
||||
def create_literal(self, value, user, collection):
|
||||
|
||||
logger.debug(f"Create literal {value}")
|
||||
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
|
||||
|
||||
res = self.io.query(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
params={
|
||||
"value": value,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -72,18 +108,20 @@ class Processor(TriplesStoreService):
|
|||
time=res.run_time_ms
|
||||
))
|
||||
|
||||
def relate_node(self, src, uri, dest):
|
||||
def relate_node(self, src, uri, dest, user, collection):
|
||||
|
||||
logger.debug(f"Create node rel {src} {uri} {dest}")
|
||||
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
|
||||
|
||||
res = self.io.query(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
params={
|
||||
"src": src,
|
||||
"dest": dest,
|
||||
"uri": uri,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -92,18 +130,20 @@ class Processor(TriplesStoreService):
|
|||
time=res.run_time_ms
|
||||
))
|
||||
|
||||
def relate_literal(self, src, uri, dest):
|
||||
def relate_literal(self, src, uri, dest, user, collection):
|
||||
|
||||
logger.debug(f"Create literal rel {src} {uri} {dest}")
|
||||
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
|
||||
|
||||
res = self.io.query(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
params={
|
||||
"src": src,
|
||||
"dest": dest,
|
||||
"uri": uri,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -113,17 +153,20 @@ class Processor(TriplesStoreService):
|
|||
))
|
||||
|
||||
async def store_triples(self, message):
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value)
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
||||
if t.o.is_uri:
|
||||
self.create_node(t.o.value)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value)
|
||||
self.create_node(t.o.value, user, collection)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
else:
|
||||
self.create_literal(t.o.value)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value)
|
||||
self.create_literal(t.o.value, user, collection)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
@ -142,6 +185,59 @@ class Processor(TriplesStoreService):
|
|||
help=f'FalkorDB database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete the collection for FalkorDB triples"""
|
||||
try:
|
||||
# Delete all nodes and literals for this user/collection
|
||||
node_result = self.io.query(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
|
||||
params={"user": message.user, "collection": message.collection}
|
||||
)
|
||||
|
||||
literal_result = self.io.query(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
|
||||
params={"user": message.user, "collection": message.collection}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ import logging
|
|||
from neo4j import GraphDatabase
|
||||
|
||||
from .... base import TriplesStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import triples_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -49,6 +53,34 @@ class Processor(TriplesStoreService):
|
|||
with self.io.session(database=self.db) as session:
|
||||
self.create_indexes(session)
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=triples_storage_management_topic,
|
||||
subscriber=f"{self.id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def create_indexes(self, session):
|
||||
|
||||
# Race condition, index creation failure is ignored. Right thing
|
||||
|
|
@ -61,6 +93,7 @@ class Processor(TriplesStoreService):
|
|||
|
||||
logger.info("Create indexes...")
|
||||
|
||||
# Legacy indexes for backwards compatibility
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX ON :Node",
|
||||
|
|
@ -97,15 +130,48 @@ class Processor(TriplesStoreService):
|
|||
# Maybe index already exists
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
# New indexes for user/collection filtering
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX ON :Node(user)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"User index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX ON :Node(collection)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Collection index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX ON :Literal(user)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"User index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX ON :Literal(collection)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Collection index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
logger.info("Index creation done")
|
||||
|
||||
def create_node(self, uri):
|
||||
def create_node(self, uri, user, collection):
|
||||
|
||||
logger.debug(f"Create node {uri}")
|
||||
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri=uri,
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
uri=uri, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -114,13 +180,13 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def create_literal(self, value):
|
||||
def create_literal(self, value, user, collection):
|
||||
|
||||
logger.debug(f"Create literal {value}")
|
||||
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value=value,
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value=value, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -129,15 +195,15 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def relate_node(self, src, uri, dest):
|
||||
def relate_node(self, src, uri, dest, user, collection):
|
||||
|
||||
logger.debug(f"Create node rel {src} {uri} {dest}")
|
||||
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src, dest=dest, uri=uri,
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=src, dest=dest, uri=uri, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -146,15 +212,15 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def relate_literal(self, src, uri, dest):
|
||||
def relate_literal(self, src, uri, dest, user, collection):
|
||||
|
||||
logger.debug(f"Create literal rel {src} {uri} {dest}")
|
||||
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src, dest=dest, uri=uri,
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=src, dest=dest, uri=uri, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -163,59 +229,64 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def create_triple(self, tx, t):
|
||||
def create_triple(self, tx, t, user, collection):
|
||||
|
||||
# Create new s node with given uri, if not exists
|
||||
result = tx.run(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri=t.s.value
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
uri=t.s.value, user=user, collection=collection
|
||||
)
|
||||
|
||||
if t.o.is_uri:
|
||||
|
||||
# Create new o node with given uri, if not exists
|
||||
result = tx.run(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri=t.o.value
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
uri=t.o.value, user=user, collection=collection
|
||||
)
|
||||
|
||||
result = tx.run(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=t.s.value, dest=t.o.value, uri=t.p.value,
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# Create new o literal with given uri, if not exists
|
||||
result = tx.run(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value=t.o.value
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value=t.o.value, user=user, collection=collection
|
||||
)
|
||||
|
||||
result = tx.run(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=t.s.value, dest=t.o.value, uri=t.p.value,
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
|
||||
)
|
||||
|
||||
async def store_triples(self, message):
|
||||
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
# self.create_node(t.s.value)
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
||||
# if t.o.is_uri:
|
||||
# self.create_node(t.o.value)
|
||||
# self.relate_node(t.s.value, t.p.value, t.o.value)
|
||||
# else:
|
||||
# self.create_literal(t.o.value)
|
||||
# self.relate_literal(t.s.value, t.p.value, t.o.value)
|
||||
if t.o.is_uri:
|
||||
self.create_node(t.o.value, user, collection)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
else:
|
||||
self.create_literal(t.o.value, user, collection)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
|
||||
with self.io.session(database=self.db) as session:
|
||||
session.execute_write(self.create_triple, t)
|
||||
# Alternative implementation using transactions
|
||||
# with self.io.session(database=self.db) as session:
|
||||
# session.execute_write(self.create_triple, t, user, collection)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
@ -246,6 +317,67 @@ class Processor(TriplesStoreService):
|
|||
help=f'Memgraph database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete all data for a specific collection"""
|
||||
try:
|
||||
with self.io.session(database=self.db) as session:
|
||||
# Delete all nodes for this user and collection
|
||||
node_result = session.run(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
)
|
||||
nodes_deleted = node_result.consume().counters.nodes_deleted
|
||||
|
||||
# Delete all literals for this user and collection
|
||||
literal_result = session.run(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
)
|
||||
literals_deleted = literal_result.consume().counters.nodes_deleted
|
||||
|
||||
# Note: Relationships are automatically deleted with DETACH DELETE
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ import logging
|
|||
|
||||
from neo4j import GraphDatabase
|
||||
from .... base import TriplesStoreService
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
|
||||
from .... schema import triples_storage_management_topic, storage_management_response_topic
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -49,6 +53,34 @@ class Processor(TriplesStoreService):
|
|||
with self.io.session(database=self.db) as session:
|
||||
self.create_indexes(session)
|
||||
|
||||
# Set up metrics for storage management
|
||||
storage_request_metrics = ConsumerMetrics(
|
||||
processor=self.id, flow=None, name="storage-request"
|
||||
)
|
||||
storage_response_metrics = ProducerMetrics(
|
||||
processor=self.id, flow=None, name="storage-response"
|
||||
)
|
||||
|
||||
# Set up consumer for storage management requests
|
||||
self.storage_request_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
client=self.pulsar_client,
|
||||
flow=None,
|
||||
topic=triples_storage_management_topic,
|
||||
subscriber=f"{id}-storage",
|
||||
schema=StorageManagementRequest,
|
||||
handler=self.on_storage_management,
|
||||
metrics=storage_request_metrics,
|
||||
)
|
||||
|
||||
# Set up producer for storage management responses
|
||||
self.storage_response_producer = Producer(
|
||||
client=self.pulsar_client,
|
||||
topic=storage_management_response_topic,
|
||||
schema=StorageManagementResponse,
|
||||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def create_indexes(self, session):
|
||||
|
||||
# Race condition, index creation failure is ignored. Right thing
|
||||
|
|
@ -61,6 +93,7 @@ class Processor(TriplesStoreService):
|
|||
|
||||
logger.info("Create indexes...")
|
||||
|
||||
# Legacy indexes for backwards compatibility
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
|
||||
|
|
@ -88,15 +121,50 @@ class Processor(TriplesStoreService):
|
|||
# Maybe index already exists
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
# New compound indexes for user/collection filtering
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Compound index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Compound index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
# Note: Neo4j doesn't support compound indexes on relationships in all versions
|
||||
# Try to create individual indexes on relationship properties
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Relationship index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
try:
|
||||
session.run(
|
||||
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Relationship index create failure: {e}")
|
||||
logger.warning("Index create failure ignored")
|
||||
|
||||
logger.info("Index creation done")
|
||||
|
||||
def create_node(self, uri):
|
||||
def create_node(self, uri, user, collection):
|
||||
|
||||
logger.debug(f"Create node {uri}")
|
||||
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri=uri,
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
uri=uri, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -105,13 +173,13 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def create_literal(self, value):
|
||||
def create_literal(self, value, user, collection):
|
||||
|
||||
logger.debug(f"Create literal {value}")
|
||||
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value=value,
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value=value, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -120,15 +188,15 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def relate_node(self, src, uri, dest):
|
||||
def relate_node(self, src, uri, dest, user, collection):
|
||||
|
||||
logger.debug(f"Create node rel {src} {uri} {dest}")
|
||||
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src, dest=dest, uri=uri,
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=src, dest=dest, uri=uri, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -137,15 +205,15 @@ class Processor(TriplesStoreService):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
def relate_literal(self, src, uri, dest):
|
||||
def relate_literal(self, src, uri, dest, user, collection):
|
||||
|
||||
logger.debug(f"Create literal rel {src} {uri} {dest}")
|
||||
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
|
||||
|
||||
summary = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src, dest=dest, uri=uri,
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=src, dest=dest, uri=uri, user=user, collection=collection,
|
||||
database_=self.db,
|
||||
).summary
|
||||
|
||||
|
|
@ -156,16 +224,20 @@ class Processor(TriplesStoreService):
|
|||
|
||||
async def store_triples(self, message):
|
||||
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value)
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
||||
if t.o.is_uri:
|
||||
self.create_node(t.o.value)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value)
|
||||
self.create_node(t.o.value, user, collection)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
else:
|
||||
self.create_literal(t.o.value)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value)
|
||||
self.create_literal(t.o.value, user, collection)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
@ -196,6 +268,67 @@ class Processor(TriplesStoreService):
|
|||
help=f'Neo4j database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing storage management request: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="processing_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete all data for a specific collection"""
|
||||
try:
|
||||
with self.io.session(database=self.db) as session:
|
||||
# Delete all nodes for this user and collection
|
||||
node_result = session.run(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
)
|
||||
nodes_deleted = node_result.consume().counters.nodes_deleted
|
||||
|
||||
# Delete all literals for this user and collection
|
||||
literal_result = session.run(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
)
|
||||
literals_deleted = literal_result.consume().counters.nodes_deleted
|
||||
|
||||
# Note: Relationships are automatically deleted with DETACH DELETE
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
raise
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -17,17 +17,21 @@ class ConfigTableStore:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cassandra_host, cassandra_user, cassandra_password, keyspace,
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace,
|
||||
):
|
||||
|
||||
self.keyspace = keyspace
|
||||
|
||||
logger.info("Connecting to Cassandra...")
|
||||
|
||||
if cassandra_user and cassandra_password:
|
||||
# Ensure cassandra_host is a list
|
||||
if isinstance(cassandra_host, str):
|
||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||
|
||||
if cassandra_username and cassandra_password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=cassandra_user, password=cassandra_password
|
||||
username=cassandra_username, password=cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
cassandra_host,
|
||||
|
|
|
|||
|
|
@ -17,17 +17,21 @@ class KnowledgeTableStore:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cassandra_host, cassandra_user, cassandra_password, keyspace,
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace,
|
||||
):
|
||||
|
||||
self.keyspace = keyspace
|
||||
|
||||
logger.info("Connecting to Cassandra...")
|
||||
|
||||
if cassandra_user and cassandra_password:
|
||||
# Ensure cassandra_host is a list
|
||||
if isinstance(cassandra_host, str):
|
||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||
|
||||
if cassandra_username and cassandra_password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=cassandra_user, password=cassandra_password
|
||||
username=cassandra_username, password=cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
cassandra_host,
|
||||
|
|
|
|||
|
|
@ -21,17 +21,21 @@ class LibraryTableStore:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cassandra_host, cassandra_user, cassandra_password, keyspace,
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace,
|
||||
):
|
||||
|
||||
self.keyspace = keyspace
|
||||
|
||||
logger.info("Connecting to Cassandra...")
|
||||
|
||||
if cassandra_user and cassandra_password:
|
||||
# Ensure cassandra_host is a list
|
||||
if isinstance(cassandra_host, str):
|
||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||
|
||||
if cassandra_username and cassandra_password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=cassandra_user, password=cassandra_password
|
||||
username=cassandra_username, password=cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
cassandra_host,
|
||||
|
|
@ -107,6 +111,21 @@ class LibraryTableStore:
|
|||
);
|
||||
""");
|
||||
|
||||
logger.debug("collections table...")
|
||||
|
||||
self.cassandra.execute("""
|
||||
CREATE TABLE IF NOT EXISTS collections (
|
||||
user text,
|
||||
collection text,
|
||||
name text,
|
||||
description text,
|
||||
tags set<text>,
|
||||
created_at timestamp,
|
||||
updated_at timestamp,
|
||||
PRIMARY KEY (user, collection)
|
||||
);
|
||||
""");
|
||||
|
||||
logger.info("Cassandra schema OK.")
|
||||
|
||||
def prepare_statements(self):
|
||||
|
|
@ -183,6 +202,43 @@ class LibraryTableStore:
|
|||
LIMIT 1
|
||||
""")
|
||||
|
||||
# Collection management statements
|
||||
self.insert_collection_stmt = self.cassandra.prepare("""
|
||||
INSERT INTO collections
|
||||
(user, collection, name, description, tags, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""")
|
||||
|
||||
self.update_collection_stmt = self.cassandra.prepare("""
|
||||
UPDATE collections
|
||||
SET name = ?, description = ?, tags = ?, updated_at = ?
|
||||
WHERE user = ? AND collection = ?
|
||||
""")
|
||||
|
||||
self.get_collection_stmt = self.cassandra.prepare("""
|
||||
SELECT collection, name, description, tags, created_at, updated_at
|
||||
FROM collections
|
||||
WHERE user = ? AND collection = ?
|
||||
""")
|
||||
|
||||
self.list_collections_stmt = self.cassandra.prepare("""
|
||||
SELECT collection, name, description, tags, created_at, updated_at
|
||||
FROM collections
|
||||
WHERE user = ?
|
||||
""")
|
||||
|
||||
self.delete_collection_stmt = self.cassandra.prepare("""
|
||||
DELETE FROM collections
|
||||
WHERE user = ? AND collection = ?
|
||||
""")
|
||||
|
||||
self.collection_exists_stmt = self.cassandra.prepare("""
|
||||
SELECT collection
|
||||
FROM collections
|
||||
WHERE user = ? AND collection = ?
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
self.list_processing_stmt = self.cassandra.prepare("""
|
||||
SELECT
|
||||
id, document_id, time, flow, collection, tags
|
||||
|
|
@ -517,3 +573,145 @@ class LibraryTableStore:
|
|||
|
||||
return lst
|
||||
|
||||
|
||||
|
||||
# Collection management methods
|
||||
|
||||
async def ensure_collection_exists(self, user, collection):
|
||||
"""Ensure collection metadata record exists, create if not"""
|
||||
try:
|
||||
resp = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.collection_exists_stmt, [user, collection]
|
||||
)
|
||||
if resp:
|
||||
return
|
||||
import datetime
|
||||
now = datetime.datetime.now()
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.insert_collection_stmt,
|
||||
[user, collection, collection, "", set(), now, now]
|
||||
)
|
||||
logger.debug(f"Created collection metadata for {user}/{collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
raise
|
||||
|
||||
async def list_collections(self, user, tag_filter=None):
|
||||
"""List collections for a user, optionally filtered by tags"""
|
||||
try:
|
||||
resp = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.list_collections_stmt, [user]
|
||||
)
|
||||
collections = []
|
||||
for row in resp:
|
||||
collection_data = {
|
||||
"user": user,
|
||||
"collection": row[0],
|
||||
"name": row[1] or row[0],
|
||||
"description": row[2] or "",
|
||||
"tags": list(row[3]) if row[3] else [],
|
||||
"created_at": row[4].isoformat() if row[4] else "",
|
||||
"updated_at": row[5].isoformat() if row[5] else ""
|
||||
}
|
||||
if tag_filter:
|
||||
collection_tags = set(collection_data["tags"])
|
||||
filter_tags = set(tag_filter)
|
||||
if not filter_tags.intersection(collection_tags):
|
||||
continue
|
||||
collections.append(collection_data)
|
||||
return collections
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing collections: {e}")
|
||||
raise
|
||||
|
||||
async def update_collection(self, user, collection, name=None, description=None, tags=None):
|
||||
"""Update collection metadata"""
|
||||
try:
|
||||
resp = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.get_collection_stmt, [user, collection]
|
||||
)
|
||||
if not resp:
|
||||
raise RequestError(f"Collection {collection} not found")
|
||||
row = resp.one()
|
||||
current_name = row[1] or collection
|
||||
current_description = row[2] or ""
|
||||
current_tags = set(row[3]) if row[3] else set()
|
||||
new_name = name if name is not None else current_name
|
||||
new_description = description if description is not None else current_description
|
||||
new_tags = set(tags) if tags is not None else current_tags
|
||||
import datetime
|
||||
now = datetime.datetime.now()
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.update_collection_stmt,
|
||||
[new_name, new_description, new_tags, now, user, collection]
|
||||
)
|
||||
return {
|
||||
"user": user, "collection": collection, "name": new_name,
|
||||
"description": new_description, "tags": list(new_tags),
|
||||
"updated_at": now.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating collection: {e}")
|
||||
raise
|
||||
|
||||
async def delete_collection(self, user, collection):
|
||||
"""Delete collection metadata record"""
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.delete_collection_stmt, [user, collection]
|
||||
)
|
||||
logger.debug(f"Deleted collection metadata for {user}/{collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting collection metadata: {e}")
|
||||
raise
|
||||
|
||||
async def get_collection(self, user, collection):
|
||||
"""Get collection metadata"""
|
||||
try:
|
||||
resp = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.get_collection_stmt, [user, collection]
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
row = resp.one()
|
||||
return {
|
||||
"user": user, "collection": row[0], "name": row[1] or row[0],
|
||||
"description": row[2] or "", "tags": list(row[3]) if row[3] else [],
|
||||
"created_at": row[4].isoformat() if row[4] else "",
|
||||
"updated_at": row[5].isoformat() if row[5] else ""
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collection: {e}")
|
||||
raise
|
||||
|
||||
async def create_collection(self, user, collection, name=None, description=None, tags=None):
|
||||
"""Create a new collection metadata record"""
|
||||
try:
|
||||
import datetime
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Set defaults for optional parameters
|
||||
name = name if name is not None else collection
|
||||
description = description if description is not None else ""
|
||||
tags = tags if tags is not None else set()
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.cassandra.execute, self.insert_collection_stmt,
|
||||
[user, collection, name, description, tags, now, now]
|
||||
)
|
||||
|
||||
logger.info(f"Created collection {user}/{collection}")
|
||||
|
||||
# Return the created collection data
|
||||
return {
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"tags": list(tags) if isinstance(tags, set) else tags,
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating collection: {e}")
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue