release/v1.4 -> master (#548)

This commit is contained in:
cybermaggedon 2025-10-06 17:54:26 +01:00 committed by GitHub
parent 3ec2cd54f9
commit 2bd68ed7f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 8571 additions and 1740 deletions

View file

@ -9,14 +9,14 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(FlowProcessor):
class Processor(ChunkingService):
def __init__(self, **params):
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
**params | { "id": id }
)
# Store default values for parameter override
self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
@ -65,7 +69,22 @@ class Processor(FlowProcessor):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
texts = self.text_splitter.create_documents(
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
self.default_chunk_size,
self.default_chunk_overlap
)
# Create text splitter with effective parameters
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
@ -89,7 +108,7 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ChunkingService.add_args(parser)
parser.add_argument(
'-z', '--chunk-size',

View file

@ -9,14 +9,14 @@ from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(FlowProcessor):
class Processor(ChunkingService):
def __init__(self, **params):
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
**params | { "id": id }
)
# Store default values for parameter override
self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
@ -64,7 +68,21 @@ class Processor(FlowProcessor):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
texts = self.text_splitter.create_documents(
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
self.default_chunk_size,
self.default_chunk_overlap
)
# Create text splitter with effective parameters
text_splitter = TokenTextSplitter(
encoding_name="cl100k_base",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
@ -88,7 +106,7 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ChunkingService.add_args(parser)
parser.add_argument(
'-z', '--chunk-size',

View file

@ -10,6 +10,95 @@ class FlowConfig:
def __init__(self, config):
self.config = config
# Cache for parameter type definitions to avoid repeated lookups
self.param_type_cache = {}
async def resolve_parameters(self, flow_class, user_params):
"""
Resolve parameters by merging user-provided values with defaults.
Args:
flow_class: The flow class definition dict
user_params: User-provided parameters dict (may be None or empty)
Returns:
Complete parameter dict with user values and defaults merged (all values as strings)
"""
# If the flow class has no parameters section, return user params as-is (stringified)
if "parameters" not in flow_class:
if not user_params:
return {}
# Ensure all values are strings
return {k: str(v) for k, v in user_params.items()}
resolved = {}
flow_params = flow_class["parameters"]
user_params = user_params if user_params else {}
# First pass: resolve parameters with explicit values or defaults
for param_name, param_meta in flow_params.items():
# Check if user provided a value
if param_name in user_params:
# Store as string
resolved[param_name] = str(user_params[param_name])
else:
# Look up the parameter type definition
param_type = param_meta.get("type")
if param_type:
# Check cache first
if param_type not in self.param_type_cache:
try:
# Fetch parameter type definition from config store
type_def = await self.config.get("parameter-types").get(param_type)
if type_def:
self.param_type_cache[param_type] = json.loads(type_def)
else:
logger.warning(f"Parameter type '{param_type}' not found in config")
self.param_type_cache[param_type] = {}
except Exception as e:
logger.error(f"Error fetching parameter type '{param_type}': {e}")
self.param_type_cache[param_type] = {}
# Apply default from type definition (as string)
type_def = self.param_type_cache[param_type]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
if isinstance(default_value, bool):
resolved[param_name] = "true" if default_value else "false"
else:
resolved[param_name] = str(default_value)
elif type_def.get("required", False):
# Required parameter with no default and no user value
raise RuntimeError(f"Required parameter '{param_name}' not provided and has no default")
# Second pass: handle controlled-by relationships
for param_name, param_meta in flow_params.items():
if param_name not in resolved and "controlled-by" in param_meta:
controller = param_meta["controlled-by"]
if controller in resolved:
# Inherit value from controlling parameter (already a string)
resolved[param_name] = resolved[controller]
else:
# Controller has no value, try to get default from type definition
param_type = param_meta.get("type")
if param_type and param_type in self.param_type_cache:
type_def = self.param_type_cache[param_type]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
if isinstance(default_value, bool):
resolved[param_name] = "true" if default_value else "false"
else:
resolved[param_name] = str(default_value)
# Include any extra parameters from user that weren't in flow class definition
# This allows for forward compatibility (ensure they're strings)
for key, value in user_params.items():
if key not in resolved:
resolved[key] = str(value)
return resolved
async def handle_list_classes(self, msg):
@ -68,11 +157,14 @@ class FlowConfig:
async def handle_get_flow(self, msg):
flow = await self.config.get("flows").get(msg.flow_id)
flow_data = await self.config.get("flows").get(msg.flow_id)
flow = json.loads(flow_data)
return FlowResponse(
error = None,
flow = flow,
flow = flow_data,
description = flow.get("description", ""),
parameters = flow.get("parameters", {}),
)
async def handle_start_flow(self, msg):
@ -83,45 +175,65 @@ class FlowConfig:
if msg.flow_id is None:
raise RuntimeError("No flow ID")
if msg.flow_id in await self.config.get("flows").values():
if msg.flow_id in await self.config.get("flows").keys():
raise RuntimeError("Flow already exists")
if msg.description is None:
raise RuntimeError("No description")
if msg.class_name not in await self.config.get("flow-classes").values():
if msg.class_name not in await self.config.get("flow-classes").keys():
raise RuntimeError("Class does not exist")
def repl_template(tmp):
return tmp.replace(
"{class}", msg.class_name
).replace(
"{id}", msg.flow_id
)
cls = json.loads(
await self.config.get("flow-classes").get(msg.class_name)
)
# Resolve parameters by merging user-provided values with defaults
user_params = msg.parameters if msg.parameters else {}
parameters = await self.resolve_parameters(cls, user_params)
# Log the resolved parameters for debugging
logger.debug(f"User provided parameters: {user_params}")
logger.debug(f"Resolved parameters (with defaults): {parameters}")
# Apply parameter substitution to template replacement function
def repl_template_with_params(tmp):
result = tmp.replace(
"{class}", msg.class_name
).replace(
"{id}", msg.flow_id
)
# Apply parameter substitutions
for param_name, param_value in parameters.items():
result = result.replace(f"{{{param_name}}}", str(param_value))
return result
for kind in ("class", "flow"):
for k, v in cls[kind].items():
processor, variant = k.split(":", 1)
variant = repl_template(variant)
variant = repl_template_with_params(variant)
v = {
repl_template(k2): repl_template(v2)
repl_template_with_params(k2): repl_template_with_params(v2)
for k2, v2 in v.items()
}
flac = await self.config.get("flows-active").values()
if processor in flac:
target = json.loads(flac[processor])
flac = await self.config.get("flows-active").get(processor)
if flac is not None:
target = json.loads(flac)
else:
target = {}
# The condition if variant not in target: means it only adds
# the configuration if the variant doesn't already exist.
# If "everything" already exists in the target with old
# values, they won't update.
if variant not in target:
target[variant] = v
@ -131,10 +243,10 @@ class FlowConfig:
def repl_interface(i):
if isinstance(i, str):
return repl_template(i)
return repl_template_with_params(i)
else:
return {
k: repl_template(v)
k: repl_template_with_params(v)
for k, v in i.items()
}
@ -152,6 +264,7 @@ class FlowConfig:
"description": msg.description,
"class-name": msg.class_name,
"interfaces": interfaces,
"parameters": parameters,
})
)
@ -177,15 +290,20 @@ class FlowConfig:
raise RuntimeError("Internal error: flow has no flow class")
class_name = flow["class-name"]
parameters = flow.get("parameters", {})
cls = json.loads(await self.config.get("flow-classes").get(class_name))
def repl_template(tmp):
return tmp.replace(
result = tmp.replace(
"{class}", class_name
).replace(
"{id}", msg.flow_id
)
# Apply parameter substitutions
for param_name, param_value in parameters.items():
result = result.replace(f"{{{param_name}}}", str(param_value))
return result
for kind in ("flow",):
@ -195,10 +313,10 @@ class FlowConfig:
variant = repl_template(variant)
flac = await self.config.get("flows-active").values()
flac = await self.config.get("flows-active").get(processor)
if processor in flac:
target = json.loads(flac[processor])
if flac is not None:
target = json.loads(flac)
else:
target = {}
@ -209,7 +327,7 @@ class FlowConfig:
processor, json.dumps(target)
)
if msg.flow_id in await self.config.get("flows").values():
if msg.flow_id in await self.config.get("flows").keys():
await self.config.get("flows").delete(msg.flow_id)
await self.config.inc_version()

