Release 1.4 -> master (#524)

Catch up
This commit is contained in:
cybermaggedon 2025-09-20 16:00:37 +01:00 committed by GitHub
parent a8e437fc7f
commit 6c7af8789d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
216 changed files with 31360 additions and 1611 deletions

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.2,<1.3",
"trustgraph-base>=1.4,<1.5",
"aiohttp",
"anthropic",
"cassandra-driver",
@ -40,6 +40,7 @@ dependencies = [
"qdrant-client",
"rdflib",
"requests",
"strawberry-graphql",
"tabulate",
"tiktoken",
"urllib3",
@ -86,14 +87,16 @@ kg-store = "trustgraph.storage.knowledge:run"
librarian = "trustgraph.librarian:run"
mcp-tool = "trustgraph.agent.mcp_tool:run"
metering = "trustgraph.metering:run"
nlp-query = "trustgraph.retrieval.nlp_query:run"
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
oe-write-milvus = "trustgraph.storage.object_embeddings.milvus:run"
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
pdf-decoder = "trustgraph.decoding.pdf:run"
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
prompt-template = "trustgraph.prompt.template:run"
rev-gateway = "trustgraph.rev_gateway:run"
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
run-processing = "trustgraph.processing:run"
structured-query = "trustgraph.retrieval.structured_query:run"
structured-diag = "trustgraph.retrieval.structured_diag:run"
text-completion-azure = "trustgraph.model.text_completion.azure:run"
text-completion-azure-openai = "trustgraph.model.text_completion.azure_openai:run"
text-completion-claude = "trustgraph.model.text_completion.claude:run"

View file

@ -269,13 +269,7 @@ class AgentManager:
logger.debug(f"TOOL>>> {act}")
# Instantiate the tool implementation with context and config
if action.config:
tool_instance = action.implementation(context, **action.config)
else:
tool_instance = action.implementation(context)
resp = await tool_instance.invoke(
resp = await action.implementation(context).invoke(
**act.arguments
)

View file

@ -12,12 +12,13 @@ import logging
logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
from . agent_manager import AgentManager
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
from . types import Final, Action, Tool, Argument
@ -79,6 +80,13 @@ class Processor(AgentService):
)
)
self.register_specification(
StructuredQueryClientSpec(
request_name = "structured-query-request",
response_name = "structured-query-response",
)
)
async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}")
@ -137,11 +145,21 @@ class Processor(AgentService):
template_id=data.get("template"),
arguments=arguments
)
elif impl_id == "structured-query":
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection"),
user=None # User will be provided dynamically via context
)
arguments = StructuredQueryImpl.get_arguments()
else:
raise RuntimeError(
f"Tool type {impl_id} not known"
)
# Validate tool configuration
validate_tool_config(data)
tools[name] = Tool(
name=name,
description=data.get("description"),
@ -219,14 +237,43 @@ class Processor(AgentService):
await respond(r)
# Apply tool filtering based on request groups and state
filtered_tools = filter_tools_by_group_and_state(
tools=self.agent.tools,
requested_groups=getattr(request, 'group', None),
current_state=getattr(request, 'state', None)
)
logger.info(f"Filtered from {len(self.agent.tools)} to {len(filtered_tools)} available tools")
# Create temporary agent with filtered tools
temp_agent = AgentManager(
tools=filtered_tools,
additional_context=self.agent.additional_context
)
logger.debug("Call React")
act = await self.agent.react(
# Create user-aware context wrapper that preserves the flow interface
# but adds user information for tools that need it
class UserAwareContext:
def __init__(self, flow, user):
self._flow = flow
self._user = user
def __call__(self, service_name):
client = self._flow(service_name)
# For structured query clients, store user context
if service_name == "structured-query-request":
client._current_user = self._user
return client
act = await temp_agent.react(
question = request.question,
history = history,
think = think,
observe = observe,
context = flow,
context = UserAwareContext(flow, request.user),
)
logger.debug(f"Action: {act}")
@ -255,11 +302,17 @@ class Processor(AgentService):
logger.debug("Send next...")
history.append(act)
# Handle state transitions if tool execution was successful
next_state = request.state
if act.name in filtered_tools:
executed_tool = filtered_tools[act.name]
next_state = get_next_state(executed_tool, request.state or "undefined")
r = AgentRequest(
question=request.question,
plan=request.plan,
state=request.state,
state=next_state,
group=getattr(request, 'group', []),
history=[
AgentStep(
thought=h.thought,

View file

@ -85,6 +85,49 @@ class McpToolImpl:
return json.dumps(output)
# This tool implementation knows how to query structured data using natural language
class StructuredQueryImpl:
def __init__(self, context, collection=None, user=None):
self.context = context
self.collection = collection # For multi-tenant scenarios
self.user = user # User context for multi-tenancy
@staticmethod
def get_arguments():
return [
Argument(
name="question",
type="string",
description="Natural language question about structured data (tables, databases, etc.)"
)
]
async def invoke(self, **arguments):
client = self.context("structured-query-request")
logger.debug("Structured query question...")
# Get user from client context if available, otherwise use instance user or default
user = getattr(client, '_current_user', self.user or "trustgraph")
result = await client.structured_query(
question=arguments.get("question"),
user=user,
collection=self.collection or "default"
)
# Format the result for the agent
if isinstance(result, dict):
if result.get("error"):
return f"Error: {result['error']['message']}"
elif result.get("data"):
# Pretty format JSON data for agent consumption
return json.dumps(result["data"], indent=2)
else:
return "No data returned"
else:
return str(result)
# This tool implementation knows how to execute prompt templates
class PromptImpl:
def __init__(self, context, template_id, arguments=None):

View file

@ -0,0 +1,165 @@
"""
Tool filtering logic for the TrustGraph tool group system.
Provides functions to filter available tools based on group membership
and execution state as defined in the tool-group tech spec.
"""
import logging
from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
def filter_tools_by_group_and_state(
tools: Dict[str, Any],
requested_groups: Optional[List[str]] = None,
current_state: Optional[str] = None
) -> Dict[str, Any]:
"""
Filter tools based on group membership and execution state.
Args:
tools: Dictionary of tool_name -> tool_object
requested_groups: List of groups requested (defaults to ["default"])
current_state: Current execution state (defaults to "undefined")
Returns:
Dictionary of filtered tools that match group and state criteria
"""
# Apply defaults as specified in tech spec
if requested_groups is None:
requested_groups = ["default"]
if current_state is None or current_state == "":
current_state = "undefined"
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")
filtered_tools = {}
for tool_name, tool in tools.items():
if _is_tool_available(tool, requested_groups, current_state):
filtered_tools[tool_name] = tool
else:
logger.debug(f"Tool {tool_name} filtered out")
logger.info(f"Filtered {len(tools)} tools to {len(filtered_tools)} available tools")
return filtered_tools
def _is_tool_available(
tool: Any,
requested_groups: List[str],
current_state: str
) -> bool:
"""
Check if a tool is available based on group and state criteria.
Args:
tool: Tool object with config attribute containing group/state metadata
requested_groups: List of requested groups
current_state: Current execution state
Returns:
True if tool should be available, False otherwise
"""
# Extract tool configuration
config = getattr(tool, 'config', {})
# Get tool groups (default to ["default"] if not specified)
tool_groups = config.get('group', ["default"])
if not isinstance(tool_groups, list):
tool_groups = [tool_groups]
# Get tool applicable states (default to all states if not specified)
applicable_states = config.get('applicable-states', ["*"])
if not isinstance(applicable_states, list):
applicable_states = [applicable_states]
# Apply group filtering logic from tech spec:
# Tool is available if intersection(tool_groups, requested_groups) is not empty
# OR "*" is in requested_groups (wildcard access)
group_match = (
"*" in requested_groups or
bool(set(tool_groups) & set(requested_groups))
)
# Apply state filtering logic from tech spec:
# Tool is available if current_state is in applicable_states
# OR "*" is in applicable_states (available in all states)
state_match = (
"*" in applicable_states or
current_state in applicable_states
)
is_available = group_match and state_match
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"Tool availability check: tool_groups={tool_groups}, "
f"requested_groups={requested_groups}, applicable_states={applicable_states}, "
f"current_state={current_state}, group_match={group_match}, "
f"state_match={state_match}, is_available={is_available}"
)
return is_available
def get_next_state(tool: Any, current_state: str) -> str:
"""
Get the next state after successful tool execution.
Args:
tool: Tool object with config attribute
current_state: Current execution state
Returns:
Next state, or current_state if no transition is defined
"""
config = getattr(tool, 'config', {})
if config is None:
config = {}
next_state = config.get('state')
if next_state:
logger.debug(f"State transition: {current_state} -> {next_state}")
return next_state
else:
logger.debug(f"No state transition defined, staying in {current_state}")
return current_state
def validate_tool_config(config: Dict[str, Any]) -> None:
"""
Validate tool configuration for group and state fields.
Args:
config: Tool configuration dictionary
Raises:
ValueError: If configuration is invalid
"""
# Validate group field
if 'group' in config:
groups = config['group']
if not isinstance(groups, list):
raise ValueError("Tool 'group' field must be a list of strings")
if not all(isinstance(g, str) for g in groups):
raise ValueError("All group names must be strings")
# Validate state field
if 'state' in config:
state = config['state']
if not isinstance(state, str):
raise ValueError("Tool 'state' field must be a string")
# Validate applicable-states field
if 'applicable-states' in config:
states = config['applicable-states']
if not isinstance(states, list):
raise ValueError("Tool 'applicable-states' field must be a list of strings")
if not all(isinstance(s, str) for s in states):
raise ValueError("All state names must be strings")

View file

@ -45,13 +45,13 @@ class Configuration:
# FIXME: Some version vs config race conditions
def __init__(self, push, host, user, password, keyspace):
def __init__(self, push, host, username, password, keyspace):
# External function to respond to update
self.push = push
self.table_store = ConfigTableStore(
host, user, password, keyspace
host, username, password, keyspace
)
async def inc_version(self):

View file

@ -15,6 +15,7 @@ from trustgraph.schema import FlowRequest, FlowResponse
from trustgraph.schema import flow_request_queue, flow_response_queue
from trustgraph.base import AsyncProcessor, Consumer, Producer
from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from . config import Configuration
from . flow import FlowConfig
@ -60,9 +61,21 @@ class Processor(AsyncProcessor):
"flow_response_queue", default_flow_response_queue
)
cassandra_host = params.get("cassandra_host", default_cassandra_host)
cassandra_user = params.get("cassandra_user")
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
id = params.get("id")
@ -76,8 +89,9 @@ class Processor(AsyncProcessor):
"config_push_schema": ConfigPush.__name__,
"flow_request_schema": FlowRequest.__name__,
"flow_response_schema": FlowResponse.__name__,
"cassandra_host": cassandra_host,
"cassandra_user": cassandra_user,
"cassandra_host": self.cassandra_host,
"cassandra_username": self.cassandra_username,
"cassandra_password": self.cassandra_password,
}
)
@ -142,9 +156,9 @@ class Processor(AsyncProcessor):
)
self.config = Configuration(
host = cassandra_host.split(","),
user = cassandra_user,
password = cassandra_password,
host = self.cassandra_host,
username = self.cassandra_username,
password = self.cassandra_password,
keyspace = keyspace,
push = self.push
)
@ -276,23 +290,7 @@ class Processor(AsyncProcessor):
help=f'Flow response queue {default_flow_response_queue}',
)
parser.add_argument(
'--cassandra-host',
default="cassandra",
help=f'Graph host (default: cassandra)'
)
parser.add_argument(
'--cassandra-user',
default=None,
help=f'Cassandra user'
)
parser.add_argument(
'--cassandra-password',
default=None,
help=f'Cassandra password'
)
add_cassandra_args(parser)
def run():

View file

@ -16,12 +16,12 @@ logger = logging.getLogger(__name__)
class KnowledgeManager:
def __init__(
self, cassandra_host, cassandra_user, cassandra_password,
self, cassandra_host, cassandra_username, cassandra_password,
keyspace, flow_config,
):
self.table_store = KnowledgeTableStore(
cassandra_host, cassandra_user, cassandra_password, keyspace
cassandra_host, cassandra_username, cassandra_password, keyspace
)
self.loader_queue = asyncio.Queue(maxsize=20)
@ -248,6 +248,9 @@ class KnowledgeManager:
await ge_pub.start()
async def publish_triples(t):
# Override collection with request collection
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
t.metadata.collection = request.collection or "default"
await t_pub.send(None, t)
logger.debug("Publishing triples...")
@ -260,6 +263,9 @@ class KnowledgeManager:
)
async def publish_ge(g):
# Override collection with request collection
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
g.metadata.collection = request.collection or "default"
await ge_pub.send(None, g)
logger.debug("Publishing graph embeddings...")

View file

@ -11,6 +11,7 @@ import logging
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber
from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .. schema import KnowledgeRequest, KnowledgeResponse, Error
from .. schema import knowledge_request_queue, knowledge_response_queue
@ -49,16 +50,29 @@ class Processor(AsyncProcessor):
"knowledge_response_queue", default_knowledge_response_queue
)
cassandra_host = params.get("cassandra_host", default_cassandra_host)
cassandra_user = params.get("cassandra_user")
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
super(Processor, self).__init__(
**params | {
"knowledge_request_queue": knowledge_request_queue,
"knowledge_response_queue": knowledge_response_queue,
"cassandra_host": cassandra_host,
"cassandra_user": cassandra_user,
"cassandra_host": self.cassandra_host,
"cassandra_username": self.cassandra_username,
"cassandra_password": self.cassandra_password,
}
)
@ -89,9 +103,9 @@ class Processor(AsyncProcessor):
)
self.knowledge = KnowledgeManager(
cassandra_host = cassandra_host.split(","),
cassandra_user = cassandra_user,
cassandra_password = cassandra_password,
cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username,
cassandra_password = self.cassandra_password,
keyspace = keyspace,
flow_config = self,
)
@ -210,23 +224,7 @@ class Processor(AsyncProcessor):
help=f'Config response queue {default_knowledge_response_queue}',
)
parser.add_argument(
'--cassandra-host',
default="cassandra",
help=f'Graph host (default: cassandra)'
)
parser.add_argument(
'--cassandra-user',
default=None,
help=f'Cassandra user'
)
parser.add_argument(
'--cassandra-password',
default=None,
help=f'Cassandra password'
)
add_cassandra_args(parser)
def run():

View file

@ -1,137 +0,0 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
# Global list to track clusters for cleanup
_active_clusters = []
class TrustGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", table="default", username=None, password=None
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.table = table
self.username = username
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
# Track this cluster globally
_active_clusters.append(self.cluster)
self.init()
def clear(self):
self.session.execute(f"""
drop keyspace if exists {self.keyspace};
""");
self.init()
def init(self):
self.session.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
}};
""");
self.session.set_keyspace(self.keyspace)
self.session.execute(f"""
create table if not exists {self.table} (
s text,
p text,
o text,
PRIMARY KEY (s, p, o)
);
""");
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
""");
self.session.execute(f"""
create index if not exists {self.table}_o
ON {self.table} (o);
""");
def insert(self, s, p, o):
self.session.execute(
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
(s, p, o)
)
def get_all(self, limit=50):
return self.session.execute(
f"select s, p, o from {self.table} limit {limit}"
)
def get_s(self, s, limit=10):
return self.session.execute(
f"select p, o from {self.table} where s = %s limit {limit}",
(s,)
)
def get_p(self, p, limit=10):
return self.session.execute(
f"select s, o from {self.table} where p = %s limit {limit}",
(p,)
)
def get_o(self, o, limit=10):
return self.session.execute(
f"select s, p from {self.table} where o = %s limit {limit}",
(o,)
)
def get_sp(self, s, p, limit=10):
return self.session.execute(
f"select o from {self.table} where s = %s and p = %s limit {limit}",
(s, p)
)
def get_po(self, p, o, limit=10):
return self.session.execute(
f"select s from {self.table} where p = %s and o = %s limit {limit} allow filtering",
(p, o)
)
def get_os(self, o, s, limit=10):
return self.session.execute(
f"select p from {self.table} where o = %s and s = %s limit {limit}",
(o, s)
)
def get_spo(self, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
)
def close(self):
"""Close the Cassandra session and cluster connections properly"""
if hasattr(self, 'session') and self.session:
self.session.shutdown()
if hasattr(self, 'cluster') and self.cluster:
self.cluster.shutdown()
# Remove from global tracking
if self.cluster in _active_clusters:
_active_clusters.remove(self.cluster)

View file

@ -0,0 +1,350 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import os
import logging
# Global list to track clusters for cleanup
_active_clusters = []
logger = logging.getLogger(__name__)
class KnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.username = username
# Multi-table schema design for optimal performance
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
if self.use_legacy:
self.table = "triples" # Legacy single table
else:
# New optimized tables
self.subject_table = "triples_s"
self.po_table = "triples_p"
self.object_table = "triples_o"
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
# Track this cluster globally
_active_clusters.append(self.cluster)
self.init()
if not self.use_legacy:
self.prepare_statements()
def clear(self):
self.session.execute(f"""
drop keyspace if exists {self.keyspace};
""");
self.init()
def init(self):
self.session.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
}};
""");
self.session.set_keyspace(self.keyspace)
if self.use_legacy:
self.init_legacy_schema()
else:
self.init_optimized_schema()
def init_legacy_schema(self):
"""Initialize legacy single-table schema for backward compatibility"""
self.session.execute(f"""
create table if not exists {self.table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
""");
self.session.execute(f"""
create index if not exists {self.table}_s
ON {self.table} (s);
""");
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
""");
self.session.execute(f"""
create index if not exists {self.table}_o
ON {self.table} (o);
""");
def init_optimized_schema(self):
"""Initialize optimized multi-table schema for performance"""
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.subject_table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY ((collection, s), p, o)
);
""");
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.po_table} (
collection text,
p text,
o text,
s text,
PRIMARY KEY ((collection, p), o, s)
);
""");
# Table 3: Object-centric queries (get_o)
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.object_table} (
collection text,
o text,
s text,
p text,
PRIMARY KEY ((collection, o), s, p)
);
""");
logger.info("Optimized multi-table schema initialized")
def prepare_statements(self):
"""Prepare statements for optimal performance"""
# Insert statements for batch operations
self.insert_subject_stmt = self.session.prepare(
f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
self.insert_po_stmt = self.session.prepare(
f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)"
)
self.insert_object_stmt = self.session.prepare(
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
)
# Query statements for optimized access
self.get_all_stmt = self.session.prepare(
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING"
)
self.get_s_stmt = self.session.prepare(
f"SELECT p, o FROM {self.subject_table} WHERE collection = ? AND s = ? LIMIT ?"
)
self.get_p_stmt = self.session.prepare(
f"SELECT s, o FROM {self.po_table} WHERE collection = ? AND p = ? LIMIT ?"
)
self.get_o_stmt = self.session.prepare(
f"SELECT s, p FROM {self.object_table} WHERE collection = ? AND o = ? LIMIT ?"
)
self.get_sp_stmt = self.session.prepare(
f"SELECT o FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? LIMIT ?"
)
# The critical optimization: get_po without ALLOW FILTERING!
self.get_po_stmt = self.session.prepare(
f"SELECT s FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? LIMIT ?"
)
self.get_os_stmt = self.session.prepare(
f"SELECT p FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? LIMIT ?"
)
self.get_spo_stmt = self.session.prepare(
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
)
logger.info("Prepared statements initialized for optimal performance")
def insert(self, collection, s, p, o):
if self.use_legacy:
self.session.execute(
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
(collection, s, p, o)
)
else:
# Batch write to all three tables for consistency
batch = BatchStatement()
# Insert into subject table
batch.add(self.insert_subject_stmt, (collection, s, p, o))
# Insert into predicate-object table (column order: collection, p, o, s)
batch.add(self.insert_po_stmt, (collection, p, o, s))
# Insert into object table (column order: collection, o, s, p)
batch.add(self.insert_object_stmt, (collection, o, s, p))
self.session.execute(batch)
def get_all(self, collection, limit=50):
if self.use_legacy:
return self.session.execute(
f"select s, p, o from {self.table} where collection = %s limit {limit}",
(collection,)
)
else:
# Use subject table for get_all queries
return self.session.execute(
self.get_all_stmt,
(collection, limit)
)
def get_s(self, collection, s, limit=10):
if self.use_legacy:
return self.session.execute(
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
(collection, s)
)
else:
# Optimized: Direct partition access with (collection, s)
return self.session.execute(
self.get_s_stmt,
(collection, s, limit)
)
def get_p(self, collection, p, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
(collection, p)
)
else:
# Optimized: Use po_table for direct partition access
return self.session.execute(
self.get_p_stmt,
(collection, p, limit)
)
def get_o(self, collection, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
(collection, o)
)
else:
# Optimized: Use object_table for direct partition access
return self.session.execute(
self.get_o_stmt,
(collection, o, limit)
)
def get_sp(self, collection, s, p, limit=10):
if self.use_legacy:
return self.session.execute(
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
(collection, s, p)
)
else:
# Optimized: Use subject_table with clustering key access
return self.session.execute(
self.get_sp_stmt,
(collection, s, p, limit)
)
def get_po(self, collection, p, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
(collection, p, o)
)
else:
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
return self.session.execute(
self.get_po_stmt,
(collection, p, o, limit)
)
def get_os(self, collection, o, s, limit=10):
if self.use_legacy:
return self.session.execute(
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
(collection, o, s)
)
else:
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
return self.session.execute(
self.get_os_stmt,
(collection, s, o, limit)
)
def get_spo(self, collection, s, p, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
(collection, s, p, o)
)
else:
# Optimized: Use subject_table for exact key lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
)
def delete_collection(self, collection):
"""Delete all triples for a specific collection"""
if self.use_legacy:
self.session.execute(
f"delete from {self.table} where collection = %s",
(collection,)
)
else:
# Delete from all three tables
self.session.execute(
f"delete from {self.subject_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.po_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.object_table} where collection = %s",
(collection,)
)
def close(self):
"""Close the Cassandra session and cluster connections properly"""
if hasattr(self, 'session') and self.session:
self.session.shutdown()
if hasattr(self, 'cluster') and self.cluster:
self.cluster.shutdown()
# Remove from global tracking
if self.cluster in _active_clusters:
_active_clusters.remove(self.cluster)

View file

@ -2,9 +2,32 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
import logging
import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
# Replace non-alphanumeric characters (except underscore) with underscore
# Then collapse multiple underscores into single underscore
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
safe = re.sub(r'_+', '_', safe)
# Remove leading/trailing underscores
safe = safe.strip('_')
# Ensure it's not empty
if not safe:
safe = 'default'
return safe
safe_user = sanitize(user)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
class DocVectors:
def __init__(self, uri="http://localhost:19530", prefix='doc'):
@ -26,9 +49,9 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def init_collection(self, dimension):
def init_collection(self, dimension, user, collection):
collection_name = self.prefix + "_" + str(dimension)
collection_name = make_safe_collection_name(user, collection, self.prefix)
pkey_field = FieldSchema(
name="id",
@ -75,14 +98,14 @@ class DocVectors:
index_params=index_params
)
self.collections[dimension] = collection_name
self.collections[(dimension, user, collection)] = collection_name
def insert(self, embeds, doc):
def insert(self, embeds, doc, user, collection):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
data = [
{
@ -92,18 +115,18 @@ class DocVectors:
]
self.client.insert(
collection_name=self.collections[dim],
collection_name=self.collections[(dim, user, collection)],
data=data
)
def search(self, embeds, fields=["doc"], limit=10):
def search(self, embeds, user, collection, fields=["doc"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
coll = self.collections[dim]
coll = self.collections[(dim, user, collection)]
search_params = {
"metric_type": "COSINE",
@ -139,3 +162,20 @@ class DocVectors:
return res
def delete_collection(self, user, collection):
"""Delete a collection for the given user and collection"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
# Check if collection exists
if self.client.has_collection(collection_name):
# Drop the collection
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")

View file