View file

@ -24,16 +24,12 @@ class KnowledgeGraph:
self.keyspace = keyspace
self.username = username
# Multi-table schema design for optimal performance
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
if self.use_legacy:
self.table = "triples" # Legacy single table
else:
# New optimized tables
self.subject_table = "triples_s"
self.po_table = "triples_p"
self.object_table = "triples_o"
# Optimized multi-table schema with collection deletion support
self.subject_table = "triples_s"
self.po_table = "triples_p"
self.object_table = "triples_o"
self.collection_table = "triples_collection" # For SPO queries and deletion
self.collection_metadata_table = "collection_metadata" # For tracking which collections exist
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
@ -47,9 +43,7 @@ class KnowledgeGraph:
_active_clusters.append(self.cluster)
self.init()
if not self.use_legacy:
self.prepare_statements()
self.prepare_statements()
def clear(self):
@ -70,42 +64,13 @@ class KnowledgeGraph:
""");
self.session.set_keyspace(self.keyspace)
self.init_optimized_schema()
if self.use_legacy:
self.init_legacy_schema()
else:
self.init_optimized_schema()
def init_legacy_schema(self):
"""Initialize legacy single-table schema for backward compatibility"""
self.session.execute(f"""
create table if not exists {self.table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
""");
self.session.execute(f"""
create index if not exists {self.table}_s
ON {self.table} (s);
""");
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
""");
self.session.execute(f"""
create index if not exists {self.table}_o
ON {self.table} (o);
""");
def init_optimized_schema(self):
"""Initialize optimized multi-table schema for performance"""
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
# Table 1: Subject-centric queries (get_s, get_sp, get_os)
# Compound partition key for optimal data distribution
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.subject_table} (
collection text,
@ -117,6 +82,7 @@ class KnowledgeGraph:
""");
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
# Compound partition key for optimal data distribution
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.po_table} (
collection text,
@ -128,6 +94,7 @@ class KnowledgeGraph:
""");
# Table 3: Object-centric queries (get_o)
# Compound partition key for optimal data distribution
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.object_table} (
collection text,
@ -138,7 +105,29 @@ class KnowledgeGraph:
);
""");
logger.info("Optimized multi-table schema initialized")
# Table 4: Collection management and SPO queries (get_spo)
# Simple partition key enables efficient collection deletion
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.collection_table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
""");
# Table 5: Collection metadata tracking
# Tracks which collections exist without polluting triple data
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.collection_metadata_table} (
collection text,
created_at timestamp,
PRIMARY KEY (collection)
);
""");
logger.info("Optimized multi-table schema initialized (5 tables)")
def prepare_statements(self):
"""Prepare statements for optimal performance"""
@ -155,6 +144,10 @@ class KnowledgeGraph:
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
)
self.insert_collection_stmt = self.session.prepare(
f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
# Query statements for optimized access
self.get_all_stmt = self.session.prepare(
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING"
@ -186,158 +179,177 @@ class KnowledgeGraph:
)
self.get_spo_stmt = self.session.prepare(
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
f"SELECT s as x FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
)
logger.info("Prepared statements initialized for optimal performance")
# Delete statements for collection deletion
self.delete_subject_stmt = self.session.prepare(
f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
)
self.delete_po_stmt = self.session.prepare(
f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?"
)
self.delete_object_stmt = self.session.prepare(
f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?"
)
self.delete_collection_stmt = self.session.prepare(
f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
)
logger.info("Prepared statements initialized for optimal performance (4 tables)")
def insert(self, collection, s, p, o):
# Batch write to all four tables for consistency
batch = BatchStatement()
if self.use_legacy:
self.session.execute(
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
(collection, s, p, o)
)
else:
# Batch write to all three tables for consistency
batch = BatchStatement()
# Insert into subject table
batch.add(self.insert_subject_stmt, (collection, s, p, o))
# Insert into subject table
batch.add(self.insert_subject_stmt, (collection, s, p, o))
# Insert into predicate-object table (column order: collection, p, o, s)
batch.add(self.insert_po_stmt, (collection, p, o, s))
# Insert into predicate-object table (column order: collection, p, o, s)
batch.add(self.insert_po_stmt, (collection, p, o, s))
# Insert into object table (column order: collection, o, s, p)
batch.add(self.insert_object_stmt, (collection, o, s, p))
# Insert into object table (column order: collection, o, s, p)
batch.add(self.insert_object_stmt, (collection, o, s, p))
# Insert into collection table for SPO queries and deletion tracking
batch.add(self.insert_collection_stmt, (collection, s, p, o))
self.session.execute(batch)
self.session.execute(batch)
def get_all(self, collection, limit=50):
if self.use_legacy:
return self.session.execute(
f"select s, p, o from {self.table} where collection = %s limit {limit}",
(collection,)
)
else:
# Use subject table for get_all queries
return self.session.execute(
self.get_all_stmt,
(collection, limit)
)
# Use subject table for get_all queries
return self.session.execute(
self.get_all_stmt,
(collection, limit)
)
def get_s(self, collection, s, limit=10):
if self.use_legacy:
return self.session.execute(
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
(collection, s)
)
else:
# Optimized: Direct partition access with (collection, s)
return self.session.execute(
self.get_s_stmt,
(collection, s, limit)
)
# Optimized: Direct partition access with (collection, s)
return self.session.execute(
self.get_s_stmt,
(collection, s, limit)
)
def get_p(self, collection, p, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
(collection, p)
)
else:
# Optimized: Use po_table for direct partition access
return self.session.execute(
self.get_p_stmt,
(collection, p, limit)
)
# Optimized: Use po_table for direct partition access
return self.session.execute(
self.get_p_stmt,
(collection, p, limit)
)
def get_o(self, collection, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
(collection, o)
)
else:
# Optimized: Use object_table for direct partition access
return self.session.execute(
self.get_o_stmt,
(collection, o, limit)
)
# Optimized: Use object_table for direct partition access
return self.session.execute(
self.get_o_stmt,
(collection, o, limit)
)
def get_sp(self, collection, s, p, limit=10):
if self.use_legacy:
return self.session.execute(
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
(collection, s, p)
)
else:
# Optimized: Use subject_table with clustering key access
return self.session.execute(
self.get_sp_stmt,
(collection, s, p, limit)
)
# Optimized: Use subject_table with clustering key access
return self.session.execute(
self.get_sp_stmt,
(collection, s, p, limit)
)
def get_po(self, collection, p, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
(collection, p, o)
)
else:
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
return self.session.execute(
self.get_po_stmt,
(collection, p, o, limit)
)
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
return self.session.execute(
self.get_po_stmt,
(collection, p, o, limit)
)
def get_os(self, collection, o, s, limit=10):
if self.use_legacy:
return self.session.execute(
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
(collection, o, s)
)
else:
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
return self.session.execute(
self.get_os_stmt,
(collection, s, o, limit)
)
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
return self.session.execute(
self.get_os_stmt,
(collection, s, o, limit)
)
def get_spo(self, collection, s, p, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
(collection, s, p, o)
# Optimized: Use collection_table for exact key lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
)
def collection_exists(self, collection):
"""Check if collection exists by querying collection_metadata table"""
try:
result = self.session.execute(
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
(collection,)
)
else:
# Optimized: Use subject_table for exact key lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
return bool(list(result))
except Exception as e:
logger.error(f"Error checking collection existence: {e}")
return False
def create_collection(self, collection):
"""Create collection by inserting metadata row"""
try:
import datetime
self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
)
logger.info(f"Created collection metadata for {collection}")
except Exception as e:
logger.error(f"Error creating collection: {e}")
raise e
def delete_collection(self, collection):
"""Delete all triples for a specific collection"""
if self.use_legacy:
self.session.execute(
f"delete from {self.table} where collection = %s",
(collection,)
)
else:
# Delete from all three tables
self.session.execute(
f"delete from {self.subject_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.po_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.object_table} where collection = %s",
(collection,)
)
"""Delete all triples for a specific collection
Uses collection_table to enumerate all triples, then deletes from all 4 tables
using full partition keys for optimal performance with compound keys.
"""
# Step 1: Read all triples from collection_table (single partition read)
rows = self.session.execute(
f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
# Step 2: Delete each triple from all 4 tables using full partition keys
# Batch deletions for efficiency
batch = BatchStatement()
count = 0
for row in rows:
s, p, o = row.s, row.p, row.o
# Delete from subject table (partition key: collection, s)
batch.add(self.delete_subject_stmt, (collection, s, p, o))
# Delete from predicate-object table (partition key: collection, p)
batch.add(self.delete_po_stmt, (collection, p, o, s))
# Delete from object table (partition key: collection, o)
batch.add(self.delete_object_stmt, (collection, o, s, p))
# Delete from collection table (partition key: collection only)
batch.add(self.delete_collection_stmt, (collection, s, p, o))
count += 1
# Execute batch every 100 triples to avoid oversized batches
if count % 100 == 0:
self.session.execute(batch)
batch = BatchStatement()
# Execute remaining deletions
if count % 100 != 0:
self.session.execute(batch)
# Step 3: Delete collection metadata
self.session.execute(
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
(collection,)
)
logger.info(f"Deleted {count} triples from collection {collection}")
def close(self):
"""Close the Cassandra session and cluster connections properly"""

View file

@ -49,6 +49,22 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
"""Check if collection exists (dimension-independent check)"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
return self.client.has_collection(collection_name)
def create_collection(self, user, collection, dimension=384):
"""Create collection with default dimension"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists")
return
self.init_collection(dimension, user, collection)
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, self.prefix)
@ -128,14 +144,6 @@ class DocVectors:
coll = self.collections[(dim, user, collection)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
logger.debug("Loading...")
self.client.load_collection(
collection_name=coll,
@ -145,10 +153,11 @@ class DocVectors:
res = self.client.search(
collection_name=coll,
anns_field="vector",
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
search_params={ "metric_type": "COSINE" },
)[0]

View file

@ -49,6 +49,22 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
"""Check if collection exists (dimension-independent check)"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
return self.client.has_collection(collection_name)
def create_collection(self, user, collection, dimension=384):
"""Create collection with default dimension"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists")
return
self.init_collection(dimension, user, collection)
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, self.prefix)
@ -128,14 +144,6 @@ class EntityVectors:
coll = self.collections[(dim, user, collection)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
logger.debug("Loading...")
self.client.load_collection(
collection_name=coll,
@ -145,10 +153,11 @@ class EntityVectors:
res = self.client.search(
collection_name=coll,
anns_field="vector",
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
search_params={ "metric_type": "COSINE" },
)[0]

View file

@ -60,7 +60,7 @@ class CollectionManager:
async def ensure_collection_exists(self, user: str, collection: str):
"""
Ensure a collection exists, creating it if necessary (lazy creation)
Ensure a collection exists, creating it if necessary with broadcast to storage
Args:
user: User ID
@ -74,7 +74,7 @@ class CollectionManager:
return
# Create new collection with default metadata
logger.info(f"Creating new collection {user}/{collection}")
logger.info(f"Auto-creating collection {user}/{collection} from document submission")
await self.table_store.create_collection(
user=user,
collection=collection,
@ -83,10 +83,64 @@ class CollectionManager:
tags=set()
)
# Broadcast collection creation to all storage backends
creation_key = (user, collection)
logger.info(f"Broadcasting create-collection for {creation_key}")
self.pending_deletions[creation_key] = {
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
"deletion_complete": asyncio.Event()
}
storage_request = StorageManagementRequest(
operation="create-collection",
user=user,
collection=collection
)
# Send creation requests to all storage types
if self.vector_storage_producer:
await self.vector_storage_producer.send(storage_request)
if self.object_storage_producer:
await self.object_storage_producer.send(storage_request)
if self.triples_storage_producer:
await self.triples_storage_producer.send(storage_request)
# Wait for all storage creations to complete (with timeout)
creation_info = self.pending_deletions[creation_key]
try:
await asyncio.wait_for(
creation_info["deletion_complete"].wait(),
timeout=30.0 # 30 second timeout
)
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
creation_info["all_successful"] = False
creation_info["error_messages"].append("Timeout waiting for storage creation")
# Check if all creations succeeded
if not creation_info["all_successful"]:
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
logger.error(error_msg)
# Clean up metadata on failure
await self.table_store.delete_collection(user, collection)
# Clean up tracking
del self.pending_deletions[creation_key]
raise RuntimeError(error_msg)
# Clean up tracking
del self.pending_deletions[creation_key]
logger.info(f"Collection {creation_key} auto-created successfully in all storage backends")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
# Don't fail the operation if collection creation fails
# This maintains backward compatibility
raise e
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
"""
@ -154,6 +208,67 @@ class CollectionManager:
tags=tags
)
# Broadcast collection creation to all storage backends
creation_key = (request.user, request.collection)
logger.info(f"Broadcasting create-collection for {creation_key}")
self.pending_deletions[creation_key] = {
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
"deletion_complete": asyncio.Event()
}
storage_request = StorageManagementRequest(
operation="create-collection",
user=request.user,
collection=request.collection
)
# Send creation requests to all storage types
if self.vector_storage_producer:
await self.vector_storage_producer.send(storage_request)
if self.object_storage_producer:
await self.object_storage_producer.send(storage_request)
if self.triples_storage_producer:
await self.triples_storage_producer.send(storage_request)
# Wait for all storage creations to complete (with timeout)
creation_info = self.pending_deletions[creation_key]
try:
await asyncio.wait_for(
creation_info["deletion_complete"].wait(),
timeout=30.0 # 30 second timeout
)
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
creation_info["all_successful"] = False
creation_info["error_messages"].append("Timeout waiting for storage creation")
# Check if all creations succeeded
if not creation_info["all_successful"]:
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
logger.error(error_msg)
# Clean up metadata on failure
await self.table_store.delete_collection(request.user, request.collection)
# Clean up tracking
del self.pending_deletions[creation_key]
return CollectionManagementResponse(
error=Error(
type="storage_creation_error",
message=error_msg
),
timestamp=datetime.now().isoformat()
)
# Clean up tracking
del self.pending_deletions[creation_key]
logger.info(f"Collection {creation_key} created successfully in all storage backends")
# Get the newly created collection for response
created_collection = await self.table_store.get_collection(request.user, request.collection)
@ -213,7 +328,7 @@ class CollectionManager:
# Track this deletion request
self.pending_deletions[deletion_key] = {
"responses_pending": 3, # vector, object, triples
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
@ -303,9 +418,9 @@ class CollectionManager:
if response.error and response.error.message:
info["all_successful"] = False
info["error_messages"].append(response.error.message)
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
logger.warning(f"Storage operation failed for {deletion_key}: {response.error.message}")
else:
logger.debug(f"Storage deletion succeeded for {deletion_key}")
logger.debug(f"Storage operation succeeded for {deletion_key}")
# If all responses received, signal completion
if info["responses_pending"] == 0:

View file

@ -32,7 +32,7 @@ class Processor(LlmService):
token = params.get("token", default_token)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
model = default_model
model = params.get("model", default_model)
if endpoint is None:
raise RuntimeError("Azure endpoint not specified")
@ -53,9 +53,11 @@ class Processor(LlmService):
self.token = token
self.temperature = temperature
self.max_output = max_output
self.model = model
self.default_model = model
def build_prompt(self, system, content):
def build_prompt(self, system, content, temperature=None):
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
data = {
"messages": [
@ -67,7 +69,7 @@ class Processor(LlmService):
}
],
"max_tokens": self.max_output,
"temperature": self.temperature,
"temperature": effective_temperature,
"top_p": 1
}
@ -100,13 +102,22 @@ class Processor(LlmService):
return result
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
prompt = self.build_prompt(
system,
prompt
prompt,
effective_temperature
)
response = self.call_llm(prompt)
@ -125,7 +136,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -54,7 +54,7 @@ class Processor(LlmService):
self.temperature = temperature
self.max_output = max_output
self.model = model
self.default_model = model
self.openai = AzureOpenAI(
api_key=token,
@ -62,14 +62,22 @@ class Processor(LlmService):
azure_endpoint = endpoint,
)
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{
"role": "user",
@ -81,7 +89,7 @@ class Processor(LlmService):
]
}
],
temperature=self.temperature,
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
)
@ -97,7 +105,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return r

View file

@ -41,21 +41,29 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.claude = anthropic.Anthropic(api_key=api_key)
self.temperature = temperature
self.max_output = max_output
logger.info("Claude LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
response = message = self.claude.messages.create(
model=self.model,
model=model_name,
max_tokens=self.max_output,
temperature=self.temperature,
temperature=effective_temperature,
system = system,
messages=[
{
@ -81,7 +89,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -39,21 +39,29 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.cohere = cohere.Client(api_key=api_key)
logger.info("Cohere LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
output = self.cohere.chat(
model=self.model,
output = self.cohere.chat(
model=model_name,
message=prompt,
preamble = system,
temperature=self.temperature,
temperature=effective_temperature,
chat_history=[],
prompt_truncation='auto',
connectors=[]
@ -71,7 +79,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -53,10 +53,13 @@ class Processor(LlmService):
)
self.client = genai.Client(api_key=api_key)
self.model = model
self.default_model = model
self.temperature = temperature
self.max_output = max_output
# Cache for generation configs per model
self.generation_configs = {}
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [
@ -83,22 +86,45 @@ class Processor(LlmService):
logger.info("GoogleAIStudio LLM service initialized")
async def generate_content(self, system, prompt):
def _get_or_create_config(self, model_name, temperature=None):
"""Get or create generation config with dynamic temperature"""
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
generation_config = types.GenerateContentConfig(
temperature = self.temperature,
top_p = 1,
top_k = 40,
max_output_tokens = self.max_output,
response_mime_type = "text/plain",
system_instruction = system,
safety_settings = self.safety_settings,
)
# Create cache key that includes temperature to avoid conflicts
cache_key = f"{model_name}:{effective_temperature}"
if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = types.GenerateContentConfig(
temperature = effective_temperature,
top_p = 1,
top_k = 40,
max_output_tokens = self.max_output,
response_mime_type = "text/plain",
safety_settings = self.safety_settings,
)
return self.generation_configs[cache_key]
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
generation_config = self._get_or_create_config(model_name, effective_temperature)
# Set system instruction per request (can't be cached)
generation_config.system_instruction = system
try:
response = self.client.models.generate_content(
model=self.model,
model=model_name,
config=generation_config,
contents=prompt,
)
@ -114,7 +140,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -39,7 +39,7 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.llamafile=llamafile
self.temperature = temperature
self.max_output = max_output
@ -50,25 +50,33 @@ class Processor(LlmService):
logger.info("Llamafile LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{"role": "user", "content": prompt}
]
#temperature=self.temperature,
#max_tokens=self.max_output,
#top_p=1,
#frequency_penalty=0,
#presence_penalty=0,
#response_format={
# "type": "text"
#}
],
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
inputtokens = resp.usage.prompt_tokens
@ -82,7 +90,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = "llama.cpp",
model = model_name,
)
return resp

View file

@ -39,7 +39,7 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.url = url + "v1/"
self.temperature = temperature
self.max_output = max_output
@ -50,7 +50,15 @@ class Processor(LlmService):
logger.info("LMStudio LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
@ -59,18 +67,18 @@ class Processor(LlmService):
logger.debug(f"Prompt: {prompt}")
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{"role": "user", "content": prompt}
]
#temperature=self.temperature,
#max_tokens=self.max_output,
#top_p=1,
#frequency_penalty=0,
#presence_penalty=0,
#response_format={
# "type": "text"
#}
],
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
logger.debug(f"Full response: {resp}")
@ -86,7 +94,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -41,21 +41,29 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.max_output = max_output
self.mistral = Mistral(api_key=api_key)
logger.info("Mistral LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.mistral.chat.complete(
model=self.model,
model=model_name,
messages=[
{
"role": "user",
@ -67,7 +75,7 @@ class Processor(LlmService):
]
}
],
temperature=self.temperature,
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
@ -87,7 +95,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -17,6 +17,7 @@ from .... base import LlmService, LlmResult
default_ident = "text-completion"
default_model = 'gemma2:9b'
default_temperature = 0.0
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
class Processor(LlmService):
@ -24,25 +25,36 @@ class Processor(LlmService):
def __init__(self, **params):
model = params.get("model", default_model)
temperature = params.get("temperature", default_temperature)
ollama = params.get("ollama", default_ollama)
super(Processor, self).__init__(
**params | {
"model": model,
"temperature": temperature,
"ollama": ollama,
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.llm = Client(host=ollama)
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
response = self.llm.generate(self.model, prompt)
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
response_text = response['response']
logger.debug("Sending response...")
@ -55,7 +67,7 @@ class Processor(LlmService):
text = response_text,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp
@ -84,6 +96,13 @@ class Processor(LlmService):
help=f'ollama (default: {default_ollama})'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -47,7 +47,7 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.max_output = max_output
@ -58,14 +58,22 @@ class Processor(LlmService):
logger.info("OpenAI LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{
"role": "user",
@ -77,7 +85,7 @@ class Processor(LlmService):
]
}
],
temperature=self.temperature,
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
@ -97,7 +105,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -30,32 +30,43 @@ class Processor(LlmService):
base_url = params.get("url", default_base_url)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
model = params.get("model", "tgi")
super(Processor, self).__init__(
**params | {
"temperature": temperature,
"max_output": max_output,
"url": base_url,
"model": model,
}
)
self.base_url = base_url
self.temperature = temperature
self.max_output = max_output
self.default_model = model
self.session = aiohttp.ClientSession()
logger.info(f"Using TGI service at {base_url}")
logger.info("TGI LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
headers = {
"Content-Type": "application/json",
}
request = {
"model": "tgi",
"model": model_name,
"messages": [
{
"role": "system",
@ -67,7 +78,7 @@ class Processor(LlmService):
}
],
"max_tokens": self.max_output,
"temperature": self.temperature,
"temperature": effective_temperature,
}
try:
@ -96,7 +107,7 @@ class Processor(LlmService):
text = ans,
in_token = inputtokens,
out_token = outputtokens,
model = "tgi",
model = model_name,
)
return resp

View file

@ -45,24 +45,32 @@ class Processor(LlmService):
self.base_url = base_url
self.temperature = temperature
self.max_output = max_output
self.model = model
self.default_model = model
self.session = aiohttp.ClientSession()
logger.info(f"Using vLLM service at {base_url}")
logger.info("vLLM LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
headers = {
"Content-Type": "application/json",
}
request = {
"model": self.model,
"model": model_name,
"prompt": system + "\n\n" + prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
"temperature": effective_temperature,
}
try:
@ -91,7 +99,7 @@ class Processor(LlmService):
text = ans,
in_token = inputtokens,
out_token = outputtokens,
model = self.model,
model = model_name,
)
return resp

View file

@ -57,21 +57,26 @@ class Processor(DocumentEmbeddingsQueryService):
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, msg):
try:
chunks = []
collection = (
"d_" + msg.user + "_" + msg.collection
)
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
for vec in msg.vectors:
dim = len(vec)
collection = (
"d_" + msg.user + "_" + msg.collection
)
self.ensure_collection_exists(collection, dim)
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,

View file

@ -57,6 +57,10 @@ class Processor(GraphEmbeddingsQueryService):
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
@ -70,12 +74,16 @@ class Processor(GraphEmbeddingsQueryService):
entity_set = set()
entities = []
for vec in msg.vectors:
collection = (
"t_" + msg.user + "_" + msg.collection
)
dim = len(vec)
collection = (
"t_" + msg.user + "_" + msg.collection
)
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
for vec in msg.vectors:
self.ensure_collection_exists(collection, dim)

View file

@ -1,12 +1,56 @@
import asyncio
import logging
import time
from collections import OrderedDict
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
CRITICAL SECURITY WARNING:
This cache is shared within a GraphRag instance but GraphRag instances
are created per-request. Cache keys MUST include user:collection prefix
to ensure data isolation between different security contexts.
"""
def __init__(self, max_size=5000, ttl=300):
self.cache = OrderedDict()
self.access_times = {}
self.max_size = max_size
self.ttl = ttl
def get(self, key):
if key not in self.cache:
return None
# Check TTL expiration
if time.time() - self.access_times[key] > self.ttl:
del self.cache[key]
del self.access_times[key]
return None
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.max_size:
# Remove least recently used
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
del self.access_times[oldest_key]
self.cache[key] = value
self.access_times[key] = time.time()
class Query:
def __init__(
@ -61,8 +105,14 @@ class Query:
async def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
# CRITICAL SECURITY: Cache key MUST include user and collection
# to prevent data leakage between different contexts
cache_key = f"{self.user}:{self.collection}:{e}"
# Check LRU cache first with isolated key
cached_label = self.rag.label_cache.get(cache_key)
if cached_label is not None:
return cached_label
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
@ -70,60 +120,104 @@ class Query:
)
if len(res) == 0:
self.rag.label_cache[e] = e
self.rag.label_cache.put(cache_key, e)
return e
self.rag.label_cache[e] = str(res[0].o)
return self.rag.label_cache[e]
label = str(res[0].o)
self.rag.label_cache.put(cache_key, label)
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently"""
tasks = []
for entity in entities:
# Create concurrent tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
),
self.rag.triples_client.query(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
),
self.rag.triples_client.query(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection
)
])
# Execute all queries concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception):
all_triples.extend(result)
return all_triples
async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching"""
visited = set()
current_level = set(entities)
subgraph = set()
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Filter out already visited entities
unvisited_entities = [e for e in current_level if e not in visited]
if not unvisited_entities:
break
# Batch query all unvisited entities at current level
triples = await self.execute_batch_triple_queries(
unvisited_entities, self.triple_limit
)
# Process results and collect next level entities
next_level = set()
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple)
# Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth
s, p, o = triple_tuple
if s not in visited:
next_level.add(s)
if o not in visited:
next_level.add(o)
# Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size:
return subgraph
# Update for next iteration
visited.update(current_level)
current_level = next_level
return subgraph
async def follow_edges(self, ent, subgraph, path_length):
# Not needed?
"""Legacy method - replaced by follow_edges_batch"""
# Maintain backward compatibility with early termination checks
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = await self.rag.triples_client.query(
s=ent, p=None, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(str(triple.o), subgraph, path_length-1)
res = await self.rag.triples_client.query(
s=None, p=ent, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
res = await self.rag.triples_client.query(
s=None, p=None, o=ent,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(
str(triple.s), subgraph, path_length-1
)
# For backward compatibility, convert to new approach
batch_result = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result)
async def get_subgraph(self, query):
@ -132,31 +226,52 @@ class Query:
if self.verbose:
logger.debug("Getting subgraph...")
subgraph = set()
# Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
for ent in entities:
await self.follow_edges(ent, subgraph, self.max_path_length)
return list(subgraph)
subgraph = list(subgraph)
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
tasks = []
for entity in entities:
tasks.append(self.maybe_label(entity))
return subgraph
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
subgraph = await self.get_subgraph(query)
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
# Collect all unique entities that need label resolution
entities_to_resolve = set()
for s, p, o in filtered_subgraph:
entities_to_resolve.update([s, p, o])
# Batch resolve labels for all entities in parallel
entity_list = list(entities_to_resolve)
resolved_labels = await self.resolve_labels_batch(entity_list)
# Create entity-to-label mapping
label_map = {}
for entity, label in zip(entity_list, resolved_labels):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = await self.maybe_label(edge[0])
p = await self.maybe_label(edge[1])
o = await self.maybe_label(edge[2])
sg2.append((s, p, o))
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
)
sg2.append(labeled_triple)
sg2 = sg2[0:self.max_subgraph_size]
@ -171,6 +286,13 @@ class Query:
return sg2
class GraphRag:
"""
CRITICAL SECURITY:
This class MUST be instantiated per-request to ensure proper isolation
between users and collections. The cache within this instance will only
live for the duration of a single request, preventing cross-contamination
of data between different security contexts.
"""
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
@ -184,7 +306,9 @@ class GraphRag:
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.label_cache = {}
# Replace simple dict with LRU cache with TTL
# CRITICAL: This cache only lives for one request due to per-request instantiation
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
if self.verbose:
logger.debug("GraphRag initialized")

View file

@ -45,6 +45,10 @@ class Processor(FlowProcessor):
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
# Caching must NEVER allow information leakage across these boundaries
self.register_specification(
ConsumerSpec(
name = "request",
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
try:
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
rag = GraphRag(
embeddings_client=flow("embeddings-request"),
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
verbose=True,
)
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
response = await self.rag.query(
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,

View file

@ -60,19 +60,34 @@ class Processor(DocumentEmbeddingsStoreService):
metrics=storage_response_metrics,
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_document_embeddings(self, message):
# Validate collection exists before accepting writes
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
for vec in emb.vectors:
self.vecstore.insert(
vec, chunk,
message.metadata.user,
vec, chunk,
message.metadata.user,
message.metadata.collection
)
@ -87,18 +102,21 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -113,17 +131,40 @@ class Processor(DocumentEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Milvus collection for document embeddings"""
try:
if self.vecstore.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.vecstore.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
self.vecstore.delete_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -115,38 +115,36 @@ class Processor(DocumentEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_document_embeddings(self, message):
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection
)
# Validate collection exists before accepting writes
if not self.pinecone.has_index(index_name):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
for vec in emb.vectors:
dim = len(vec)
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.create_index(index_name, dim)
except Exception as e:
logger.error("Pinecone index creation failed")
raise e
logger.info(f"Index {index_name} created")
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
@ -192,18 +190,21 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -218,10 +219,36 @@ class Processor(DocumentEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Pinecone index for document embeddings"""
try:
index_name = f"d-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
logger.info(f"Pinecone index {index_name} already exists")
else:
# Create with default dimension - will need to be recreated if dimension doesn't match
self.create_index(index_name, dim=384)
logger.info(f"Created Pinecone index: {index_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
try:
index_name = f"d-{message.user}-{message.collection}"
index_name = f"d-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
@ -234,7 +261,7 @@ class Processor(DocumentEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -36,8 +36,6 @@ class Processor(DocumentEmbeddingsStoreService):
}
)
self.last_collection = None
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
@ -71,8 +69,30 @@ class Processor(DocumentEmbeddingsStoreService):
metrics=storage_response_metrics,
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
if hasattr(self, 'storage_request_consumer'):
await self.storage_request_consumer.start()
if hasattr(self, 'storage_response_producer'):
await self.storage_response_producer.start()
async def store_document_embeddings(self, message):
# Validate collection exists before accepting writes
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection
)
if not self.qdrant.collection_exists(collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
chunk = emb.chunk.decode("utf-8")
@ -80,29 +100,6 @@ class Processor(DocumentEmbeddingsStoreService):
for vec in emb.vectors:
dim = len(vec)
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection
)
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
except Exception as e:
logger.error("Qdrant collection creation failed")
raise e
self.last_collection = collection
self.qdrant.upsert(
collection_name=collection,
points=[
@ -133,18 +130,21 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -159,10 +159,43 @@ class Processor(DocumentEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Qdrant collection for document embeddings"""
try:
collection_name = f"d_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
logger.info(f"Qdrant collection {collection_name} already exists")
else:
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
# Using a placeholder dimension - actual dimension determined by first embedding
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=384, # Default dimension, common for many models
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {collection_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
try:
collection_name = f"d_{message.user}_{message.collection}"
collection_name = f"d_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
@ -175,7 +208,7 @@ class Processor(DocumentEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -60,8 +60,23 @@ class Processor(GraphEmbeddingsStoreService):
metrics=storage_response_metrics,
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_graph_embeddings(self, message):
# Validate collection exists before accepting writes
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for entity in message.entities:
if entity.entity.value != "" and entity.entity.value is not None:
@ -83,18 +98,21 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -109,17 +127,40 @@ class Processor(GraphEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Milvus collection for graph embeddings"""
try:
if self.vecstore.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.vecstore.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
self.vecstore.delete_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -115,8 +115,27 @@ class Processor(GraphEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_graph_embeddings(self, message):
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection
)
# Validate collection exists before accepting writes
if not self.pinecone.has_index(index_name):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for entity in message.entities:
if entity.entity.value == "" or entity.entity.value is None:
@ -124,28 +143,6 @@ class Processor(GraphEmbeddingsStoreService):
for vec in entity.vectors:
dim = len(vec)
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.create_index(index_name, dim)
except Exception as e:
logger.error("Pinecone index creation failed")
raise e
logger.info(f"Index {index_name} created")
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
@ -191,18 +188,21 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -217,10 +217,36 @@ class Processor(GraphEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Pinecone index for graph embeddings"""
try:
index_name = f"t-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
logger.info(f"Pinecone index {index_name} already exists")
else:
# Create with default dimension - will need to be recreated if dimension doesn't match
self.create_index(index_name, dim=384)
logger.info(f"Created Pinecone index: {index_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
try:
index_name = f"t-{message.user}-{message.collection}"
index_name = f"t-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
@ -233,7 +259,7 @@ class Processor(GraphEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -36,8 +36,6 @@ class Processor(GraphEmbeddingsStoreService):
}
)
self.last_collection = None
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
@ -71,31 +69,30 @@ class Processor(GraphEmbeddingsStoreService):
metrics=storage_response_metrics,
)
def get_collection(self, dim, user, collection):
def get_collection(self, user, collection):
"""Get collection name and validate it exists"""
cname = (
"t_" + user + "_" + collection
)
if cname != self.last_collection:
if not self.qdrant.collection_exists(cname):
try:
self.qdrant.create_collection(
collection_name=cname,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
except Exception as e:
logger.error("Qdrant collection creation failed")
raise e
self.last_collection = cname
if not self.qdrant.collection_exists(cname):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
return cname
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
if hasattr(self, 'storage_request_consumer'):
await self.storage_request_consumer.start()
if hasattr(self, 'storage_response_producer'):
await self.storage_response_producer.start()
async def store_graph_embeddings(self, message):
for entity in message.entities:
@ -104,10 +101,8 @@ class Processor(GraphEmbeddingsStoreService):
for vec in entity.vectors:
dim = len(vec)
collection = self.get_collection(
dim, message.metadata.user, message.metadata.collection
message.metadata.user, message.metadata.collection
)
self.qdrant.upsert(
@ -140,18 +135,21 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -166,10 +164,43 @@ class Processor(GraphEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Qdrant collection for graph embeddings"""
try:
collection_name = f"t_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
logger.info(f"Qdrant collection {collection_name} already exists")
else:
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
# Using a placeholder dimension - actual dimension determined by first embedding
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=384, # Default dimension, common for many models
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {collection_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
try:
collection_name = f"t_{message.user}_{message.collection}"
collection_name = f"t_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
@ -182,7 +213,7 @@ class Processor(GraphEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -295,6 +295,8 @@ class Processor(FlowProcessor):
try:
self.session.execute(create_table_cql)
if keyspace not in self.known_tables:
self.known_tables[keyspace] = set()
self.known_tables[keyspace].add(table_key)
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
@ -340,18 +342,47 @@ class Processor(FlowProcessor):
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
return str(value)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Validate collection/keyspace exists before accepting writes
safe_keyspace = self.sanitize_name(obj.metadata.user)
if safe_keyspace not in self.known_keyspaces:
# Check if keyspace actually exists in Cassandra
self.connect_cassandra()
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
# Check if result is None (mock case) or has no rows
if result is None or not result.one():
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Cache it if it exists
self.known_keyspaces.add(safe_keyspace)
if safe_keyspace not in self.known_tables:
self.known_tables[safe_keyspace] = set()
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Ensure table exists
keyspace = obj.metadata.user
table_name = obj.schema_name
@ -425,26 +456,36 @@ class Processor(FlowProcessor):
async def on_storage_management(self, msg, consumer, flow):
"""Handle storage management requests for collection operations"""
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
request = msg.value()
logger.info(f"Received storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if msg.operation == "delete-collection":
await self.delete_collection(msg.user, msg.collection)
if request.operation == "create-collection":
await self.create_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
logger.info(f"Successfully created collection {request.user}/{request.collection}")
elif request.operation == "delete-collection":
await self.delete_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
else:
logger.warning(f"Unknown storage management operation: {msg.operation}")
logger.warning(f"Unknown storage management operation: {request.operation}")
# Send error response
from .... schema import Error
response = StorageManagementResponse(
error=Error(
type="unknown_operation",
message=f"Unknown operation: {msg.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -459,10 +500,28 @@ class Processor(FlowProcessor):
message=str(e)
)
)
await self.send("storage-response", response)
await self.storage_response_producer.send(response)
async def create_collection(self, user: str, collection: str):
"""Create/verify collection exists in Cassandra object store"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Ensure keyspace exists
if safe_keyspace not in self.known_keyspaces:
self.ensure_keyspace(safe_keyspace)
self.known_keyspaces.add(safe_keyspace)
# For Cassandra objects, collection is just a property in rows
# No need to create separate tables per collection
# Just mark that we've seen this collection
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection"""
"""Delete all data for a specific collection using schema information"""
# Connect if not already connected
self.connect_cassandra()
@ -482,40 +541,78 @@ class Processor(FlowProcessor):
return
self.known_keyspaces.add(safe_keyspace)
# Get all tables in the keyspace that might contain collection data
get_tables_cql = """
SELECT table_name FROM system_schema.tables
WHERE keyspace_name = %s
"""
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
# Iterate over schemas we manage to delete from relevant tables
tables_deleted = 0
for row in tables:
table_name = row.table_name
for schema_name, schema in self.schemas.items():
safe_table = self.sanitize_table(schema_name)
# Check if the table has a collection column
check_column_cql = """
SELECT column_name FROM system_schema.columns
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
"""
# Check if table exists
table_key = f"{user}.{schema_name}"
if table_key not in self.known_tables.get(user, set()):
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
continue
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
if result.one():
# Table has collection column, delete data for this collection
try:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{table_name}
try:
# Get primary key fields from schema
primary_key_fields = [field for field in schema.fields if field.primary]
if primary_key_fields:
# Schema has primary keys: need to query for partition keys first
# Build SELECT query for primary key fields
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
select_cql = f"""
SELECT {', '.join(pk_field_names)}
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
self.session.execute(delete_cql, (collection,))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
rows = self.session.execute(select_cql, (collection,))
# Delete each row using full partition key
for row in rows:
where_clauses = ["collection = %s"]
values = [collection]
for field_name in pk_field_names:
where_clauses.append(f"{field_name} = %s")
values.append(getattr(row, field_name))
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE {' AND '.join(where_clauses)}
"""
self.session.execute(delete_cql, tuple(values))
else:
# No primary keys, uses synthetic_id
# Need to query for synthetic_ids first
select_cql = f"""
SELECT synthetic_id
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using collection and synthetic_id
for row in rows:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE collection = %s AND synthetic_id = %s
"""
self.session.execute(delete_cql, (collection, row.synthetic_id))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""

View file

@ -109,6 +109,15 @@ class Processor(TriplesStoreService):
self.table = user
# Validate collection exists before accepting writes
if not self.tg.collection_exists(message.metadata.collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.tg.insert(
message.metadata.collection,
@ -117,18 +126,27 @@ class Processor(TriplesStoreService):
t.o.value
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -143,42 +161,85 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection from the unified triples table"""
async def handle_create_collection(self, request):
"""Create a collection in Cassandra triple store"""
try:
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != message.user:
if self.table is None or self.table != request.user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
keyspace=request.user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
keyspace=request.user,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
raise
self.table = message.user
self.table = request.user
# Delete all triples for this collection from the unified table
# In the unified table schema, collection is the partition key
delete_cql = """
DELETE FROM triples
WHERE collection = ?
"""
# Create collection using the built-in method
logger.info(f"Creating collection {request.collection} for user {request.user}")
if self.tg.collection_exists(request.collection):
logger.info(f"Collection {request.collection} already exists")
else:
self.tg.create_collection(request.collection)
logger.info(f"Created collection {request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete all data for a specific collection from the unified triples table"""
try:
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != request.user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=request.user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=request.user,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
raise
self.table = request.user
# Delete all triples for this collection using the built-in method
try:
self.tg.session.execute(delete_cql, (message.collection,))
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
self.tg.delete_collection(request.collection)
logger.info(f"Deleted all triples for collection {request.collection} from keyspace {request.user}")
except Exception as e:
logger.error(f"Failed to delete collection data: {e}")
raise
@ -188,7 +249,7 @@ class Processor(TriplesStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -152,11 +152,43 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def collection_exists(self, user, collection):
"""Check if collection metadata node exists"""
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"RETURN c LIMIT 1",
params={"user": user, "collection": collection}
)
return result.result_set is not None and len(result.result_set) > 0
def create_collection(self, user, collection):
"""Create collection metadata node"""
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection metadata node for {user}/{collection}")
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.create_node(t.s.value, user, collection)
@ -185,18 +217,27 @@ class Processor(TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -211,28 +252,57 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create collection metadata in FalkorDB"""
try:
if self.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for FalkorDB triples"""
try:
# Delete all nodes and literals for this user/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
params={"user": request.user, "collection": request.collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
params={"user": request.user, "collection": request.collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
# Delete collection metadata node
metadata_result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
params={"user": request.user, "collection": request.collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -267,12 +267,43 @@ class Processor(TriplesStoreService):
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
)
def collection_exists(self, user, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
)
return bool(list(result))
def create_collection(self, user, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.create_node(t.s.value, user, collection)
@ -317,18 +348,27 @@ class Processor(TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -343,7 +383,30 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create collection metadata in Memgraph"""
try:
if self.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
@ -351,7 +414,7 @@ class Processor(TriplesStoreService):
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
@ -359,20 +422,28 @@ class Processor(TriplesStoreService):
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"DELETE c",
user=request.user, collection=request.collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -228,6 +228,15 @@ class Processor(TriplesStoreService):
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.create_node(t.s.value, user, collection)
@ -268,18 +277,27 @@ class Processor(TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -294,7 +312,52 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
def collection_exists(self, user, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
)
return bool(list(result))
def create_collection(self, user, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
async def handle_create_collection(self, request):
"""Create collection metadata in Neo4j"""
try:
if self.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
@ -302,7 +365,7 @@ class Processor(TriplesStoreService):
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
@ -310,20 +373,28 @@ class Processor(TriplesStoreService):
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"DELETE c",
user=request.user, collection=request.collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -145,7 +145,7 @@ class ConfigTableStore:
""")
self.get_all_stmt = self.cassandra.prepare("""
SELECT class, key, value FROM config;
SELECT class AS cls, key, value FROM config;
""")
self.get_values_stmt = self.cassandra.prepare("""