@ -2,9 +2,32 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
import logging
import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
# Replace non-alphanumeric characters (except underscore) with underscore
# Then collapse multiple underscores into single underscore
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
safe = re.sub(r'_+', '_', safe)
# Remove leading/trailing underscores
safe = safe.strip('_')
# Ensure it's not empty
if not safe:
safe = 'default'
return safe
safe_user = sanitize(user)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
class EntityVectors:
def __init__(self, uri="http://localhost:19530", prefix='entity'):
@ -26,9 +49,9 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def init_collection(self, dimension):
def init_collection(self, dimension, user, collection):
collection_name = self.prefix + "_" + str(dimension)
collection_name = make_safe_collection_name(user, collection, self.prefix)
pkey_field = FieldSchema(
name="id",
@ -75,14 +98,14 @@ class EntityVectors:
index_params=index_params
)
self.collections[dimension] = collection_name
self.collections[(dimension, user, collection)] = collection_name
def insert(self, embeds, entity):
def insert(self, embeds, entity, user, collection):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
data = [
{
@ -92,18 +115,18 @@ class EntityVectors:
]
self.client.insert(
collection_name=self.collections[dim],
collection_name=self.collections[(dim, user, collection)],
data=data
)
def search(self, embeds, fields=["entity"], limit=10):
def search(self, embeds, user, collection, fields=["entity"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
coll = self.collections[dim]
coll = self.collections[(dim, user, collection)]
search_params = {
"metric_type": "COSINE",
@ -139,3 +162,20 @@ class EntityVectors:
return res
def delete_collection(self, user, collection):
"""Delete a collection for the given user and collection"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
# Check if collection exists
if self.client.has_collection(collection_name):
# Drop the collection
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")

View file

@ -1,157 +0,0 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
import logging
logger = logging.getLogger(__name__)
class ObjectVectors:
def __init__(self, uri="http://localhost:19530", prefix='obj'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def init_collection(self, dimension, name):
collection_name = self.prefix + "_" + name + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
name_field = FieldSchema(
name="name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_name_field = FieldSchema(
name="key_name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_field = FieldSchema(
name="key",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [
pkey_field, vec_field, name_field, key_name_field, key_field
],
description = "Object embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[(dimension, name)] = collection_name
def insert(self, embeds, name, key_name, key):
dim = len(embeds)
if (dim, name) not in self.collections:
self.init_collection(dim, name)
data = [
{
"vector": embeds,
"name": name,
"key_name": key_name,
"key": key,
}
]
self.client.insert(
collection_name=self.collections[(dim, name)],
data=data
)
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim, name)
coll = self.collections[(dim, name)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
logger.debug("Loading...")
self.client.load_collection(
collection_name=coll,
)
logger.debug("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
logger.debug(f"Unloading, reload at {self.next_reload}")
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -27,13 +27,13 @@ class Processor(FlowProcessor):
id = params.get("id")
concurrency = params.get("concurrency", 1)
template_id = params.get("template-id", default_template_id)
config_key = params.get("config-type", default_config_type)
template_id = params.get("template_id", default_template_id)
config_key = params.get("config_type", default_config_type)
super().__init__(**params | {
"id": id,
"template-id": template_id,
"config-type": config_key,
"template_id": template_id,
"config_type": config_key,
"concurrency": concurrency,
})

View file

@ -53,7 +53,7 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"config-type": self.config_key,
"config_type": self.config_key,
"concurrency": concurrency,
}
)
@ -256,31 +256,34 @@ class Processor(FlowProcessor):
flow
)
# Emit each extracted object
for obj in objects:
# Emit extracted objects as a batch if any were found
if objects:
# Calculate confidence (could be enhanced with actual confidence from prompt)
confidence = 0.8 # Default confidence
# Convert all values to strings for Pulsar compatibility
string_values = convert_values_to_strings(obj)
# Convert all objects' values to strings for Pulsar compatibility
batch_values = []
for obj in objects:
string_values = convert_values_to_strings(obj)
batch_values.append(string_values)
# Create ExtractedObject
# Create ExtractedObject with batched values
extracted = ExtractedObject(
metadata=Metadata(
id=f"{v.metadata.id}:{schema_name}:{hash(str(obj))}",
id=f"{v.metadata.id}:{schema_name}",
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
schema_name=schema_name,
values=string_values,
values=batch_values, # Array of objects
confidence=confidence,
source_span=chunk_text[:100] # First 100 chars as source reference
)
await flow("output").send(extracted)
logger.debug(f"Emitted extracted object for schema {schema_name}")
logger.debug(f"Emitted batch of {len(objects)} objects for schema {schema_name}")
except Exception as e:
logger.error(f"Object extraction exception: {e}", exc_info=True)

View file

@ -0,0 +1,30 @@
from ... schema import CollectionManagementRequest, CollectionManagementResponse
from ... schema import collection_request_queue, collection_response_queue
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class CollectionManagementRequestor(ServiceRequestor):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
super(CollectionManagementRequestor, self).__init__(
pulsar_client=pulsar_client,
consumer_name = consumer,
subscription = subscriber,
request_queue=collection_request_queue,
response_queue=collection_response_queue,
request_schema=CollectionManagementRequest,
response_schema=CollectionManagementResponse,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("collection-management")
self.response_translator = TranslatorRegistry.get_response_translator("collection-management")
def to_request(self, body):
print("REQUEST", body, flush=True)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
print("RESPONSE", message, flush=True)
return self.response_translator.from_response_with_completion(message)

View file

@ -26,46 +26,66 @@ class DocumentEmbeddingsExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = DocumentEmbeddings
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = DocumentEmbeddings,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -8,6 +9,9 @@ from ... schema import DocumentEmbeddings, ChunkEmbeddings
from ... base import Publisher
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
# Module logger
logger = logging.getLogger(__name__)
class DocumentEmbeddingsImport:
def __init__(
@ -26,13 +30,17 @@ class DocumentEmbeddingsImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -26,46 +26,66 @@ class EntityContextsExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = EntityContexts
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = EntityContexts,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_entity_contexts(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -9,6 +10,9 @@ from ... base import Publisher
from . serialize import to_subgraph, to_value
# Module logger
logger = logging.getLogger(__name__)
class EntityContextsImport:
def __init__(
@ -26,13 +30,17 @@ class EntityContextsImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -26,46 +26,66 @@ class GraphEmbeddingsExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = GraphEmbeddings
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = GraphEmbeddings,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -9,6 +10,9 @@ from ... base import Publisher
from . serialize import to_subgraph, to_value
# Module logger
logger = logging.getLogger(__name__)
class GraphEmbeddingsImport:
def __init__(
@ -26,13 +30,17 @@ class GraphEmbeddingsImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -11,6 +11,7 @@ from . config import ConfigRequestor
from . flow import FlowRequestor
from . librarian import LibrarianRequestor
from . knowledge import KnowledgeRequestor
from . collection_management import CollectionManagementRequestor
from . embeddings import EmbeddingsRequestor
from . agent import AgentRequestor
@ -19,6 +20,10 @@ from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . objects_query import ObjectsQueryRequestor
from . nlp_query import NLPQueryRequestor
from . structured_query import StructuredQueryRequestor
from . structured_diag import StructuredDiagRequestor
from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
@ -34,6 +39,7 @@ from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
from . entity_contexts_import import EntityContextsImport
from . objects_import import ObjectsImport
from . core_export import CoreExport
from . core_import import CoreImport
@ -50,6 +56,10 @@ request_response_dispatchers = {
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"triples": TriplesQueryRequestor,
"objects": ObjectsQueryRequestor,
"nlp-query": NLPQueryRequestor,
"structured-query": StructuredQueryRequestor,
"structured-diag": StructuredDiagRequestor,
}
global_dispatchers = {
@ -57,6 +67,7 @@ global_dispatchers = {
"flow": FlowRequestor,
"librarian": LibrarianRequestor,
"knowledge": KnowledgeRequestor,
"collection-management": CollectionManagementRequestor,
}
sender_dispatchers = {
@ -76,6 +87,7 @@ import_dispatchers = {
"graph-embeddings": GraphEmbeddingsImport,
"document-embeddings": DocumentEmbeddingsImport,
"entity-contexts": EntityContextsImport,
"objects": ObjectsImport,
}
class DispatcherWrapper:

View file

@ -147,7 +147,7 @@ class Mux:
self.running.stop()
if self.ws:
self.ws.close()
await self.ws.close()
self.ws = None
break
@ -165,6 +165,6 @@ class Mux:
self.running.stop()
if self.ws:
self.ws.close()
await self.ws.close()
self.ws = None

View file

@ -0,0 +1,30 @@
from ... schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class NLPQueryRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(NLPQueryRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=request_queue,
response_queue=response_queue,
request_schema=QuestionToStructuredQueryRequest,
response_schema=QuestionToStructuredQueryResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("nlp-query")
self.response_translator = TranslatorRegistry.get_response_translator("nlp-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -0,0 +1,76 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import ExtractedObject
from ... base import Publisher
from . serialize import to_subgraph
# Module logger
logger = logging.getLogger(__name__)
class ObjectsImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = ExtractedObject
)
async def start(self):
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
async def receive(self, msg):
data = msg.json()
# Handle both single object and array of objects for backward compatibility
values_data = data["values"]
if not isinstance(values_data, list):
# Single object - wrap in array
values_data = [values_data]
elt = ExtractedObject(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"].get("metadata", [])),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
schema_name=data["schema_name"],
values=values_data,
confidence=data.get("confidence", 1.0),
source_span=data.get("source_span", ""),
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -0,0 +1,30 @@
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class ObjectsQueryRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(ObjectsQueryRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=request_queue,
response_queue=response_queue,
request_schema=ObjectsQueryRequest,
response_schema=ObjectsQueryResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -0,0 +1,30 @@
from ... schema import StructuredDataDiagnosisRequest, StructuredDataDiagnosisResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class StructuredDiagRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(StructuredDiagRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=request_queue,
response_queue=response_queue,
request_schema=StructuredDataDiagnosisRequest,
response_schema=StructuredDataDiagnosisResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("structured-diag")
self.response_translator = TranslatorRegistry.get_response_translator("structured-diag")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -0,0 +1,30 @@
from ... schema import StructuredQueryRequest, StructuredQueryResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class StructuredQueryRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(StructuredQueryRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=request_queue,
response_queue=response_queue,
request_schema=StructuredQueryRequest,
response_schema=StructuredQueryResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("structured-query")
self.response_translator = TranslatorRegistry.get_response_translator("structured-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -26,46 +26,66 @@ class TriplesExport:
self.subscriber = subscriber
async def destroy(self):
# Step 1: Signal stop to prevent new messages
self.running.stop()
await self.ws.close()
# Step 2: Wait briefly for in-flight messages
await asyncio.sleep(0.5)
# Step 3: Unsubscribe and stop subscriber (triggers queue drain)
if hasattr(self, 'subs'):
await self.subs.unsubscribe_all(self.id)
await self.subs.stop()
# Step 4: Close websocket last
if self.ws and not self.ws.closed:
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = Triples
"""Enhanced run with better error handling"""
self.subs = Subscriber(
client = self.pulsar_client,
topic = self.queue,
consumer_name = self.consumer,
subscription = self.subscriber,
schema = Triples,
backpressure_strategy = "block" # Configurable
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
await self.subs.start()
self.id = str(uuid.uuid4())
q = await self.subs.subscribe_all(self.id)
consecutive_errors = 0
max_consecutive_errors = 5
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_triples(resp))
except TimeoutError:
consecutive_errors = 0 # Reset on success
except asyncio.TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
logger.error(f"Exception: {str(e)}", exc_info=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()
logger.error(f"Exception sending to websocket: {str(e)}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error("Too many consecutive errors, shutting down")
break
# Brief pause before retry
await asyncio.sleep(0.1)
# Graceful cleanup handled in destroy()

View file

@ -1,6 +1,7 @@
import asyncio
import uuid
import logging
from aiohttp import WSMsgType
from ... schema import Metadata
@ -9,6 +10,9 @@ from ... base import Publisher
from . serialize import to_subgraph
# Module logger
logger = logging.getLogger(__name__)
class TriplesImport:
def __init__(
@ -26,13 +30,17 @@ class TriplesImport:
await self.publisher.start()
async def destroy(self):
# Step 1: Stop accepting new messages
self.running.stop()
# Step 2: Wait for publisher to drain its queue
logger.info("Draining publisher queue...")
await self.publisher.stop()
# Step 3: Close websocket only after queue is drained
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()

View file

@ -25,24 +25,43 @@ class SocketEndpoint:
await dispatcher.run()
async def listener(self, ws, dispatcher, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
await dispatcher.receive(msg)
continue
elif msg.type == WSMsgType.BINARY:
await dispatcher.receive(msg)
continue
"""Enhanced listener with graceful shutdown"""
try:
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
await dispatcher.receive(msg)
continue
elif msg.type == WSMsgType.BINARY:
await dispatcher.receive(msg)
continue
else:
# Graceful shutdown on close
logger.info("Websocket closing, initiating graceful shutdown")
running.stop()
# Allow time for dispatcher cleanup
await asyncio.sleep(1.0)
# Close websocket if not already closed
if not ws.closed:
await ws.close()
break
else:
break
running.stop()
await ws.close()
# This executes when the async for loop completes normally (no break)
logger.debug("Websocket iteration completed, performing cleanup")
running.stop()
if not ws.closed:
await ws.close()
except Exception:
# Handle exceptions and cleanup
running.stop()
if not ws.closed:
await ws.close()
raise
async def handle(self, request):
"""Enhanced handler with better cleanup"""
try:
token = request.query['token']
except:
@ -55,7 +74,9 @@ class SocketEndpoint:
ws = web.WebSocketResponse(max_msg_size=52428800)
await ws.prepare(request)
dispatcher = None
try:
async with asyncio.TaskGroup() as tg:
@ -80,9 +101,6 @@ class SocketEndpoint:
logger.debug("Task group closed")
# Finally?
await dispatcher.destroy()
except ExceptionGroup as e:
logger.error("Exception group occurred:", exc_info=True)
@ -90,11 +108,34 @@ class SocketEndpoint:
for se in e.exceptions:
logger.error(f" Exception type: {type(se)}")
logger.error(f" Exception: {se}")
# Attempt graceful dispatcher shutdown
if dispatcher and hasattr(dispatcher, 'destroy'):
try:
await asyncio.wait_for(
dispatcher.destroy(),
timeout=5.0
)
except asyncio.TimeoutError:
logger.warning("Dispatcher shutdown timed out")
except Exception as de:
logger.error(f"Error during dispatcher cleanup: {de}")
except Exception as e:
logger.error(f"Socket exception: {e}", exc_info=True)
await ws.close()
finally:
# Ensure dispatcher cleanup
if dispatcher and hasattr(dispatcher, 'destroy'):
try:
await dispatcher.destroy()
except Exception as de:
logger.error(f"Error in final dispatcher cleanup: {de}")
# Ensure websocket is closed
if ws and not ws.closed:
await ws.close()
return ws
async def start(self):

View file

@ -0,0 +1,315 @@
"""
Collection management for the librarian
"""
import asyncio
import logging
from datetime import datetime
from typing import Dict, Any, List, Optional
from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error
from .. schema import CollectionMetadata
from .. schema import StorageManagementRequest, StorageManagementResponse
from .. exceptions import RequestError
from .. tables.library import LibraryTableStore
# Module logger
logger = logging.getLogger(__name__)
class CollectionManager:
"""Manages collection metadata and coordinates collection operations across storage types"""
def __init__(
self,
cassandra_host,
cassandra_username,
cassandra_password,
keyspace,
vector_storage_producer=None,
object_storage_producer=None,
triples_storage_producer=None,
storage_response_consumer=None
):
"""
Initialize the CollectionManager
Args:
cassandra_host: Cassandra host(s)
cassandra_username: Cassandra username
cassandra_password: Cassandra password
keyspace: Cassandra keyspace for library data
vector_storage_producer: Producer for vector storage management
object_storage_producer: Producer for object storage management
triples_storage_producer: Producer for triples storage management
storage_response_consumer: Consumer for storage management responses
"""
self.table_store = LibraryTableStore(
cassandra_host, cassandra_username, cassandra_password, keyspace
)
# Storage management producers
self.vector_storage_producer = vector_storage_producer
self.object_storage_producer = object_storage_producer
self.triples_storage_producer = triples_storage_producer
self.storage_response_consumer = storage_response_consumer
# Track pending deletion operations
self.pending_deletions = {}
logger.info("Collection manager initialized")
async def ensure_collection_exists(self, user: str, collection: str):
"""
Ensure a collection exists, creating it if necessary (lazy creation)
Args:
user: User ID
collection: Collection ID
"""
try:
# Check if collection already exists
existing = await self.table_store.get_collection(user, collection)
if existing:
logger.debug(f"Collection {user}/{collection} already exists")
return
# Create new collection with default metadata
logger.info(f"Creating new collection {user}/{collection}")
await self.table_store.create_collection(
user=user,
collection=collection,
name=collection, # Default name to collection ID
description="",
tags=set()
)
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
# Don't fail the operation if collection creation fails
# This maintains backward compatibility
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
"""
List collections for a user with optional tag filtering
Args:
request: Collection management request
Returns:
CollectionManagementResponse with list of collections
"""
try:
tag_filter = list(request.tag_filter) if request.tag_filter else None
collections = await self.table_store.list_collections(request.user, tag_filter)
collection_metadata = [
CollectionMetadata(
user=coll["user"],
collection=coll["collection"],
name=coll["name"],
description=coll["description"],
tags=coll["tags"],
created_at=coll["created_at"],
updated_at=coll["updated_at"]
)
for coll in collections
]
return CollectionManagementResponse(
error=None,
collections=collection_metadata,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error listing collections: {e}")
raise RequestError(f"Failed to list collections: {str(e)}")
async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
"""
Update collection metadata (creates if doesn't exist)
Args:
request: Collection management request
Returns:
CollectionManagementResponse with updated collection
"""
try:
# Check if collection exists, create if it doesn't
existing = await self.table_store.get_collection(request.user, request.collection)
if not existing:
# Create new collection with provided metadata
logger.info(f"Creating new collection {request.user}/{request.collection}")
name = request.name if request.name else request.collection
description = request.description if request.description else ""
tags = set(request.tags) if request.tags else set()
await self.table_store.create_collection(
user=request.user,
collection=request.collection,
name=name,
description=description,
tags=tags
)
# Get the newly created collection for response
created_collection = await self.table_store.get_collection(request.user, request.collection)
collection_metadata = CollectionMetadata(
user=created_collection["user"],
collection=created_collection["collection"],
name=created_collection["name"],
description=created_collection["description"],
tags=created_collection["tags"],
created_at=created_collection["created_at"],
updated_at=created_collection["updated_at"]
)
else:
# Collection exists, update it
name = request.name if request.name else None
description = request.description if request.description else None
tags = list(request.tags) if request.tags else None
updated_collection = await self.table_store.update_collection(
request.user, request.collection, name, description, tags
)
collection_metadata = CollectionMetadata(
user=updated_collection["user"],
collection=updated_collection["collection"],
name=updated_collection["name"],
description=updated_collection["description"],
tags=updated_collection["tags"],
created_at="", # Not returned by update
updated_at=updated_collection["updated_at"]
)
return CollectionManagementResponse(
error=None,
collections=[collection_metadata],
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error updating collection: {e}")
raise RequestError(f"Failed to update collection: {str(e)}")
async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
"""
Delete collection with cascade to all storage types
Args:
request: Collection management request
Returns:
CollectionManagementResponse indicating success or failure
"""
try:
deletion_key = (request.user, request.collection)
logger.info(f"Starting cascade deletion for {request.user}/{request.collection}")
# Track this deletion request
self.pending_deletions[deletion_key] = {
"responses_pending": 3, # vector, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
"deletion_complete": asyncio.Event()
}
# Create storage management request
storage_request = StorageManagementRequest(
operation="delete-collection",
user=request.user,
collection=request.collection
)
# Send deletion requests to all storage types
if self.vector_storage_producer:
await self.vector_storage_producer.send(storage_request)
if self.object_storage_producer:
await self.object_storage_producer.send(storage_request)
if self.triples_storage_producer:
await self.triples_storage_producer.send(storage_request)
# Wait for all storage deletions to complete (with timeout)
deletion_info = self.pending_deletions[deletion_key]
try:
await asyncio.wait_for(
deletion_info["deletion_complete"].wait(),
timeout=30.0 # 30 second timeout
)
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for storage deletion responses for {deletion_key}")
deletion_info["all_successful"] = False
deletion_info["error_messages"].append("Timeout waiting for storage deletion")
# Check if all deletions succeeded
if not deletion_info["all_successful"]:
error_msg = f"Storage deletion failed: {'; '.join(deletion_info['error_messages'])}"
logger.error(error_msg)
# Clean up tracking
del self.pending_deletions[deletion_key]
return CollectionManagementResponse(
error=Error(
type="storage_deletion_error",
message=error_msg
),
timestamp=datetime.now().isoformat()
)
# All storage deletions succeeded, now delete metadata
logger.info(f"Storage deletions complete, removing metadata for {deletion_key}")
await self.table_store.delete_collection(request.user, request.collection)
# Clean up tracking
del self.pending_deletions[deletion_key]
return CollectionManagementResponse(
error=None,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error deleting collection: {e}")
# Clean up tracking on error
if deletion_key in self.pending_deletions:
del self.pending_deletions[deletion_key]
raise RequestError(f"Failed to delete collection: {str(e)}")
async def on_storage_response(self, response: StorageManagementResponse):
"""
Handle storage management responses for deletion tracking
Args:
response: Storage management response
"""
logger.debug(f"Received storage response: error={response.error}")
# Find matching deletion by checking all pending deletions
# Note: This is simplified correlation - in production we'd want better correlation
for deletion_key, info in list(self.pending_deletions.items()):
if info["responses_pending"] > 0:
# Record this response
info["responses_received"].append(response)
info["responses_pending"] -= 1
# Check if this response indicates failure
if response.error and response.error.message:
info["all_successful"] = False
info["error_messages"].append(response.error.message)
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
else:
logger.debug(f"Storage deletion succeeded for {deletion_key}")
# If all responses received, signal completion
if info["responses_pending"] == 0:
logger.info(f"All storage responses received for {deletion_key}")
info["deletion_complete"].set()
break # Only process for first matching deletion

View file

@ -16,7 +16,7 @@ class Librarian:
def __init__(
self,
cassandra_host, cassandra_user, cassandra_password,
cassandra_host, cassandra_username, cassandra_password,
minio_host, minio_access_key, minio_secret_key,
bucket_name, keyspace, load_document,
):
@ -26,7 +26,7 @@ class Librarian:
)
self.table_store = LibraryTableStore(
cassandra_host, cassandra_user, cassandra_password, keyspace
cassandra_host, cassandra_username, cassandra_password, keyspace
)
self.load_document = load_document

View file

@ -8,12 +8,19 @@ import asyncio
import base64
import json
import logging
from datetime import datetime
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber
from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .. schema import LibrarianRequest, LibrarianResponse, Error
from .. schema import librarian_request_queue, librarian_response_queue
from .. schema import CollectionManagementRequest, CollectionManagementResponse
from .. schema import collection_request_queue, collection_response_queue
from .. schema import StorageManagementRequest, StorageManagementResponse
from .. schema import vector_storage_management_topic, object_storage_management_topic
from .. schema import triples_storage_management_topic, storage_management_response_topic
from .. schema import Document, Metadata
from .. schema import TextDocument, Metadata
@ -21,6 +28,7 @@ from .. schema import TextDocument, Metadata
from .. exceptions import RequestError
from . librarian import Librarian
from . collection_manager import CollectionManager
# Module logger
logger = logging.getLogger(__name__)
@ -29,6 +37,8 @@ default_ident = "librarian"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
default_collection_request_queue = collection_request_queue
default_collection_response_queue = collection_response_queue
default_minio_host = "minio:9000"
default_minio_access_key = "minioadmin"
@ -56,6 +66,14 @@ class Processor(AsyncProcessor):
"librarian_response_queue", default_librarian_response_queue
)
collection_request_queue = params.get(
"collection_request_queue", default_collection_request_queue
)
collection_response_queue = params.get(
"collection_response_queue", default_collection_response_queue
)
minio_host = params.get("minio_host", default_minio_host)
minio_access_key = params.get(
"minio_access_key",
@ -66,18 +84,33 @@ class Processor(AsyncProcessor):
default_minio_secret_key
)
cassandra_host = params.get("cassandra_host", default_cassandra_host)
cassandra_user = params.get("cassandra_user")
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
super(Processor, self).__init__(
**params | {
"librarian_request_queue": librarian_request_queue,
"librarian_response_queue": librarian_response_queue,
"collection_request_queue": collection_request_queue,
"collection_response_queue": collection_response_queue,
"minio_host": minio_host,
"minio_access_key": minio_access_key,
"cassandra_host": cassandra_host,
"cassandra_user": cassandra_user,
"cassandra_host": self.cassandra_host,
"cassandra_username": self.cassandra_username,
"cassandra_password": self.cassandra_password,
}
)
@ -89,6 +122,18 @@ class Processor(AsyncProcessor):
processor = self.id, flow = None, name = "librarian-response"
)
collection_request_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "collection-request"
)
collection_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "collection-response"
)
storage_response_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "storage-response"
)
self.librarian_request_consumer = Consumer(
taskgroup = self.taskgroup,
client = self.pulsar_client,
@ -107,10 +152,58 @@ class Processor(AsyncProcessor):
metrics = librarian_response_metrics,
)
self.collection_request_consumer = Consumer(
taskgroup = self.taskgroup,
client = self.pulsar_client,
flow = None,
topic = collection_request_queue,
subscriber = id,
schema = CollectionManagementRequest,
handler = self.on_collection_request,
metrics = collection_request_metrics,
)
self.collection_response_producer = Producer(
client = self.pulsar_client,
topic = collection_response_queue,
schema = CollectionManagementResponse,
metrics = collection_response_metrics,
)
# Storage management producers for collection deletion
self.vector_storage_producer = Producer(
client = self.pulsar_client,
topic = vector_storage_management_topic,
schema = StorageManagementRequest,
)
self.object_storage_producer = Producer(
client = self.pulsar_client,
topic = object_storage_management_topic,
schema = StorageManagementRequest,
)
self.triples_storage_producer = Producer(
client = self.pulsar_client,
topic = triples_storage_management_topic,
schema = StorageManagementRequest,
)
self.storage_response_consumer = Consumer(
taskgroup = self.taskgroup,
client = self.pulsar_client,
flow = None,
topic = storage_management_response_topic,
subscriber = id,
schema = StorageManagementResponse,
handler = self.on_storage_response,
metrics = storage_response_metrics,
)
self.librarian = Librarian(
cassandra_host = cassandra_host.split(","),
cassandra_user = cassandra_user,
cassandra_password = cassandra_password,
cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username,
cassandra_password = self.cassandra_password,
minio_host = minio_host,
minio_access_key = minio_access_key,
minio_secret_key = minio_secret_key,
@ -119,6 +212,17 @@ class Processor(AsyncProcessor):
load_document = self.load_document,
)
self.collection_manager = CollectionManager(
cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username,
cassandra_password = self.cassandra_password,
keyspace = keyspace,
vector_storage_producer = self.vector_storage_producer,
object_storage_producer = self.object_storage_producer,
triples_storage_producer = self.triples_storage_producer,
storage_response_consumer = self.storage_response_consumer,
)
self.register_config_handler(self.on_librarian_config)
self.flows = {}
@ -130,6 +234,12 @@ class Processor(AsyncProcessor):
await super(Processor, self).start()
await self.librarian_request_consumer.start()
await self.librarian_response_producer.start()
await self.collection_request_consumer.start()
await self.collection_response_producer.start()
await self.vector_storage_producer.start()
await self.object_storage_producer.start()
await self.triples_storage_producer.start()
await self.storage_response_consumer.start()
async def on_librarian_config(self, config, version):
@ -209,6 +319,19 @@ class Processor(AsyncProcessor):
logger.debug("Document submitted")
async def add_processing_with_collection(self, request):
"""
Wrapper for add_processing that ensures collection exists
"""
# Ensure collection exists when processing is added
if hasattr(request, 'processing_metadata') and request.processing_metadata:
user = request.processing_metadata.user
collection = request.processing_metadata.collection
await self.collection_manager.ensure_collection_exists(user, collection)
# Call the original add_processing method
return await self.librarian.add_processing(request)
async def process_request(self, v):
if v.operation is None:
@ -222,7 +345,7 @@ class Processor(AsyncProcessor):
"update-document": self.librarian.update_document,
"get-document-metadata": self.librarian.get_document_metadata,
"get-document-content": self.librarian.get_document_content,
"add-processing": self.librarian.add_processing,
"add-processing": self.add_processing_with_collection,
"remove-processing": self.librarian.remove_processing,
"list-documents": self.librarian.list_documents,
"list-processing": self.librarian.list_processing,
@ -282,6 +405,73 @@ class Processor(AsyncProcessor):
logger.debug("Librarian input processing complete")
async def process_collection_request(self, v):
"""
Process collection management requests
"""
if v.operation is None:
raise RequestError("Null operation")
logger.debug(f"Collection request: {v.operation}")
impls = {
"list-collections": self.collection_manager.list_collections,
"update-collection": self.collection_manager.update_collection,
"delete-collection": self.collection_manager.delete_collection,
}
if v.operation not in impls:
raise RequestError(f"Invalid collection operation: {v.operation}")
return await impls[v.operation](v)
async def on_collection_request(self, msg, consumer, flow):
"""
Handle collection management request messages
"""
v = msg.value()
id = msg.properties().get("id", "unknown")
logger.info(f"Handling collection request {id}...")
try:
resp = await self.process_collection_request(v)
await self.collection_response_producer.send(
resp, properties={"id": id}
)
except RequestError as e:
resp = CollectionManagementResponse(
error=Error(
type="request-error",
message=str(e),
),
timestamp=datetime.now().isoformat()
)
await self.collection_response_producer.send(
resp, properties={"id": id}
)
except Exception as e:
resp = CollectionManagementResponse(
error=Error(
type="unexpected-error",
message=str(e),
),
timestamp=datetime.now().isoformat()
)
await self.collection_response_producer.send(
resp, properties={"id": id}
)
logger.debug("Collection request processing complete")
async def on_storage_response(self, msg, consumer, flow):
"""
Handle storage management response messages
"""
v = msg.value()
logger.debug("Received storage management response")
await self.collection_manager.on_storage_response(v)
@staticmethod
def add_args(parser):
@ -299,6 +489,18 @@ class Processor(AsyncProcessor):
help=f'Config response queue {default_librarian_response_queue}',
)
parser.add_argument(
'--collection-request-queue',
default=default_collection_request_queue,
help=f'Collection request queue (default: {default_collection_request_queue})'
)
parser.add_argument(
'--collection-response-queue',
default=default_collection_response_queue,
help=f'Collection response queue (default: {default_collection_response_queue})'
)
parser.add_argument(
'--minio-host',
default=default_minio_host,
@ -319,23 +521,7 @@ class Processor(AsyncProcessor):
f'(default: {default_minio_access_key})',
)
parser.add_argument(
'--cassandra-host',
default="cassandra",
help=f'Graph host (default: cassandra)'
)
parser.add_argument(
'--cassandra-user',
default=None,
help=f'Cassandra user'
)
parser.add_argument(
'--cassandra-password',
default=None,
help=f'Cassandra password'
)
add_cassandra_args(parser)
def run():

View file

@ -37,7 +37,7 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"config-type": self.config_key,
"config_type": self.config_key,
"concurrency": concurrency,
}
)

View file

@ -43,7 +43,12 @@ class Processor(DocumentEmbeddingsQueryService):
for vec in msg.vectors:
resp = self.vecstore.search(vec, limit=msg.limit)
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk = r["entity"]["doc"]

View file

@ -47,6 +47,39 @@ class Processor(DocumentEmbeddingsQueryService):
}
)
self.last_index_name = None
def ensure_index_exists(self, index_name, dim):
"""Ensure index exists, create if it doesn't"""
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.pinecone.create_index(
name=index_name,
dimension=dim,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1",
)
)
logger.info(f"Created index: {index_name}")
# Wait for index to be ready
import time
for i in range(0, 1000):
if self.pinecone.describe_index(index_name).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(index_name).status["ready"]:
raise RuntimeError("Gave up waiting for index creation")
except Exception as e:
logger.error(f"Pinecone index creation failed: {e}")
raise e
self.last_index_name = index_name
async def query_document_embeddings(self, msg):
try:
@ -62,9 +95,11 @@ class Processor(DocumentEmbeddingsQueryService):
dim = len(vec)
index_name = (
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
"d-" + msg.user + "-" + msg.collection
)
self.ensure_index_exists(index_name, dim)
index = self.pinecone.Index(index_name)
results = index.query(

View file

@ -38,6 +38,24 @@ class Processor(DocumentEmbeddingsQueryService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
async def query_document_embeddings(self, msg):
@ -49,10 +67,11 @@ class Processor(DocumentEmbeddingsQueryService):
dim = len(vec)
collection = (
"d_" + msg.user + "_" + msg.collection + "_" +
str(dim)
"d_" + msg.user + "_" + msg.collection
)
self.ensure_collection_exists(collection, dim)
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,

View file

@ -50,7 +50,12 @@ class Processor(GraphEmbeddingsQueryService):
for vec in msg.vectors:
resp = self.vecstore.search(vec, limit=msg.limit * 2)
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
for r in resp:
ent = r["entity"]["entity"]

View file

@ -49,6 +49,39 @@ class Processor(GraphEmbeddingsQueryService):
}
)
self.last_index_name = None
def ensure_index_exists(self, index_name, dim):
"""Ensure index exists, create if it doesn't"""
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.pinecone.create_index(
name=index_name,
dimension=dim,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1",
)
)
logger.info(f"Created index: {index_name}")
# Wait for index to be ready
import time
for i in range(0, 1000):
if self.pinecone.describe_index(index_name).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(index_name).status["ready"]:
raise RuntimeError("Gave up waiting for index creation")
except Exception as e:
logger.error(f"Pinecone index creation failed: {e}")
raise e
self.last_index_name = index_name
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
@ -71,9 +104,11 @@ class Processor(GraphEmbeddingsQueryService):
dim = len(vec)
index_name = (
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
"t-" + msg.user + "-" + msg.collection
)
self.ensure_index_exists(index_name, dim)
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance

View file

@ -38,6 +38,24 @@ class Processor(GraphEmbeddingsQueryService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -56,10 +74,11 @@ class Processor(GraphEmbeddingsQueryService):
dim = len(vec)
collection = (
"t_" + msg.user + "_" + msg.collection + "_" +
str(dim)
"t_" + msg.user + "_" + msg.collection
)
self.ensure_collection_exists(collection, dim)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(

View file

@ -0,0 +1,2 @@
from . service import *

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from . service import run
run()

View file

@ -0,0 +1,738 @@
"""
Objects query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
"""
import json
import logging
import asyncio
from typing import Dict, Any, Optional, List, Set
from enum import Enum
from dataclasses import dataclass, field
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
import strawberry
from strawberry import Schema
from strawberry.types import Info
from strawberry.scalars import JSON
from strawberry.tools import create_type
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-query"
# GraphQL filter input types
@strawberry.input
class IntFilter:
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = ObjectsQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = ObjectsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema
self.graphql_schema: Optional[Schema] = None
# GraphQL types cache
self.graphql_types: Dict[str, type] = {}
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces and tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {}
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize table names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
safe_name = 'o_' + safe_name
return safe_name.lower()
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
"""Parse GraphQL filter key into field name and operator"""
if not filter_key:
return ("", "eq")
# Support common GraphQL filter patterns:
# field_name -> (field_name, "eq")
# field_name_gt -> (field_name, "gt")
# field_name_gte -> (field_name, "gte")
# field_name_lt -> (field_name, "lt")
# field_name_lte -> (field_name, "lte")
# field_name_in -> (field_name, "in")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.graphql_types = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.generate_graphql_schema()
def get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL"""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
"""Create a GraphQL type from a RowSchema"""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self.get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
"""Create a dynamic filter input type for a schema"""
# Create the filter type dynamically
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
# Allow filtering on any field for now, not just indexed/primary
# if field.indexed or field.primary:
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
logger.info(f"Added IntFilter for {field.name}")
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
logger.info(f"Added FloatFilter for {field.name}")
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.info(f"Added StringFilter for {field.name}")
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def create_sort_direction_enum(self):
"""Create sort direction enum"""
@strawberry.enum
class SortDirection(Enum):
ASC = "asc"
DESC = "desc"
return SortDirection
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
"""Parse the idiomatic nested filter structure"""
if not where_obj:
return {}
conditions = {}
logger.info(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.info(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
logger.info(f"Final parsed conditions: {conditions}")
return conditions
def generate_graphql_schema(self):
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
self.graphql_schema = None
return
# Create GraphQL types and filter types for each schema
filter_types = {}
sort_direction_enum = self.create_sort_direction_enum()
for schema_name, row_schema in self.schemas.items():
graphql_type = self.create_graphql_type(schema_name, row_schema)
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
self.graphql_types[schema_name] = graphql_type
filter_types[schema_name] = filter_type
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = filter_types[schema_name]
# Create resolver function for this schema
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
async def resolver(
info: Info,
where: Optional[f_type] = None,
order_by: Optional[str] = None,
direction: Optional[sort_enum] = None,
limit: Optional[int] = 100
) -> List[g_type]:
# Get the processor instance from context
processor = info.context["processor"]
user = info.context["user"]
collection = info.context["collection"]
# Parse the idiomatic where clause
filters = processor.parse_idiomatic_where_clause(where)
# Query Cassandra
results = await processor.query_cassandra(
user, collection, s_name, r_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = g_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver
# Add resolver to query
resolver_name = schema_name
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
# Add field to query dictionary
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][resolver_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
self.graphql_schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[Any] = None
) -> List[Dict[str, Any]]:
"""Execute a query against Cassandra"""
# Connect if needed
self.connect_cassandra()
# Build the query
keyspace = self.sanitize_name(user)
table = self.sanitize_table(schema_name)
# Start with basic SELECT
query = f"SELECT * FROM {keyspace}.{table}"
# Add WHERE clauses
where_clauses = [f"collection = %s"]
params = [collection]
# Add filters for indexed or primary key fields
for filter_key, value in filters.items():
if value is not None:
# Parse field name and operator from filter key
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
result = self.parse_filter_key(filter_key)
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
if not result or len(result) != 2:
logger.error(f"parse_filter_key returned invalid result: {result}")
continue # Skip this filter
field_name, operator = result
# Find the field in schema
schema_field = None
for f in row_schema.fields:
if f.name == field_name:
schema_field = f
break
if schema_field:
safe_field = self.sanitize_name(field_name)
# Build WHERE clause based on operator
if operator == "eq":
where_clauses.append(f"{safe_field} = %s")
params.append(value)
elif operator == "gt":
where_clauses.append(f"{safe_field} > %s")
params.append(value)
elif operator == "gte":
where_clauses.append(f"{safe_field} >= %s")
params.append(value)
elif operator == "lt":
where_clauses.append(f"{safe_field} < %s")
params.append(value)
elif operator == "lte":
where_clauses.append(f"{safe_field} <= %s")
params.append(value)
elif operator == "in":
if isinstance(value, list):
placeholders = ",".join(["%s"] * len(value))
where_clauses.append(f"{safe_field} IN ({placeholders})")
params.extend(value)
else:
# Default to equality for unknown operators
where_clauses.append(f"{safe_field} = %s")
params.append(value)
if where_clauses:
query += " WHERE " + " AND ".join(where_clauses)
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
cassandra_order_by_added = False
if order_by and direction:
# Validate that order_by field exists in schema
order_field_exists = any(f.name == order_by for f in row_schema.fields)
if order_field_exists:
safe_order_field = self.sanitize_name(order_by)
direction_str = "ASC" if direction.value == "asc" else "DESC"
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
query += f" ORDER BY {safe_order_field} {direction_str}"
# Add limit first (must come before ALLOW FILTERING)
if limit:
query += f" LIMIT {limit}"
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
query += " ALLOW FILTERING"
# Execute query
try:
result = self.session.execute(query, params)
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
except Exception as e:
# If ORDER BY fails, try without it
if order_by and direction and "ORDER BY" in query:
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
# Remove ORDER BY clause and retry
query_parts = query.split(" ORDER BY ")
if len(query_parts) == 2:
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
result = self.session.execute(query_without_order, params)
cassandra_order_by_added = False
else:
raise
else:
raise
# Convert rows to dicts
results = []
for row in result:
row_dict = {}
for field in row_schema.fields:
safe_field = self.sanitize_name(field.name)
if hasattr(row, safe_field):
value = getattr(row, safe_field)
# Use original field name in result
row_dict[field.name] = value
results.append(row_dict)
# Post-query sorting if Cassandra didn't handle ORDER BY
if order_by and direction and not cassandra_order_by_added:
reverse_order = (direction.value == "desc")
try:
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = ObjectsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in objects query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = ObjectsQueryResponse(
error = Error(
type = "objects-query-error",
message = str(e),
),
data = None,
errors = [],
extensions = {}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-query-graphql-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -6,36 +6,44 @@ null. Output is a list of triples.
import logging
from .... direct.cassandra import TrustGraph
from .... direct.cassandra_kg import KnowledgeGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... base import TriplesQueryService
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "triples-query"
default_graph_host='localhost'
class Processor(TriplesQueryService):
def __init__(self, **params):
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
super(Processor, self).__init__(
**params | {
"graph_host": graph_host,
"graph_username": graph_username,
"cassandra_host": ','.join(hosts),
"cassandra_username": username,
}
)
self.graph_host = [graph_host]
self.username = graph_username
self.password = graph_password
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
self.table = None
def create_value(self, ent):
@ -48,21 +56,21 @@ class Processor(TriplesQueryService):
try:
table = (query.user, query.collection)
user = query.user
if table != self.table:
if self.username and self.password:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=query.user, table=query.collection,
username=self.username, password=self.password
if user != self.table:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=query.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=query.user, table=query.collection,
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=query.user,
)
self.table = table
self.table = user
triples = []
@ -70,13 +78,13 @@ class Processor(TriplesQueryService):
if query.p is not None:
if query.o is not None:
resp = self.tg.get_spo(
query.s.value, query.p.value, query.o.value,
query.collection, query.s.value, query.p.value, query.o.value,
limit=query.limit
)
triples.append((query.s.value, query.p.value, query.o.value))
else:
resp = self.tg.get_sp(
query.s.value, query.p.value,
query.collection, query.s.value, query.p.value,
limit=query.limit
)
for t in resp:
@ -84,14 +92,14 @@ class Processor(TriplesQueryService):
else:
if query.o is not None:
resp = self.tg.get_os(
query.o.value, query.s.value,
query.collection, query.o.value, query.s.value,
limit=query.limit
)
for t in resp:
triples.append((query.s.value, t.p, query.o.value))
else:
resp = self.tg.get_s(
query.s.value,
query.collection, query.s.value,
limit=query.limit
)
for t in resp:
@ -100,14 +108,14 @@ class Processor(TriplesQueryService):
if query.p is not None:
if query.o is not None:
resp = self.tg.get_po(
query.p.value, query.o.value,
query.collection, query.p.value, query.o.value,
limit=query.limit
)
for t in resp:
triples.append((t.s, query.p.value, query.o.value))
else:
resp = self.tg.get_p(
query.p.value,
query.collection, query.p.value,
limit=query.limit
)
for t in resp:
@ -115,13 +123,14 @@ class Processor(TriplesQueryService):
else:
if query.o is not None:
resp = self.tg.get_o(
query.o.value,
query.collection, query.o.value,
limit=query.limit
)
for t in resp:
triples.append((t.s, t.p, query.o.value))
else:
resp = self.tg.get_all(
query.collection,
limit=query.limit
)
for t in resp:
@ -147,24 +156,7 @@ class Processor(TriplesQueryService):
def add_args(parser):
TriplesQueryService.add_args(parser)
parser.add_argument(
'-g', '--graph-host',
default="localhost",
help=f'Graph host (default: localhost)'
)
parser.add_argument(
'--graph-username',
default=None,
help=f'Cassandra username'
)
parser.add_argument(
'--graph-password',
default=None,
help=f'Cassandra password'
)
add_cassandra_args(parser)
def run():

View file

@ -55,6 +55,10 @@ class Processor(TriplesQueryService):
try:
# Extract user and collection, use defaults if not provided
user = query.user if query.user else "default"
collection = query.collection if query.collection else "default"
triples = []
if query.s is not None:
@ -64,10 +68,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value, value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -75,10 +82,13 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value, uri=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -90,10 +100,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -102,10 +115,13 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -120,10 +136,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=query.s.value, value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -132,10 +151,13 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=query.s.value, uri=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -148,10 +170,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=query.s.value,
user=user, collection=collection,
database_=self.db,
)
@ -160,10 +185,13 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=query.s.value,
user=user, collection=collection,
database_=self.db,
)
@ -181,10 +209,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=query.p.value, value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -193,10 +224,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=query.p.value, dest=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -209,10 +243,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -221,10 +258,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -239,10 +279,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -251,10 +294,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -267,9 +313,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
database_=self.db,
)
@ -278,9 +327,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
database_=self.db,
)

View file

@ -55,6 +55,10 @@ class Processor(TriplesQueryService):
try:
# Extract user and collection, use defaults if not provided
user = query.user if query.user else "default"
collection = query.collection if query.collection else "default"
triples = []
if query.s is not None:
@ -64,9 +68,12 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src",
src=query.s.value, rel=query.p.value, value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -74,9 +81,12 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN $src as src",
src=query.s.value, rel=query.p.value, uri=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -88,9 +98,12 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest",
src=query.s.value, rel=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -99,9 +112,12 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest",
src=query.s.value, rel=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -116,9 +132,12 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN rel.uri as rel",
src=query.s.value, value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -127,9 +146,12 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel",
src=query.s.value, uri=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -142,9 +164,12 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest",
src=query.s.value,
user=user, collection=collection,
database_=self.db,
)
@ -153,9 +178,12 @@ class Processor(TriplesQueryService):
triples.append((query.s.value, data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest",
src=query.s.value,
user=user, collection=collection,
database_=self.db,
)
@ -173,9 +201,12 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src",
uri=query.p.value, value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -184,9 +215,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"RETURN src.uri as src",
uri=query.p.value, dest=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -199,9 +233,12 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest",
uri=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -210,9 +247,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest",
uri=query.p.value,
user=user, collection=collection,
database_=self.db,
)
@ -227,9 +267,12 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel",
value=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -238,9 +281,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel",
uri=query.o.value,
user=user, collection=collection,
database_=self.db,
)
@ -253,8 +299,11 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
user=user, collection=collection,
database_=self.db,
)
@ -263,8 +312,11 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
user=user, collection=collection,
database_=self.db,
)

View file

@ -92,7 +92,12 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
response = await self.rag.query(v.query, doc_limit=doc_limit)
response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit
)
await flow("response").send(
DocumentRagResponse(

View file

@ -0,0 +1 @@
from . service import *

View file

@ -0,0 +1,5 @@
#!/usr/bin/env python3
from . service import run
run()

View file

@ -0,0 +1,25 @@
You are a database schema selection expert. Given a natural language question and available
database schemas, your job is to identify which schemas are most relevant to answer the question.
## Available Schemas:
{% for schema in schemas %}
**{{ schema.name }}**: {{ schema.description }}
Fields:
{% for field in schema.fields %}
- {{ field.name }} ({{ field.type }}): {{ field.description }}
{% endfor %}
{% endfor %}
## Question:
{{ question }}
## Instructions:
1. Analyze the question to understand what data is being requested
2. Examine each schema to understand what data it contains
3. Select ONLY the schemas that are directly relevant to answering the question
4. Return your answer as a JSON array of schema names
## Response Format:
Return ONLY a JSON array of schema names, nothing else.
Example: ["customers", "orders", "products"]

View file

@ -0,0 +1,101 @@
You are a GraphQL query generation expert. Given a natural language question and relevant database
schemas, generate a precise GraphQL query to answer the question.
## Question:
{{ question }}
## Relevant Schemas:
{% for schema in schemas %}
**{{ schema.name }}**: {{ schema.description }}
Fields:
{% for field in schema.fields %}
- {{ field.name }} ({{ field.type }}){% if field.description %}: {{ field.description }}{% endif
%}{% if field.primary_key %} [PRIMARY KEY]{% endif %}{% if field.required %} [REQUIRED]{% endif
%}{% if field.indexed %} [INDEXED]{% endif %}{% if field.enum_values %} [OPTIONS: {{
field.enum_values|join(', ') }}]{% endif %}
{% endfor %}
{% endfor %}
## GraphQL Query Rules:
1. Use the schema names as GraphQL query fields (e.g., `customers`, `orders`)
2. Apply filters using the `where` parameter with nested filter objects
3. Available filter operators per field type:
- String fields: `eq`, `contains`, `startsWith`, `endsWith`, `in`, `not`, `not_in`
- Integer/Float fields: `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `not`, `not_in`
4. Use `order_by` for sorting (field name as string)
5. Use `direction` for sort direction: `ASC` or `DESC`
6. Use `limit` to restrict number of results
7. Select specific fields in the query body
## Example GraphQL Queries:
**Question**: "Show me customers from California"
```graphql
query {
customers(where: {state: {eq: "California"}}, limit: 100) {
customer_id
name
email
state
}
}
Question: "Top 10 products by price"
query {
products(order_by: "price", direction: DESC, limit: 10) {
product_id
name
price
category
}
}
Question: "Recent orders over $100"
query {
orders(
where: {
total_amount: {gt: 100}
order_date: {gte: "2024-01-01"}
}
order_by: "order_date"
direction: DESC
limit: 50
) {
order_id
customer_id
total_amount
order_date
status
}
}
Instructions:
1. Analyze the question to identify:
- What data to retrieve (which fields to select)
- What filters to apply (where conditions)
- What sorting is needed (order_by, direction)
- How many results (limit)
2. Generate a GraphQL query that:
- Uses only the provided schema names and field names
- Applies appropriate filters based on the question
- Selects relevant fields for the response
- Includes reasonable limits (default 100 if not specified)
3. If variables are needed, include them in the response
Response Format:
Return a JSON object with:
- "query": the GraphQL query string
- "variables": object with any GraphQL variables (empty object if none)
- "confidence": float between 0.0-1.0 indicating confidence in the query
Example:
{
"query": "query { customers(where: {state: {eq: \"California\"}}, limit: 100) { customer_id name
email state } }",
"variables": {},
"confidence": 0.95
}

View file

@ -0,0 +1,315 @@
"""
NLP to Structured Query Service - converts natural language questions to GraphQL queries.
Two-phase approach: 1) Select relevant schemas, 2) Generate GraphQL query.
"""
import json
import logging
from typing import Dict, Any, Optional, List
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
from ...schema import PromptRequest
from ...schema import Error, RowSchema, Field as SchemaField
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, PromptClientSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "nlp-query"
default_schema_selection_template = "schema-selection"
default_graphql_generation_template = "graphql-generation"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Config key for schemas
self.config_key = params.get("config_type", "schema")
# Configurable prompt template names
self.schema_selection_template = params.get("schema_selection_template", default_schema_selection_template)
self.graphql_generation_template = params.get("graphql_generation_template", default_graphql_generation_template)
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = QuestionToStructuredQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = QuestionToStructuredQueryResponse,
)
)
# Client spec for calling prompt service
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
logger.info("NLP Query service initialized")
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
async def phase1_select_schemas(self, question: str, flow) -> List[str]:
"""Phase 1: Use prompt service to select relevant schemas for the question"""
logger.info("Starting Phase 1: Schema selection")
# Prepare schema information for the prompt
schema_info = []
for name, schema in self.schemas.items():
schema_desc = {
"name": name,
"description": schema.description,
"fields": [{"name": f.name, "type": f.type, "description": f.description}
for f in schema.fields]
}
schema_info.append(schema_desc)
# Create prompt variables
variables = {
"question": question,
"schemas": schema_info # Pass structured data directly
}
# Call prompt service for schema selection
# Convert variables to JSON-encoded terms
terms = {k: json.dumps(v) for k, v in variables.items()}
prompt_request = PromptRequest(
id=self.schema_selection_template,
terms=terms
)
try:
response = await flow("prompt-request").request(prompt_request)
if response.error is not None:
raise Exception(f"Prompt service error: {response.error}")
# Parse the response to get selected schema names
# Response could be in either text or object field
response_data = response.text if response.text else response.object
if response_data is None:
raise Exception("Prompt service returned empty response")
# Parse JSON array of schema names
selected_schemas = json.loads(response_data)
logger.info(f"Phase 1 selected schemas: {selected_schemas}")
return selected_schemas
except Exception as e:
logger.error(f"Phase 1 schema selection failed: {e}")
raise
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]:
"""Phase 2: Generate GraphQL query using selected schemas"""
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
# Get detailed schema information for selected schemas only
selected_schema_info = []
for schema_name in selected_schemas:
if schema_name in self.schemas:
schema = self.schemas[schema_name]
schema_desc = {
"name": schema_name,
"description": schema.description,
"fields": [
{
"name": f.name,
"type": f.type,
"description": f.description,
"required": f.required,
"primary_key": f.primary,
"indexed": f.indexed,
"enum_values": f.enum_values if f.enum_values else []
}
for f in schema.fields
]
}
selected_schema_info.append(schema_desc)
# Create prompt variables for GraphQL generation
variables = {
"question": question,
"schemas": selected_schema_info # Pass structured data directly
}
# Call prompt service for GraphQL generation
# Convert variables to JSON-encoded terms
terms = {k: json.dumps(v) for k, v in variables.items()}
prompt_request = PromptRequest(
id=self.graphql_generation_template,
terms=terms
)
try:
response = await flow("prompt-request").request(prompt_request)
if response.error is not None:
raise Exception(f"Prompt service error: {response.error}")
# Parse the response to get GraphQL query and variables
# Response could be in either text or object field
response_data = response.text if response.text else response.object
if response_data is None:
raise Exception("Prompt service returned empty response")
# Parse JSON with "query" and "variables" fields
result = json.loads(response_data)
logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...")
return result
except Exception as e:
logger.error(f"Phase 2 GraphQL generation failed: {e}")
raise
async def on_message(self, msg, consumer, flow):
"""Handle incoming question to structured query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling NLP query request {id}: {request.question[:100]}...")
# Phase 1: Select relevant schemas
selected_schemas = await self.phase1_select_schemas(request.question, flow)
# Phase 2: Generate GraphQL query
graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas, flow)
# Create response
response = QuestionToStructuredQueryResponse(
error=None,
graphql_query=graphql_result.get("query", ""),
variables=graphql_result.get("variables", {}),
detected_schemas=selected_schemas,
confidence=graphql_result.get("confidence", 0.8) # Default confidence
)
logger.info("Sending NLP query response...")
await flow("response").send(response, properties={"id": id})
logger.info("NLP query request completed")
except Exception as e:
logger.error(f"Exception in NLP query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = QuestionToStructuredQueryResponse(
error = Error(
type = "nlp-query-error",
message = str(e),
),
graphql_query = "",
variables = {},
detected_schemas = [],
confidence = 0.0
)
await flow("response").send(response, properties={"id": id})
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
parser.add_argument(
'--schema-selection-template',
default=default_schema_selection_template,
help=f'Prompt template name for schema selection (default: {default_schema_selection_template})'
)
parser.add_argument(
'--graphql-generation-template',
default=default_graphql_generation_template,
help=f'Prompt template name for GraphQL generation (default: {default_graphql_generation_template})'
)
def run():
"""Entry point for nlp-query command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,2 @@
# Structured data diagnosis service
from .service import *

View file

@ -0,0 +1,494 @@
"""
Structured Data Diagnosis Service - analyzes structured data and generates descriptors.
Supports three operations: detect-type, generate-descriptor, and diagnose (combined).
"""
import json
import logging
from typing import Dict, Any, Optional
from ...schema import StructuredDataDiagnosisRequest, StructuredDataDiagnosisResponse
from ...schema import PromptRequest, Error, RowSchema, Field as SchemaField
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, PromptClientSpec
from .type_detector import detect_data_type, detect_csv_options
# Module logger
logger = logging.getLogger(__name__)
default_ident = "structured-diag"
default_csv_prompt = "diagnose-csv"
default_json_prompt = "diagnose-json"
default_xml_prompt = "diagnose-xml"
default_schema_selection_prompt = "schema-selection"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Config key for schemas
self.config_key = params.get("config_type", "schema")
# Configurable prompt template names
self.csv_prompt = params.get("csv_prompt", default_csv_prompt)
self.json_prompt = params.get("json_prompt", default_json_prompt)
self.xml_prompt = params.get("xml_prompt", default_xml_prompt)
self.schema_selection_prompt = params.get("schema_selection_prompt", default_schema_selection_prompt)
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = StructuredDataDiagnosisRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = StructuredDataDiagnosisResponse,
)
)
# Client spec for calling prompt service
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
logger.info("Structured Data Diagnosis service initialized")
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
async def on_message(self, msg, consumer, flow):
"""Handle incoming structured data diagnosis request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling structured data diagnosis request {id}: operation={request.operation}")
if request.operation == "detect-type":
response = await self.detect_type_operation(request, flow)
elif request.operation == "generate-descriptor":
response = await self.generate_descriptor_operation(request, flow)
elif request.operation == "diagnose":
response = await self.diagnose_operation(request, flow)
elif request.operation == "schema-selection":
response = await self.schema_selection_operation(request, flow)
else:
error = Error(
type="InvalidOperation",
message=f"Unknown operation: {request.operation}. Supported: detect-type, generate-descriptor, diagnose, schema-selection"
)
response = StructuredDataDiagnosisResponse(
error=error,
operation=request.operation
)
# Send response
await flow("response").send(
response, properties={"id": id}
)
except Exception as e:
logger.error(f"Error processing diagnosis request: {e}", exc_info=True)
error = Error(
type="ProcessingError",
message=f"Failed to process diagnosis request: {str(e)}"
)
response = StructuredDataDiagnosisResponse(
error=error,
operation=request.operation if request else "unknown"
)
await flow("response").send(
response, properties={"id": id}
)
async def detect_type_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
"""Handle detect-type operation"""
logger.info("Processing detect-type operation")
detected_type, confidence = detect_data_type(request.sample)
metadata = {}
if detected_type == "csv":
csv_options = detect_csv_options(request.sample)
metadata["csv_options"] = json.dumps(csv_options)
return StructuredDataDiagnosisResponse(
error=None,
operation=request.operation,
detected_type=detected_type or "",
confidence=confidence,
metadata=metadata
)
async def generate_descriptor_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
"""Handle generate-descriptor operation"""
logger.info(f"Processing generate-descriptor operation for type: {request.type}")
if not request.type:
error = Error(
type="MissingParameter",
message="Type parameter is required for generate-descriptor operation"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
if not request.schema_name:
error = Error(
type="MissingParameter",
message="Schema name parameter is required for generate-descriptor operation"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Get target schema
if request.schema_name not in self.schemas:
error = Error(
type="SchemaNotFound",
message=f"Schema '{request.schema_name}' not found in configuration"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
target_schema = self.schemas[request.schema_name]
# Generate descriptor using prompt service
descriptor = await self.generate_descriptor_with_prompt(
request.sample, request.type, target_schema, request.options, flow
)
if descriptor is None:
error = Error(
type="DescriptorGenerationFailed",
message="Failed to generate descriptor using prompt service"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
return StructuredDataDiagnosisResponse(
error=None,
operation=request.operation,
descriptor=json.dumps(descriptor),
metadata={"schema_name": request.schema_name, "type": request.type}
)
async def diagnose_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
"""Handle combined diagnose operation"""
logger.info("Processing combined diagnose operation")
# Step 1: Detect type
detected_type, confidence = detect_data_type(request.sample)
if not detected_type:
error = Error(
type="TypeDetectionFailed",
message="Unable to detect data type from sample"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Step 2: Use provided schema name or auto-select first available
schema_name = request.schema_name
if not schema_name and self.schemas:
schema_name = list(self.schemas.keys())[0]
logger.info(f"Auto-selected schema: {schema_name}")
if not schema_name:
error = Error(
type="NoSchemaAvailable",
message="No schema specified and no schemas available in configuration"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
if schema_name not in self.schemas:
error = Error(
type="SchemaNotFound",
message=f"Schema '{schema_name}' not found in configuration"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
target_schema = self.schemas[schema_name]
# Step 3: Generate descriptor
descriptor = await self.generate_descriptor_with_prompt(
request.sample, detected_type, target_schema, request.options, flow
)
if descriptor is None:
error = Error(
type="DescriptorGenerationFailed",
message="Failed to generate descriptor using prompt service"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
metadata = {
"schema_name": schema_name,
"auto_selected_schema": request.schema_name != schema_name
}
if detected_type == "csv":
csv_options = detect_csv_options(request.sample)
metadata["csv_options"] = json.dumps(csv_options)
return StructuredDataDiagnosisResponse(
error=None,
operation=request.operation,
detected_type=detected_type,
confidence=confidence,
descriptor=json.dumps(descriptor),
metadata=metadata
)
async def schema_selection_operation(self, request: StructuredDataDiagnosisRequest, flow) -> StructuredDataDiagnosisResponse:
"""Handle schema-selection operation"""
logger.info("Processing schema-selection operation")
# Prepare all schemas for the prompt - match the original config format
all_schemas = []
for schema_name, row_schema in self.schemas.items():
schema_info = {
"name": row_schema.name,
"description": row_schema.description,
"fields": [
{
"name": f.name,
"type": f.type,
"description": f.description,
"required": f.required,
"primary_key": f.primary,
"indexed": f.indexed,
"enum": f.enum_values if f.enum_values else [],
"size": f.size if hasattr(f, 'size') else 0
}
for f in row_schema.fields
]
}
all_schemas.append(schema_info)
# Create prompt variables - schemas array contains ALL schemas
# Note: The prompt expects 'question' not 'sample'
variables = {
"question": request.sample, # The prompt template expects 'question'
"schemas": all_schemas,
"options": request.options or {}
}
# Call prompt service with configurable template
terms = {k: json.dumps(v) for k, v in variables.items()}
prompt_request = PromptRequest(
id=self.schema_selection_prompt,
terms=terms
)
try:
logger.info(f"Calling prompt service for schema selection with template: {self.schema_selection_prompt}")
response = await flow("prompt-request").request(prompt_request)
if response.error:
logger.error(f"Prompt service error: {response.error.message}")
error = Error(
type="PromptServiceError",
message="Failed to select schemas using prompt service"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Check both text and object fields for response
response_data = None
if response.object and response.object.strip():
response_data = response.object.strip()
logger.debug(f"Using response from 'object' field: {response_data}")
elif response.text and response.text.strip():
response_data = response.text.strip()
logger.debug(f"Using response from 'text' field: {response_data}")
else:
logger.error("Empty response from prompt service (checked both text and object fields)")
error = Error(
type="PromptServiceError",
message="Empty response from prompt service"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Parse the response as JSON array of schema IDs
try:
schema_matches = json.loads(response_data)
if not isinstance(schema_matches, list):
raise ValueError("Response must be an array")
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse schema matches response: {e}")
error = Error(
type="ParseError",
message="Failed to parse schema selection response as JSON array"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
return StructuredDataDiagnosisResponse(
error=None,
operation=request.operation,
schema_matches=schema_matches
)
except Exception as e:
logger.error(f"Error calling prompt service: {e}", exc_info=True)
error = Error(
type="PromptServiceError",
message="Failed to select schemas using prompt service"
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
async def generate_descriptor_with_prompt(
self, sample: str, data_type: str, target_schema: RowSchema,
options: Dict[str, str], flow
) -> Optional[Dict[str, Any]]:
"""Generate descriptor using appropriate prompt service"""
# Select prompt template based on data type
prompt_templates = {
"csv": self.csv_prompt,
"json": self.json_prompt,
"xml": self.xml_prompt
}
prompt_id = prompt_templates.get(data_type)
if not prompt_id:
logger.error(f"No prompt template defined for data type: {data_type}")
return None
# Prepare schema information for prompt
schema_info = {
"name": target_schema.name,
"description": target_schema.description,
"fields": [
{
"name": f.name,
"type": f.type,
"description": f.description,
"required": f.required,
"primary_key": f.primary,
"indexed": f.indexed,
"enum_values": f.enum_values if f.enum_values else []
}
for f in target_schema.fields
]
}
# Create prompt variables
variables = {
"sample": sample,
"schemas": [schema_info], # Array with single target schema
"options": options or {}
}
# Call prompt service
terms = {k: json.dumps(v) for k, v in variables.items()}
prompt_request = PromptRequest(
id=prompt_id,
terms=terms
)
try:
logger.info(f"Calling prompt service with template: {prompt_id}")
response = await flow("prompt-request").request(prompt_request)
if response.error:
logger.error(f"Prompt service error: {response.error.message}")
return None
# Parse response
if response.object:
try:
return json.loads(response.object)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse prompt response as JSON: {e}")
logger.debug(f"Response object: {response.object}")
return None
elif response.text:
try:
return json.loads(response.text)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse prompt text response as JSON: {e}")
logger.debug(f"Response text: {response.text}")
return None
else:
logger.error("Empty response from prompt service")
return None
except Exception as e:
logger.error(f"Error calling prompt service: {e}", exc_info=True)
return None
def run():
"""Entry point for structured-diag command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,208 @@
"""
Algorithmic data type detection for structured data.
Determines if data is CSV, JSON, or XML based on content analysis.
"""
import json
import xml.etree.ElementTree as ET
import csv
from io import StringIO
import logging
from typing import Dict, Optional, Tuple
# Module logger
logger = logging.getLogger(__name__)
def detect_data_type(sample: str) -> Tuple[Optional[str], float]:
"""
Detect the data type (csv, json, xml) of a data sample.
Args:
sample: String containing data sample to analyze
Returns:
Tuple of (detected_type, confidence_score)
detected_type: "csv", "json", "xml", or None if unable to determine
confidence_score: Float between 0.0 and 1.0 indicating confidence
"""
if not sample or not sample.strip():
return None, 0.0
sample = sample.strip()
# Simple pattern matching
if sample.startswith('<?xml') or (sample.startswith('<') and '</' in sample):
return 'xml', 0.9
elif sample.startswith(('{', '[')):
return 'json', 0.9
else:
return 'csv', 0.8
def _check_json_format(sample: str) -> float:
"""Check if sample is valid JSON format"""
try:
# Must start with { or [
if not (sample.startswith('{') or sample.startswith('[')):
return 0.0
# Try to parse as JSON
data = json.loads(sample)
# Higher confidence for structured data
if isinstance(data, dict):
return 0.95
elif isinstance(data, list) and len(data) > 0:
# Check if it's an array of objects (common for structured data)
if isinstance(data[0], dict):
return 0.9
else:
return 0.7
else:
return 0.6
except (json.JSONDecodeError, ValueError):
return 0.0
def _check_xml_format(sample: str) -> float:
"""Check if sample is valid XML format"""
# XML declaration or starts with tag
if sample.startswith('<?xml') or sample.startswith('<'):
# Must have closing tags for valid XML
if '</' in sample and '>' in sample:
try:
# Quick parse test
ET.fromstring(sample)
return 0.9 # Valid XML
except ET.ParseError:
return 0.3 # Looks like XML but malformed
else:
return 0.1 # Incomplete XML
return 0.0 # Not XML
def _check_csv_format(sample: str) -> float:
"""Check if sample is valid CSV format"""
try:
lines = sample.strip().split('\n')
if len(lines) < 2:
return 0.0
# Try to parse as CSV with different delimiters
delimiters = [',', ';', '\t', '|']
best_score = 0.0
for delimiter in delimiters:
score = _check_csv_with_delimiter(sample, delimiter)
best_score = max(best_score, score)
return best_score
except Exception:
return 0.0
def _check_csv_with_delimiter(sample: str, delimiter: str) -> float:
"""Check CSV format with specific delimiter"""
try:
reader = csv.reader(StringIO(sample), delimiter=delimiter)
rows = list(reader)
if len(rows) < 2:
return 0.0
# Check consistency of column counts
first_row_cols = len(rows[0])
if first_row_cols < 2:
return 0.0
consistent_rows = 0
for row in rows[1:]:
if len(row) == first_row_cols:
consistent_rows += 1
consistency_ratio = consistent_rows / (len(rows) - 1) if len(rows) > 1 else 0
# Base score on consistency and structure
if consistency_ratio > 0.8:
# Higher score for more columns and rows
column_bonus = min(first_row_cols * 0.05, 0.2)
row_bonus = min(len(rows) * 0.01, 0.1)
return min(0.7 + column_bonus + row_bonus, 0.95)
elif consistency_ratio > 0.6:
return 0.5
else:
return 0.2
except Exception:
return 0.0
def detect_csv_options(sample: str) -> Dict[str, any]:
"""
Detect CSV-specific options like delimiter and header presence.
Args:
sample: CSV data sample
Returns:
Dict with detected options: delimiter, has_header, etc.
"""
options = {
"delimiter": ",",
"has_header": True,
"encoding": "utf-8"
}
try:
lines = sample.strip().split('\n')
if len(lines) < 2:
return options
# Detect delimiter
delimiters = [',', ';', '\t', '|']
best_delimiter = ","
best_score = 0
for delimiter in delimiters:
score = _check_csv_with_delimiter(sample, delimiter)
if score > best_score:
best_score = score
best_delimiter = delimiter
options["delimiter"] = best_delimiter
# Detect header (heuristic: first row has text, second row has more numbers/structured data)
reader = csv.reader(StringIO(sample), delimiter=best_delimiter)
rows = list(reader)
if len(rows) >= 2:
first_row = rows[0]
second_row = rows[1]
# Count numeric fields in each row
first_numeric = sum(1 for cell in first_row if _is_numeric(cell))
second_numeric = sum(1 for cell in second_row if _is_numeric(cell))
# If second row has more numeric values, first row is likely header
if second_numeric > first_numeric and first_numeric < len(first_row) * 0.7:
options["has_header"] = True
else:
options["has_header"] = False
except Exception as e:
logger.debug(f"Error detecting CSV options: {e}")
return options
def _is_numeric(value: str) -> bool:
"""Check if a string value represents a number"""
try:
float(value.strip())
return True
except (ValueError, AttributeError):
return False

View file

@ -0,0 +1 @@
from . service import *

View file

@ -0,0 +1,5 @@
#!/usr/bin/env python3
from . service import run
run()

View file

@ -0,0 +1,175 @@
"""
Structured Query Service - orchestrates natural language question processing.
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
and returns the results.
"""
import json
import logging
from typing import Dict, Any, Optional
from ...schema import StructuredQueryRequest, StructuredQueryResponse
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
from ...schema import Error
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "structured-query"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
super(Processor, self).__init__(
**params | {
"id": id,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = StructuredQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = StructuredQueryResponse,
)
)
# Client spec for calling NLP query service
self.register_specification(
RequestResponseSpec(
request_name = "nlp-query-request",
response_name = "nlp-query-response",
request_schema = QuestionToStructuredQueryRequest,
response_schema = QuestionToStructuredQueryResponse
)
)
# Client spec for calling objects query service
self.register_specification(
RequestResponseSpec(
request_name = "objects-query-request",
response_name = "objects-query-response",
request_schema = ObjectsQueryRequest,
response_schema = ObjectsQueryResponse
)
)
logger.info("Structured Query service initialized")
async def on_message(self, msg, consumer, flow):
"""Handle incoming structured query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling structured query request {id}: {request.question[:100]}...")
# Step 1: Convert question to GraphQL using NLP query service
logger.info("Step 1: Converting question to GraphQL")
nlp_request = QuestionToStructuredQueryRequest(
question=request.question,
max_results=100 # Default limit
)
nlp_response = await flow("nlp-query-request").request(nlp_request)
if nlp_response.error is not None:
raise Exception(f"NLP query service error: {nlp_response.error.message}")
if not nlp_response.graphql_query:
raise Exception("NLP query service returned empty GraphQL query")
logger.info(f"Generated GraphQL query: {nlp_response.graphql_query[:200]}...")
logger.info(f"Detected schemas: {nlp_response.detected_schemas}")
logger.info(f"Confidence: {nlp_response.confidence}")
# Step 2: Execute GraphQL query using objects query service
logger.info("Step 2: Executing GraphQL query")
# Convert variables to strings (GraphQL variables can be various types, but Pulsar schema expects strings)
variables_as_strings = {}
if nlp_response.variables:
for key, value in nlp_response.variables.items():
if isinstance(value, str):
variables_as_strings[key] = value
else:
variables_as_strings[key] = str(value)
# Use user/collection values from request
objects_request = ObjectsQueryRequest(
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
variables=variables_as_strings,
operation_name=None
)
objects_response = await flow("objects-query-request").request(objects_request)
if objects_response.error is not None:
raise Exception(f"Objects query service error: {objects_response.error.message}")
# Handle GraphQL errors from the objects query service
graphql_errors = []
if objects_response.errors:
for gql_error in objects_response.errors:
graphql_errors.append(f"{gql_error.message} (path: {gql_error.path})")
logger.info("Step 3: Returning results")
# Create response
response = StructuredQueryResponse(
error=None,
data=objects_response.data or "null", # JSON string
errors=graphql_errors
)
logger.info("Sending structured query response...")
await flow("response").send(response, properties={"id": id})
logger.info("Structured query request completed")
except Exception as e:
logger.error(f"Exception in structured query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = StructuredQueryResponse(
error = Error(
type = "structured-query-error",
message = str(e),
),
data = "null",
errors = []
)
await flow("response").send(response, properties={"id": id})
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
# No additional arguments needed for this orchestrator service
def run():
"""Entry point for structured-query command"""
Processor.launch(default_ident, __doc__)

View file

@ -3,8 +3,17 @@
Accepts entity/vector pairs and writes them to a Milvus store.
"""
import logging
from .... direct.milvus_doc_embeddings import DocVectors
from .... base import DocumentEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_store_uri = 'http://localhost:19530'
@ -23,6 +32,34 @@ class Processor(DocumentEmbeddingsStoreService):
self.vecstore = DocVectors(store_uri)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_document_embeddings(self, message):
for emb in message.chunks:
@ -33,7 +70,11 @@ class Processor(DocumentEmbeddingsStoreService):
if chunk == "": continue
for vec in emb.vectors:
self.vecstore.insert(vec, chunk)
self.vecstore.insert(
vec, chunk,
message.metadata.user,
message.metadata.collection
)
@staticmethod
def add_args(parser):
@ -46,6 +87,48 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for document embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,6 +12,10 @@ import os
import logging
from .... base import DocumentEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -55,6 +59,34 @@ class Processor(DocumentEmbeddingsStoreService):
self.last_index_name = None
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_index(self, index_name, dim):
self.pinecone.create_index(
@ -96,7 +128,7 @@ class Processor(DocumentEmbeddingsStoreService):
dim = len(vec)
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
"d-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
@ -160,6 +192,54 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for document embeddings"""
try:
index_name = f"d-{message.user}-{message.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
else:
logger.info(f"Index {index_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -10,6 +10,10 @@ import uuid
import logging
from .... base import DocumentEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -36,6 +40,37 @@ class Processor(DocumentEmbeddingsStoreService):
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
# (they may not be in unit tests)
if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'):
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_document_embeddings(self, message):
for emb in message.chunks:
@ -48,8 +83,7 @@ class Processor(DocumentEmbeddingsStoreService):
dim = len(vec)
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection + "_" +
str(dim)
message.metadata.collection
)
if collection != self.last_collection:
@ -99,6 +133,54 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for document embeddings"""
try:
collection_name = f"d_{message.user}_{message.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -3,8 +3,17 @@
Accepts entity/vector pairs and writes them to a Milvus store.
"""
import logging
from .... direct.milvus_graph_embeddings import EntityVectors
from .... base import GraphEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-write"
default_store_uri = 'http://localhost:19530'
@ -23,13 +32,45 @@ class Processor(GraphEmbeddingsStoreService):
self.vecstore = EntityVectors(store_uri)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_graph_embeddings(self, message):
for entity in message.entities:
if entity.entity.value != "" and entity.entity.value is not None:
for vec in entity.vectors:
self.vecstore.insert(vec, entity.entity.value)
self.vecstore.insert(
vec, entity.entity.value,
message.metadata.user,
message.metadata.collection
)
@staticmethod
def add_args(parser):
@ -42,6 +83,48 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for graph embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,6 +12,10 @@ import os
import logging
from .... base import GraphEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -55,6 +59,34 @@ class Processor(GraphEmbeddingsStoreService):
self.last_index_name = None
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_index(self, index_name, dim):
self.pinecone.create_index(
@ -95,7 +127,7 @@ class Processor(GraphEmbeddingsStoreService):
dim = len(vec)
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
"t-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
@ -159,6 +191,54 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for graph embeddings"""
try:
index_name = f"t-{message.user}-{message.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
else:
logger.info(f"Index {index_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -10,6 +10,10 @@ import uuid
import logging
from .... base import GraphEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -36,10 +40,41 @@ class Processor(GraphEmbeddingsStoreService):
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
# (they may not be in unit tests)
if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'):
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def get_collection(self, dim, user, collection):
cname = (
"t_" + user + "_" + collection + "_" + str(dim)
"t_" + user + "_" + collection
)
if cname != self.last_collection:
@ -105,6 +140,54 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for graph embeddings"""
try:
collection_name = f"t_{message.user}_{message.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -8,12 +8,12 @@ import urllib.parse
from ... schema import Triples, GraphEmbeddings
from ... base import FlowProcessor, ConsumerSpec
from ... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from ... tables.knowledge import KnowledgeTableStore
default_ident = "kg-store"
default_cassandra_host = "cassandra"
keyspace = "knowledge"
class Processor(FlowProcessor):
@ -22,15 +22,18 @@ class Processor(FlowProcessor):
id = params.get("id")
cassandra_host = params.get("cassandra_host", default_cassandra_host)
cassandra_user = params.get("cassandra_user")
cassandra_password = params.get("cassandra_password")
# Use helper to resolve configuration
hosts, username, password = resolve_cassandra_config(
host=params.get("cassandra_host"),
username=params.get("cassandra_username"),
password=params.get("cassandra_password")
)
super(Processor, self).__init__(
**params | {
"id": id,
"cassandra_host": cassandra_host,
"cassandra_user": cassandra_user,
"cassandra_host": ','.join(hosts),
"cassandra_username": username,
}
)
@ -51,9 +54,9 @@ class Processor(FlowProcessor):
)
self.table_store = KnowledgeTableStore(
cassandra_host = cassandra_host.split(","),
cassandra_user = cassandra_user,
cassandra_password = cassandra_password,
cassandra_host = hosts,
cassandra_username = username,
cassandra_password = password,
keyspace = keyspace,
)
@ -71,6 +74,7 @@ class Processor(FlowProcessor):
def add_args(parser):
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
def run():

View file

@ -1,3 +0,0 @@
from . write import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -1,61 +0,0 @@
"""
Accepts entity/vector pairs and writes them to a Milvus store.
"""
from .... schema import ObjectEmbeddings
from .... schema import object_embeddings_store_queue
from .... log_level import LogLevel
from .... direct.milvus_object_embeddings import ObjectVectors
from .... base import Consumer
module = "oe-write"
default_input_queue = object_embeddings_store_queue
default_subscriber = module
default_store_uri = 'http://localhost:19530'
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": ObjectEmbeddings,
"store_uri": store_uri,
}
)
self.vecstore = ObjectVectors(store_uri)
async def handle(self, msg):
v = msg.value()
if v.id != "" and v.id is not None:
for vec in v.vectors:
self.vecstore.insert(vec, v.name, v.key_name, v.id)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Milvus store URI (default: {default_store_uri})'
)
def run():
Processor.launch(module, __doc__)

View file

@ -13,13 +13,15 @@ from cassandra import ConsistencyLevel
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... schema import StorageManagementRequest, StorageManagementResponse
from .... schema import object_storage_management_topic, storage_management_response_topic
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-write"
default_graph_host = 'localhost'
class Processor(FlowProcessor):
@ -27,10 +29,22 @@ class Processor(FlowProcessor):
id = params.get("id", default_ident)
# Cassandra connection parameters
self.graph_host = params.get("graph_host", default_graph_host)
self.graph_username = params.get("graph_username", None)
self.graph_password = params.get("graph_password", None)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
@ -38,7 +52,7 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"config-type": self.config_key,
"config_type": self.config_key,
}
)
@ -49,7 +63,38 @@ class Processor(FlowProcessor):
handler = self.on_object
)
)
# Set up storage management consumer and producer directly
# (FlowProcessor doesn't support topic-based specs outside of flows)
from .... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Create storage management consumer
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=object_storage_management_topic,
subscriber=f"{id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Create storage management response producer
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
@ -70,20 +115,20 @@ class Processor(FlowProcessor):
return
try:
if self.graph_username and self.graph_password:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.graph_username,
password=self.graph_password
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=[self.graph_host],
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=[self.graph_host])
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.graph_host}")
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
@ -299,7 +344,7 @@ class Processor(FlowProcessor):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(f"Storing object for schema {obj.schema_name} from {obj.metadata.id}")
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Get schema definition
schema = self.schemas.get(obj.schema_name)
@ -316,59 +361,161 @@ class Processor(FlowProcessor):
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Build column names and values
columns = ["collection"]
values = [obj.metadata.collection]
placeholders = ["%s"]
# Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key:
import uuid
columns.append("synthetic_id")
values.append(uuid.uuid4())
placeholders.append("%s")
# Process fields
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
raw_value = obj.values.get(field.name)
# Process each object in the batch
for obj_index, value_map in enumerate(obj.values):
# Build column names and values for this object
columns = ["collection"]
values = [obj.metadata.collection]
placeholders = ["%s"]
# Handle required fields
if field.required and raw_value is None:
logger.warning(f"Required field {field.name} is missing in object")
# Continue anyway - Cassandra doesn't enforce NOT NULL
# Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key:
import uuid
columns.append("synthetic_id")
values.append(uuid.uuid4())
placeholders.append("%s")
# Check if primary key field is NULL
if field.primary and raw_value is None:
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object")
return
# Process fields for this object
skip_object = False
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
raw_value = value_map.get(field.name)
# Handle required fields
if field.required and raw_value is None:
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
# Continue anyway - Cassandra doesn't enforce NOT NULL
# Check if primary key field is NULL
if field.primary and raw_value is None:
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
skip_object = True
break
# Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type)
columns.append(safe_field_name)
values.append(converted_value)
placeholders.append("%s")
# Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type)
# Skip this object if primary key validation failed
if skip_object:
continue
columns.append(safe_field_name)
values.append(converted_value)
placeholders.append("%s")
# Build and execute insert query
insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
# Debug: Show data being inserted
logger.debug(f"Storing {obj.schema_name}: {dict(zip(columns, values))}")
if len(columns) != len(values) or len(columns) != len(placeholders):
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
# Build and execute insert query for this object
insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
# Debug: Show data being inserted
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
if len(columns) != len(values) or len(columns) != len(placeholders):
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
try:
# Convert to tuple - Cassandra driver requires tuple for parameters
self.session.execute(insert_cql, tuple(values))
except Exception as e:
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
raise
async def on_storage_management(self, msg, consumer, flow):
"""Handle storage management requests for collection operations"""
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
try:
# Convert to tuple - Cassandra driver requires tuple for parameters
self.session.execute(insert_cql, tuple(values))
if msg.operation == "delete-collection":
await self.delete_collection(msg.user, msg.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
else:
logger.warning(f"Unknown storage management operation: {msg.operation}")
# Send error response
from .... schema import Error
response = StorageManagementResponse(
error=Error(
type="unknown_operation",
message=f"Unknown operation: {msg.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to insert object: {e}", exc_info=True)
raise
logger.error(f"Error handling storage management request: {e}", exc_info=True)
# Send error response
from .... schema import Error
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.send("storage-response", response)
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if safe_keyspace not in self.known_keyspaces:
# Query to verify keyspace exists
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(safe_keyspace)
# Get all tables in the keyspace that might contain collection data
get_tables_cql = """
SELECT table_name FROM system_schema.tables
WHERE keyspace_name = %s
"""
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
tables_deleted = 0
for row in tables:
table_name = row.table_name
# Check if the table has a collection column
check_column_cql = """
SELECT column_name FROM system_schema.columns
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
"""
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
if result.one():
# Table has collection column, delete data for this collection
try:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{table_name}
WHERE collection = %s
"""
self.session.execute(delete_cql, (collection,))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""
@ -381,24 +528,7 @@ class Processor(FlowProcessor):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-g', '--graph-host',
default=default_graph_host,
help=f'Cassandra host (default: {default_graph_host})'
)
parser.add_argument(
'--graph-username',
default=None,
help='Cassandra username'
)
parser.add_argument(
'--graph-password',
default=None,
help='Cassandra password'
)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',

View file

@ -3,6 +3,8 @@
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
"""
raise RuntimeError("This code is no longer in use")
import pulsar
import base64
import os
@ -14,9 +16,9 @@ from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
from .... schema import Rows
from .... schema import rows_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
@ -24,9 +26,8 @@ logger = logging.getLogger(__name__)
module = "rows-write"
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
default_input_queue = rows_store_queue
default_input_queue = "rows-store" # Default queue name
default_subscriber = module
default_graph_host='localhost'
class Processor(Consumer):
@ -34,26 +35,35 @@ class Processor(Consumer):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Rows,
"graph_host": graph_host,
"graph_username": graph_username,
"graph_password": graph_password,
"cassandra_host": ','.join(hosts),
"cassandra_username": username,
"cassandra_password": password,
}
)
if graph_username and graph_password:
auth_provider = PlainTextAuthProvider(username=graph_username, password=graph_password)
self.cluster = Cluster(graph_host.split(","), auth_provider=auth_provider, ssl_context=ssl_context)
if username and password:
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
self.cluster = Cluster(graph_host.split(","))
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
self.tables = set()
@ -128,24 +138,7 @@ class Processor(Consumer):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-g', '--graph-host',
default="localhost",
help=f'Graph host (default: localhost)'
)
parser.add_argument(
'--graph-username',
default=None,
help=f'Cassandra username'
)
parser.add_argument(
'--graph-password',
default=None,
help=f'Cassandra password'
)
add_cassandra_args(parser)
def run():

View file

@ -10,15 +10,19 @@ import argparse
import time
import logging
from .... direct.cassandra import TrustGraph
from .... direct.cassandra_kg import KnowledgeGraph
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
default_ident = "triples-write"
default_graph_host='localhost'
class Processor(TriplesStoreService):
@ -26,80 +30,175 @@ class Processor(TriplesStoreService):
id = params.get("id", default_ident)
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
super(Processor, self).__init__(
**params | {
"graph_host": graph_host,
"graph_username": graph_username
"cassandra_host": ','.join(hosts),
"cassandra_username": username
}
)
self.graph_host = [graph_host]
self.username = graph_username
self.password = graph_password
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
self.table = None
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_triples(self, message):
table = (message.metadata.user, message.metadata.collection)
user = message.metadata.user
if self.table is None or self.table != table:
if self.table is None or self.table != user:
self.tg = None
try:
if self.username and self.password:
self.tg = TrustGraph(
hosts=self.graph_host,
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
table=message.metadata.collection,
username=self.username, password=self.password
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = TrustGraph(
hosts=self.graph_host,
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
table=message.metadata.collection,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = table
self.table = user
for t in message.triples:
self.tg.insert(
message.metadata.collection,
t.s.value,
t.p.value,
t.o.value
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection from the unified triples table"""
try:
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != message.user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
raise
self.table = message.user
# Delete all triples for this collection from the unified table
# In the unified table schema, collection is the partition key
delete_cql = """
DELETE FROM triples
WHERE collection = ?
"""
try:
self.tg.session.execute(delete_cql, (message.collection,))
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
except Exception as e:
logger.error(f"Failed to delete collection data: {e}")
raise
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
@staticmethod
def add_args(parser):
TriplesStoreService.add_args(parser)
parser.add_argument(
'-g', '--graph-host',
default="localhost",
help=f'Graph host (default: localhost)'
)
parser.add_argument(
'--graph-username',
default=None,
help=f'Cassandra username'
)
parser.add_argument(
'--graph-password',
default=None,
help=f'Cassandra password'
)
add_cassandra_args(parser)
def run():

View file

@ -13,6 +13,10 @@ import logging
from falkordb import FalkorDB
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -40,14 +44,44 @@ class Processor(TriplesStoreService):
self.io = FalkorDB.from_url(graph_url).select_graph(database)
def create_node(self, uri):
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
logger.debug(f"Create node {uri}")
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_node(self, uri, user, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
res = self.io.query(
"MERGE (n:Node {uri: $uri})",
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
params={
"uri": uri,
"user": user,
"collection": collection,
},
)
@ -56,14 +90,16 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def create_literal(self, value):
def create_literal(self, value, user, collection):
logger.debug(f"Create literal {value}")
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
res = self.io.query(
"MERGE (n:Literal {value: $value})",
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
params={
"value": value,
"user": user,
"collection": collection,
},
)
@ -72,18 +108,20 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def relate_node(self, src, uri, dest):
def relate_node(self, src, uri, dest, user, collection):
logger.debug(f"Create node rel {src} {uri} {dest}")
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"collection": collection,
},
)
@ -92,18 +130,20 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def relate_literal(self, src, uri, dest):
def relate_literal(self, src, uri, dest, user, collection):
logger.debug(f"Create literal rel {src} {uri} {dest}")
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"collection": collection,
},
)
@ -113,17 +153,20 @@ class Processor(TriplesStoreService):
))
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
for t in message.triples:
self.create_node(t.s.value)
self.create_node(t.s.value, user, collection)
if t.o.is_uri:
self.create_node(t.o.value)
self.relate_node(t.s.value, t.p.value, t.o.value)
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
else:
self.create_literal(t.o.value)
self.relate_literal(t.s.value, t.p.value, t.o.value)
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
@staticmethod
def add_args(parser):
@ -142,6 +185,59 @@ class Processor(TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for FalkorDB triples"""
try:
# Delete all nodes and literals for this user/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -13,6 +13,10 @@ import logging
from neo4j import GraphDatabase
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -49,6 +53,34 @@ class Processor(TriplesStoreService):
with self.io.session(database=self.db) as session:
self.create_indexes(session)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_indexes(self, session):
# Race condition, index creation failure is ignored. Right thing
@ -61,6 +93,7 @@ class Processor(TriplesStoreService):
logger.info("Create indexes...")
# Legacy indexes for backwards compatibility
try:
session.run(
"CREATE INDEX ON :Node",
@ -97,15 +130,48 @@ class Processor(TriplesStoreService):
# Maybe index already exists
logger.warning("Index create failure ignored")
# New indexes for user/collection filtering
try:
session.run(
"CREATE INDEX ON :Node(user)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
logger.warning("Index create failure ignored")
try:
session.run(
"CREATE INDEX ON :Node(collection)"
)
except Exception as e:
logger.warning(f"Collection index create failure: {e}")
logger.warning("Index create failure ignored")
try:
session.run(
"CREATE INDEX ON :Literal(user)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
logger.warning("Index create failure ignored")
try:
session.run(
"CREATE INDEX ON :Literal(collection)"
)
except Exception as e:
logger.warning(f"Collection index create failure: {e}")
logger.warning("Index create failure ignored")
logger.info("Index creation done")
def create_node(self, uri):
def create_node(self, uri, user, collection):
logger.debug(f"Create node {uri}")
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri})",
uri=uri,
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
database_=self.db,
).summary
@ -114,13 +180,13 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value):
def create_literal(self, value, user, collection):
logger.debug(f"Create literal {value}")
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value})",
value=value,
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
database_=self.db,
).summary
@ -129,15 +195,15 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest):
def relate_node(self, src, uri, dest, user, collection):
logger.debug(f"Create node rel {src} {uri} {dest}")
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src, dest=dest, uri=uri,
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
database_=self.db,
).summary
@ -146,15 +212,15 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest):
def relate_literal(self, src, uri, dest, user, collection):
logger.debug(f"Create literal rel {src} {uri} {dest}")
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src, dest=dest, uri=uri,
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
database_=self.db,
).summary
@ -163,59 +229,64 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def create_triple(self, tx, t):
def create_triple(self, tx, t, user, collection):
# Create new s node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri})",
uri=t.s.value
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=t.s.value, user=user, collection=collection
)
if t.o.is_uri:
# Create new o node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri})",
uri=t.o.value
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=t.o.value, user=user, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=t.s.value, dest=t.o.value, uri=t.p.value,
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
)
else:
# Create new o literal with given uri, if not exists
result = tx.run(
"MERGE (n:Literal {value: $value})",
value=t.o.value
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=t.o.value, user=user, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=t.s.value, dest=t.o.value, uri=t.p.value,
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
)
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
for t in message.triples:
# self.create_node(t.s.value)
self.create_node(t.s.value, user, collection)
# if t.o.is_uri:
# self.create_node(t.o.value)
# self.relate_node(t.s.value, t.p.value, t.o.value)
# else:
# self.create_literal(t.o.value)
# self.relate_literal(t.s.value, t.p.value, t.o.value)
if t.o.is_uri:
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
else:
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
with self.io.session(database=self.db) as session:
session.execute_write(self.create_triple, t)
# Alternative implementation using transactions
# with self.io.session(database=self.db) as session:
# session.execute_write(self.create_triple, t, user, collection)
@staticmethod
def add_args(parser):
@ -246,6 +317,67 @@ class Processor(TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,6 +12,10 @@ import logging
from neo4j import GraphDatabase
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -49,6 +53,34 @@ class Processor(TriplesStoreService):
with self.io.session(database=self.db) as session:
self.create_indexes(session)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_indexes(self, session):
# Race condition, index creation failure is ignored. Right thing
@ -61,6 +93,7 @@ class Processor(TriplesStoreService):
logger.info("Create indexes...")
# Legacy indexes for backwards compatibility
try:
session.run(
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
@ -88,15 +121,50 @@ class Processor(TriplesStoreService):
# Maybe index already exists
logger.warning("Index create failure ignored")
# New compound indexes for user/collection filtering
try:
session.run(
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
logger.warning("Index create failure ignored")
try:
session.run(
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
logger.warning("Index create failure ignored")
# Note: Neo4j doesn't support compound indexes on relationships in all versions
# Try to create individual indexes on relationship properties
try:
session.run(
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
)
except Exception as e:
logger.warning(f"Relationship index create failure: {e}")
logger.warning("Index create failure ignored")
try:
session.run(
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)",
)
except Exception as e:
logger.warning(f"Relationship index create failure: {e}")
logger.warning("Index create failure ignored")
logger.info("Index creation done")
def create_node(self, uri):
def create_node(self, uri, user, collection):
logger.debug(f"Create node {uri}")
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri})",
uri=uri,
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
database_=self.db,
).summary
@ -105,13 +173,13 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value):
def create_literal(self, value, user, collection):
logger.debug(f"Create literal {value}")
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value})",
value=value,
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
database_=self.db,
).summary
@ -120,15 +188,15 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest):
def relate_node(self, src, uri, dest, user, collection):
logger.debug(f"Create node rel {src} {uri} {dest}")
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src, dest=dest, uri=uri,
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
database_=self.db,
).summary
@ -137,15 +205,15 @@ class Processor(TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest):
def relate_literal(self, src, uri, dest, user, collection):
logger.debug(f"Create literal rel {src} {uri} {dest}")
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src, dest=dest, uri=uri,
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
database_=self.db,
).summary
@ -156,16 +224,20 @@ class Processor(TriplesStoreService):
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
for t in message.triples:
self.create_node(t.s.value)
self.create_node(t.s.value, user, collection)
if t.o.is_uri:
self.create_node(t.o.value)
self.relate_node(t.s.value, t.p.value, t.o.value)
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
else:
self.create_literal(t.o.value)
self.relate_literal(t.s.value, t.p.value, t.o.value)
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
@staticmethod
def add_args(parser):
@ -196,6 +268,67 @@ class Processor(TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -17,17 +17,21 @@ class ConfigTableStore:
def __init__(
self,
cassandra_host, cassandra_user, cassandra_password, keyspace,
cassandra_host, cassandra_username, cassandra_password, keyspace,
):
self.keyspace = keyspace
logger.info("Connecting to Cassandra...")
if cassandra_user and cassandra_password:
# Ensure cassandra_host is a list
if isinstance(cassandra_host, str):
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_user, password=cassandra_password
username=cassandra_username, password=cassandra_password
)
self.cluster = Cluster(
cassandra_host,

View file

@ -17,17 +17,21 @@ class KnowledgeTableStore:
def __init__(
self,
cassandra_host, cassandra_user, cassandra_password, keyspace,
cassandra_host, cassandra_username, cassandra_password, keyspace,
):
self.keyspace = keyspace
logger.info("Connecting to Cassandra...")
if cassandra_user and cassandra_password:
# Ensure cassandra_host is a list
if isinstance(cassandra_host, str):
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_user, password=cassandra_password
username=cassandra_username, password=cassandra_password
)
self.cluster = Cluster(
cassandra_host,

View file

@ -21,17 +21,21 @@ class LibraryTableStore:
def __init__(
self,
cassandra_host, cassandra_user, cassandra_password, keyspace,
cassandra_host, cassandra_username, cassandra_password, keyspace,
):
self.keyspace = keyspace
logger.info("Connecting to Cassandra...")
if cassandra_user and cassandra_password:
# Ensure cassandra_host is a list
if isinstance(cassandra_host, str):
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_user, password=cassandra_password
username=cassandra_username, password=cassandra_password
)
self.cluster = Cluster(
cassandra_host,
@ -107,6 +111,21 @@ class LibraryTableStore:
);
""");
logger.debug("collections table...")
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS collections (
user text,
collection text,
name text,
description text,
tags set<text>,
created_at timestamp,
updated_at timestamp,
PRIMARY KEY (user, collection)
);
""");
logger.info("Cassandra schema OK.")
def prepare_statements(self):
@ -183,6 +202,43 @@ class LibraryTableStore:
LIMIT 1
""")
# Collection management statements
self.insert_collection_stmt = self.cassandra.prepare("""
INSERT INTO collections
(user, collection, name, description, tags, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""")
self.update_collection_stmt = self.cassandra.prepare("""
UPDATE collections
SET name = ?, description = ?, tags = ?, updated_at = ?
WHERE user = ? AND collection = ?
""")
self.get_collection_stmt = self.cassandra.prepare("""
SELECT collection, name, description, tags, created_at, updated_at
FROM collections
WHERE user = ? AND collection = ?
""")
self.list_collections_stmt = self.cassandra.prepare("""
SELECT collection, name, description, tags, created_at, updated_at
FROM collections
WHERE user = ?
""")
self.delete_collection_stmt = self.cassandra.prepare("""
DELETE FROM collections
WHERE user = ? AND collection = ?
""")
self.collection_exists_stmt = self.cassandra.prepare("""
SELECT collection
FROM collections
WHERE user = ? AND collection = ?
LIMIT 1
""")
self.list_processing_stmt = self.cassandra.prepare("""
SELECT
id, document_id, time, flow, collection, tags
@ -517,3 +573,145 @@ class LibraryTableStore:
return lst
# Collection management methods
async def ensure_collection_exists(self, user, collection):
"""Ensure collection metadata record exists, create if not"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.collection_exists_stmt, [user, collection]
)
if resp:
return
import datetime
now = datetime.datetime.now()
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.insert_collection_stmt,
[user, collection, collection, "", set(), now, now]
)
logger.debug(f"Created collection metadata for {user}/{collection}")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
raise
async def list_collections(self, user, tag_filter=None):
"""List collections for a user, optionally filtered by tags"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.list_collections_stmt, [user]
)
collections = []
for row in resp:
collection_data = {
"user": user,
"collection": row[0],
"name": row[1] or row[0],
"description": row[2] or "",
"tags": list(row[3]) if row[3] else [],
"created_at": row[4].isoformat() if row[4] else "",
"updated_at": row[5].isoformat() if row[5] else ""
}
if tag_filter:
collection_tags = set(collection_data["tags"])
filter_tags = set(tag_filter)
if not filter_tags.intersection(collection_tags):
continue
collections.append(collection_data)
return collections
except Exception as e:
logger.error(f"Error listing collections: {e}")
raise
async def update_collection(self, user, collection, name=None, description=None, tags=None):
"""Update collection metadata"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.get_collection_stmt, [user, collection]
)
if not resp:
raise RequestError(f"Collection {collection} not found")
row = resp.one()
current_name = row[1] or collection
current_description = row[2] or ""
current_tags = set(row[3]) if row[3] else set()
new_name = name if name is not None else current_name
new_description = description if description is not None else current_description
new_tags = set(tags) if tags is not None else current_tags
import datetime
now = datetime.datetime.now()
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.update_collection_stmt,
[new_name, new_description, new_tags, now, user, collection]
)
return {
"user": user, "collection": collection, "name": new_name,
"description": new_description, "tags": list(new_tags),
"updated_at": now.isoformat()
}
except Exception as e:
logger.error(f"Error updating collection: {e}")
raise
async def delete_collection(self, user, collection):
"""Delete collection metadata record"""
try:
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.delete_collection_stmt, [user, collection]
)
logger.debug(f"Deleted collection metadata for {user}/{collection}")
except Exception as e:
logger.error(f"Error deleting collection metadata: {e}")
raise
async def get_collection(self, user, collection):
"""Get collection metadata"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.get_collection_stmt, [user, collection]
)
if not resp:
return None
row = resp.one()
return {
"user": user, "collection": row[0], "name": row[1] or row[0],
"description": row[2] or "", "tags": list(row[3]) if row[3] else [],
"created_at": row[4].isoformat() if row[4] else "",
"updated_at": row[5].isoformat() if row[5] else ""
}
except Exception as e:
logger.error(f"Error getting collection: {e}")
raise
async def create_collection(self, user, collection, name=None, description=None, tags=None):
"""Create a new collection metadata record"""
try:
import datetime
now = datetime.datetime.now()
# Set defaults for optional parameters
name = name if name is not None else collection
description = description if description is not None else ""
tags = tags if tags is not None else set()
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.insert_collection_stmt,
[user, collection, name, description, tags, now, now]
)
logger.info(f"Created collection {user}/{collection}")
# Return the created collection data
return {
"user": user,
"collection": collection,
"name": name,
"description": description,
"tags": list(tags) if isinstance(tags, set) else tags,
"created_at": now.isoformat(),
"updated_at": now.isoformat()
}
except Exception as e:
logger.error(f"Error creating collection: {e}")
raise