mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-01 11:26:22 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
|
|
@ -9,14 +9,14 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "chunker"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
class Processor(ChunkingService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
|
|
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
|
|||
**params | { "id": id }
|
||||
)
|
||||
|
||||
# Store default values for parameter override
|
||||
self.default_chunk_size = chunk_size
|
||||
self.default_chunk_overlap = chunk_overlap
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
|
|
@ -65,7 +69,22 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Chunking document {v.metadata.id}...")
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
# Extract chunk parameters from flow (allows runtime override)
|
||||
chunk_size, chunk_overlap = await self.chunk_document(
|
||||
msg, consumer, flow,
|
||||
self.default_chunk_size,
|
||||
self.default_chunk_overlap
|
||||
)
|
||||
|
||||
# Create text splitter with effective parameters
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
texts = text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
|
|
@ -89,7 +108,7 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ChunkingService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ from langchain_text_splitters import TokenTextSplitter
|
|||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "chunker"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
class Processor(ChunkingService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
|
|
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
|
|||
**params | { "id": id }
|
||||
)
|
||||
|
||||
# Store default values for parameter override
|
||||
self.default_chunk_size = chunk_size
|
||||
self.default_chunk_overlap = chunk_overlap
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
|
|
@ -64,7 +68,21 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Chunking document {v.metadata.id}...")
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
# Extract chunk parameters from flow (allows runtime override)
|
||||
chunk_size, chunk_overlap = await self.chunk_document(
|
||||
msg, consumer, flow,
|
||||
self.default_chunk_size,
|
||||
self.default_chunk_overlap
|
||||
)
|
||||
|
||||
# Create text splitter with effective parameters
|
||||
text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
texts = text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
|
|
@ -88,7 +106,7 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ChunkingService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
|
|
|
|||
|
|
@ -10,6 +10,95 @@ class FlowConfig:
|
|||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
# Cache for parameter type definitions to avoid repeated lookups
|
||||
self.param_type_cache = {}
|
||||
|
||||
async def resolve_parameters(self, flow_class, user_params):
|
||||
"""
|
||||
Resolve parameters by merging user-provided values with defaults.
|
||||
|
||||
Args:
|
||||
flow_class: The flow class definition dict
|
||||
user_params: User-provided parameters dict (may be None or empty)
|
||||
|
||||
Returns:
|
||||
Complete parameter dict with user values and defaults merged (all values as strings)
|
||||
"""
|
||||
# If the flow class has no parameters section, return user params as-is (stringified)
|
||||
if "parameters" not in flow_class:
|
||||
if not user_params:
|
||||
return {}
|
||||
# Ensure all values are strings
|
||||
return {k: str(v) for k, v in user_params.items()}
|
||||
|
||||
resolved = {}
|
||||
flow_params = flow_class["parameters"]
|
||||
user_params = user_params if user_params else {}
|
||||
|
||||
# First pass: resolve parameters with explicit values or defaults
|
||||
for param_name, param_meta in flow_params.items():
|
||||
# Check if user provided a value
|
||||
if param_name in user_params:
|
||||
# Store as string
|
||||
resolved[param_name] = str(user_params[param_name])
|
||||
else:
|
||||
# Look up the parameter type definition
|
||||
param_type = param_meta.get("type")
|
||||
if param_type:
|
||||
# Check cache first
|
||||
if param_type not in self.param_type_cache:
|
||||
try:
|
||||
# Fetch parameter type definition from config store
|
||||
type_def = await self.config.get("parameter-types").get(param_type)
|
||||
if type_def:
|
||||
self.param_type_cache[param_type] = json.loads(type_def)
|
||||
else:
|
||||
logger.warning(f"Parameter type '{param_type}' not found in config")
|
||||
self.param_type_cache[param_type] = {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching parameter type '{param_type}': {e}")
|
||||
self.param_type_cache[param_type] = {}
|
||||
|
||||
# Apply default from type definition (as string)
|
||||
type_def = self.param_type_cache[param_type]
|
||||
if "default" in type_def:
|
||||
default_value = type_def["default"]
|
||||
# Convert to string based on type
|
||||
if isinstance(default_value, bool):
|
||||
resolved[param_name] = "true" if default_value else "false"
|
||||
else:
|
||||
resolved[param_name] = str(default_value)
|
||||
elif type_def.get("required", False):
|
||||
# Required parameter with no default and no user value
|
||||
raise RuntimeError(f"Required parameter '{param_name}' not provided and has no default")
|
||||
|
||||
# Second pass: handle controlled-by relationships
|
||||
for param_name, param_meta in flow_params.items():
|
||||
if param_name not in resolved and "controlled-by" in param_meta:
|
||||
controller = param_meta["controlled-by"]
|
||||
if controller in resolved:
|
||||
# Inherit value from controlling parameter (already a string)
|
||||
resolved[param_name] = resolved[controller]
|
||||
else:
|
||||
# Controller has no value, try to get default from type definition
|
||||
param_type = param_meta.get("type")
|
||||
if param_type and param_type in self.param_type_cache:
|
||||
type_def = self.param_type_cache[param_type]
|
||||
if "default" in type_def:
|
||||
default_value = type_def["default"]
|
||||
# Convert to string based on type
|
||||
if isinstance(default_value, bool):
|
||||
resolved[param_name] = "true" if default_value else "false"
|
||||
else:
|
||||
resolved[param_name] = str(default_value)
|
||||
|
||||
# Include any extra parameters from user that weren't in flow class definition
|
||||
# This allows for forward compatibility (ensure they're strings)
|
||||
for key, value in user_params.items():
|
||||
if key not in resolved:
|
||||
resolved[key] = str(value)
|
||||
|
||||
return resolved
|
||||
|
||||
async def handle_list_classes(self, msg):
|
||||
|
||||
|
|
@ -68,11 +157,14 @@ class FlowConfig:
|
|||
|
||||
async def handle_get_flow(self, msg):
|
||||
|
||||
flow = await self.config.get("flows").get(msg.flow_id)
|
||||
flow_data = await self.config.get("flows").get(msg.flow_id)
|
||||
flow = json.loads(flow_data)
|
||||
|
||||
return FlowResponse(
|
||||
error = None,
|
||||
flow = flow,
|
||||
flow = flow_data,
|
||||
description = flow.get("description", ""),
|
||||
parameters = flow.get("parameters", {}),
|
||||
)
|
||||
|
||||
async def handle_start_flow(self, msg):
|
||||
|
|
@ -83,45 +175,65 @@ class FlowConfig:
|
|||
if msg.flow_id is None:
|
||||
raise RuntimeError("No flow ID")
|
||||
|
||||
if msg.flow_id in await self.config.get("flows").values():
|
||||
if msg.flow_id in await self.config.get("flows").keys():
|
||||
raise RuntimeError("Flow already exists")
|
||||
|
||||
if msg.description is None:
|
||||
raise RuntimeError("No description")
|
||||
|
||||
if msg.class_name not in await self.config.get("flow-classes").values():
|
||||
if msg.class_name not in await self.config.get("flow-classes").keys():
|
||||
raise RuntimeError("Class does not exist")
|
||||
|
||||
def repl_template(tmp):
|
||||
return tmp.replace(
|
||||
"{class}", msg.class_name
|
||||
).replace(
|
||||
"{id}", msg.flow_id
|
||||
)
|
||||
|
||||
cls = json.loads(
|
||||
await self.config.get("flow-classes").get(msg.class_name)
|
||||
)
|
||||
|
||||
# Resolve parameters by merging user-provided values with defaults
|
||||
user_params = msg.parameters if msg.parameters else {}
|
||||
parameters = await self.resolve_parameters(cls, user_params)
|
||||
|
||||
# Log the resolved parameters for debugging
|
||||
logger.debug(f"User provided parameters: {user_params}")
|
||||
logger.debug(f"Resolved parameters (with defaults): {parameters}")
|
||||
|
||||
# Apply parameter substitution to template replacement function
|
||||
def repl_template_with_params(tmp):
|
||||
|
||||
result = tmp.replace(
|
||||
"{class}", msg.class_name
|
||||
).replace(
|
||||
"{id}", msg.flow_id
|
||||
)
|
||||
# Apply parameter substitutions
|
||||
for param_name, param_value in parameters.items():
|
||||
result = result.replace(f"{{{param_name}}}", str(param_value))
|
||||
|
||||
return result
|
||||
|
||||
for kind in ("class", "flow"):
|
||||
|
||||
for k, v in cls[kind].items():
|
||||
|
||||
processor, variant = k.split(":", 1)
|
||||
|
||||
variant = repl_template(variant)
|
||||
variant = repl_template_with_params(variant)
|
||||
|
||||
v = {
|
||||
repl_template(k2): repl_template(v2)
|
||||
repl_template_with_params(k2): repl_template_with_params(v2)
|
||||
for k2, v2 in v.items()
|
||||
}
|
||||
|
||||
flac = await self.config.get("flows-active").values()
|
||||
if processor in flac:
|
||||
target = json.loads(flac[processor])
|
||||
flac = await self.config.get("flows-active").get(processor)
|
||||
if flac is not None:
|
||||
target = json.loads(flac)
|
||||
else:
|
||||
target = {}
|
||||
|
||||
# The condition if variant not in target: means it only adds
|
||||
# the configuration if the variant doesn't already exist.
|
||||
# If "everything" already exists in the target with old
|
||||
# values, they won't update.
|
||||
|
||||
if variant not in target:
|
||||
target[variant] = v
|
||||
|
||||
|
|
@ -131,10 +243,10 @@ class FlowConfig:
|
|||
|
||||
def repl_interface(i):
|
||||
if isinstance(i, str):
|
||||
return repl_template(i)
|
||||
return repl_template_with_params(i)
|
||||
else:
|
||||
return {
|
||||
k: repl_template(v)
|
||||
k: repl_template_with_params(v)
|
||||
for k, v in i.items()
|
||||
}
|
||||
|
||||
|
|
@ -152,6 +264,7 @@ class FlowConfig:
|
|||
"description": msg.description,
|
||||
"class-name": msg.class_name,
|
||||
"interfaces": interfaces,
|
||||
"parameters": parameters,
|
||||
})
|
||||
)
|
||||
|
||||
|
|
@ -177,15 +290,20 @@ class FlowConfig:
|
|||
raise RuntimeError("Internal error: flow has no flow class")
|
||||
|
||||
class_name = flow["class-name"]
|
||||
parameters = flow.get("parameters", {})
|
||||
|
||||
cls = json.loads(await self.config.get("flow-classes").get(class_name))
|
||||
|
||||
def repl_template(tmp):
|
||||
return tmp.replace(
|
||||
result = tmp.replace(
|
||||
"{class}", class_name
|
||||
).replace(
|
||||
"{id}", msg.flow_id
|
||||
)
|
||||
# Apply parameter substitutions
|
||||
for param_name, param_value in parameters.items():
|
||||
result = result.replace(f"{{{param_name}}}", str(param_value))
|
||||
return result
|
||||
|
||||
for kind in ("flow",):
|
||||
|
||||
|
|
@ -195,10 +313,10 @@ class FlowConfig:
|
|||
|
||||
variant = repl_template(variant)
|
||||
|
||||
flac = await self.config.get("flows-active").values()
|
||||
flac = await self.config.get("flows-active").get(processor)
|
||||
|
||||
if processor in flac:
|
||||
target = json.loads(flac[processor])
|
||||
if flac is not None:
|
||||
target = json.loads(flac)
|
||||
else:
|
||||
target = {}
|
||||
|
||||
|
|
@ -209,7 +327,7 @@ class FlowConfig:
|
|||
processor, json.dumps(target)
|
||||
)
|
||||
|
||||
if msg.flow_id in await self.config.get("flows").values():
|
||||
if msg.flow_id in await self.config.get("flows").keys():
|
||||
await self.config.get("flows").delete(msg.flow_id)
|
||||
|
||||
await self.config.inc_version()
|
||||
|
|
|
|||
|
|
@ -24,16 +24,12 @@ class KnowledgeGraph:
|
|||
self.keyspace = keyspace
|
||||
self.username = username
|
||||
|
||||
# Multi-table schema design for optimal performance
|
||||
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
|
||||
|
||||
if self.use_legacy:
|
||||
self.table = "triples" # Legacy single table
|
||||
else:
|
||||
# New optimized tables
|
||||
self.subject_table = "triples_s"
|
||||
self.po_table = "triples_p"
|
||||
self.object_table = "triples_o"
|
||||
# Optimized multi-table schema with collection deletion support
|
||||
self.subject_table = "triples_s"
|
||||
self.po_table = "triples_p"
|
||||
self.object_table = "triples_o"
|
||||
self.collection_table = "triples_collection" # For SPO queries and deletion
|
||||
self.collection_metadata_table = "collection_metadata" # For tracking which collections exist
|
||||
|
||||
if username and password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
|
|
@ -47,9 +43,7 @@ class KnowledgeGraph:
|
|||
_active_clusters.append(self.cluster)
|
||||
|
||||
self.init()
|
||||
|
||||
if not self.use_legacy:
|
||||
self.prepare_statements()
|
||||
self.prepare_statements()
|
||||
|
||||
def clear(self):
|
||||
|
||||
|
|
@ -70,42 +64,13 @@ class KnowledgeGraph:
|
|||
""");
|
||||
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
self.init_optimized_schema()
|
||||
|
||||
if self.use_legacy:
|
||||
self.init_legacy_schema()
|
||||
else:
|
||||
self.init_optimized_schema()
|
||||
|
||||
def init_legacy_schema(self):
|
||||
"""Initialize legacy single-table schema for backward compatibility"""
|
||||
self.session.execute(f"""
|
||||
create table if not exists {self.table} (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (collection, s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_s
|
||||
ON {self.table} (s);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_p
|
||||
ON {self.table} (p);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_o
|
||||
ON {self.table} (o);
|
||||
""");
|
||||
|
||||
def init_optimized_schema(self):
|
||||
"""Initialize optimized multi-table schema for performance"""
|
||||
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
|
||||
# Table 1: Subject-centric queries (get_s, get_sp, get_os)
|
||||
# Compound partition key for optimal data distribution
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.subject_table} (
|
||||
collection text,
|
||||
|
|
@ -117,6 +82,7 @@ class KnowledgeGraph:
|
|||
""");
|
||||
|
||||
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
|
||||
# Compound partition key for optimal data distribution
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.po_table} (
|
||||
collection text,
|
||||
|
|
@ -128,6 +94,7 @@ class KnowledgeGraph:
|
|||
""");
|
||||
|
||||
# Table 3: Object-centric queries (get_o)
|
||||
# Compound partition key for optimal data distribution
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.object_table} (
|
||||
collection text,
|
||||
|
|
@ -138,7 +105,29 @@ class KnowledgeGraph:
|
|||
);
|
||||
""");
|
||||
|
||||
logger.info("Optimized multi-table schema initialized")
|
||||
# Table 4: Collection management and SPO queries (get_spo)
|
||||
# Simple partition key enables efficient collection deletion
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.collection_table} (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (collection, s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
# Table 5: Collection metadata tracking
|
||||
# Tracks which collections exist without polluting triple data
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.collection_metadata_table} (
|
||||
collection text,
|
||||
created_at timestamp,
|
||||
PRIMARY KEY (collection)
|
||||
);
|
||||
""");
|
||||
|
||||
logger.info("Optimized multi-table schema initialized (5 tables)")
|
||||
|
||||
def prepare_statements(self):
|
||||
"""Prepare statements for optimal performance"""
|
||||
|
|
@ -155,6 +144,10 @@ class KnowledgeGraph:
|
|||
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
self.insert_collection_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
# Query statements for optimized access
|
||||
self.get_all_stmt = self.session.prepare(
|
||||
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING"
|
||||
|
|
@ -186,158 +179,177 @@ class KnowledgeGraph:
|
|||
)
|
||||
|
||||
self.get_spo_stmt = self.session.prepare(
|
||||
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
|
||||
f"SELECT s as x FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
|
||||
)
|
||||
|
||||
logger.info("Prepared statements initialized for optimal performance")
|
||||
# Delete statements for collection deletion
|
||||
self.delete_subject_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
|
||||
)
|
||||
|
||||
self.delete_po_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?"
|
||||
)
|
||||
|
||||
self.delete_object_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?"
|
||||
)
|
||||
|
||||
self.delete_collection_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
|
||||
)
|
||||
|
||||
logger.info("Prepared statements initialized for optimal performance (4 tables)")
|
||||
|
||||
def insert(self, collection, s, p, o):
|
||||
# Batch write to all four tables for consistency
|
||||
batch = BatchStatement()
|
||||
|
||||
if self.use_legacy:
|
||||
self.session.execute(
|
||||
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
|
||||
(collection, s, p, o)
|
||||
)
|
||||
else:
|
||||
# Batch write to all three tables for consistency
|
||||
batch = BatchStatement()
|
||||
# Insert into subject table
|
||||
batch.add(self.insert_subject_stmt, (collection, s, p, o))
|
||||
|
||||
# Insert into subject table
|
||||
batch.add(self.insert_subject_stmt, (collection, s, p, o))
|
||||
# Insert into predicate-object table (column order: collection, p, o, s)
|
||||
batch.add(self.insert_po_stmt, (collection, p, o, s))
|
||||
|
||||
# Insert into predicate-object table (column order: collection, p, o, s)
|
||||
batch.add(self.insert_po_stmt, (collection, p, o, s))
|
||||
# Insert into object table (column order: collection, o, s, p)
|
||||
batch.add(self.insert_object_stmt, (collection, o, s, p))
|
||||
|
||||
# Insert into object table (column order: collection, o, s, p)
|
||||
batch.add(self.insert_object_stmt, (collection, o, s, p))
|
||||
# Insert into collection table for SPO queries and deletion tracking
|
||||
batch.add(self.insert_collection_stmt, (collection, s, p, o))
|
||||
|
||||
self.session.execute(batch)
|
||||
self.session.execute(batch)
|
||||
|
||||
def get_all(self, collection, limit=50):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, p, o from {self.table} where collection = %s limit {limit}",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Use subject table for get_all queries
|
||||
return self.session.execute(
|
||||
self.get_all_stmt,
|
||||
(collection, limit)
|
||||
)
|
||||
# Use subject table for get_all queries
|
||||
return self.session.execute(
|
||||
self.get_all_stmt,
|
||||
(collection, limit)
|
||||
)
|
||||
|
||||
def get_s(self, collection, s, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
|
||||
(collection, s)
|
||||
)
|
||||
else:
|
||||
# Optimized: Direct partition access with (collection, s)
|
||||
return self.session.execute(
|
||||
self.get_s_stmt,
|
||||
(collection, s, limit)
|
||||
)
|
||||
# Optimized: Direct partition access with (collection, s)
|
||||
return self.session.execute(
|
||||
self.get_s_stmt,
|
||||
(collection, s, limit)
|
||||
)
|
||||
|
||||
def get_p(self, collection, p, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
|
||||
(collection, p)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use po_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_p_stmt,
|
||||
(collection, p, limit)
|
||||
)
|
||||
# Optimized: Use po_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_p_stmt,
|
||||
(collection, p, limit)
|
||||
)
|
||||
|
||||
def get_o(self, collection, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
|
||||
(collection, o)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use object_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_o_stmt,
|
||||
(collection, o, limit)
|
||||
)
|
||||
# Optimized: Use object_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_o_stmt,
|
||||
(collection, o, limit)
|
||||
)
|
||||
|
||||
def get_sp(self, collection, s, p, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
|
||||
(collection, s, p)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table with clustering key access
|
||||
return self.session.execute(
|
||||
self.get_sp_stmt,
|
||||
(collection, s, p, limit)
|
||||
)
|
||||
# Optimized: Use subject_table with clustering key access
|
||||
return self.session.execute(
|
||||
self.get_sp_stmt,
|
||||
(collection, s, p, limit)
|
||||
)
|
||||
|
||||
def get_po(self, collection, p, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
|
||||
(collection, p, o)
|
||||
)
|
||||
else:
|
||||
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
|
||||
return self.session.execute(
|
||||
self.get_po_stmt,
|
||||
(collection, p, o, limit)
|
||||
)
|
||||
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
|
||||
return self.session.execute(
|
||||
self.get_po_stmt,
|
||||
(collection, p, o, limit)
|
||||
)
|
||||
|
||||
def get_os(self, collection, o, s, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
|
||||
(collection, o, s)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
|
||||
return self.session.execute(
|
||||
self.get_os_stmt,
|
||||
(collection, s, o, limit)
|
||||
)
|
||||
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
|
||||
return self.session.execute(
|
||||
self.get_os_stmt,
|
||||
(collection, s, o, limit)
|
||||
)
|
||||
|
||||
def get_spo(self, collection, s, p, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
|
||||
(collection, s, p, o)
|
||||
# Optimized: Use collection_table for exact key lookup
|
||||
return self.session.execute(
|
||||
self.get_spo_stmt,
|
||||
(collection, s, p, o, limit)
|
||||
)
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists by querying collection_metadata table"""
|
||||
try:
|
||||
result = self.session.execute(
|
||||
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table for exact key lookup
|
||||
return self.session.execute(
|
||||
self.get_spo_stmt,
|
||||
(collection, s, p, o, limit)
|
||||
return bool(list(result))
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def create_collection(self, collection):
|
||||
"""Create collection by inserting metadata row"""
|
||||
try:
|
||||
import datetime
|
||||
self.session.execute(
|
||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||
(collection, datetime.datetime.now())
|
||||
)
|
||||
logger.info(f"Created collection metadata for {collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating collection: {e}")
|
||||
raise e
|
||||
|
||||
def delete_collection(self, collection):
|
||||
"""Delete all triples for a specific collection"""
|
||||
if self.use_legacy:
|
||||
self.session.execute(
|
||||
f"delete from {self.table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Delete from all three tables
|
||||
self.session.execute(
|
||||
f"delete from {self.subject_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
self.session.execute(
|
||||
f"delete from {self.po_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
self.session.execute(
|
||||
f"delete from {self.object_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
"""Delete all triples for a specific collection
|
||||
|
||||
Uses collection_table to enumerate all triples, then deletes from all 4 tables
|
||||
using full partition keys for optimal performance with compound keys.
|
||||
"""
|
||||
# Step 1: Read all triples from collection_table (single partition read)
|
||||
rows = self.session.execute(
|
||||
f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
# Step 2: Delete each triple from all 4 tables using full partition keys
|
||||
# Batch deletions for efficiency
|
||||
batch = BatchStatement()
|
||||
count = 0
|
||||
|
||||
for row in rows:
|
||||
s, p, o = row.s, row.p, row.o
|
||||
|
||||
# Delete from subject table (partition key: collection, s)
|
||||
batch.add(self.delete_subject_stmt, (collection, s, p, o))
|
||||
|
||||
# Delete from predicate-object table (partition key: collection, p)
|
||||
batch.add(self.delete_po_stmt, (collection, p, o, s))
|
||||
|
||||
# Delete from object table (partition key: collection, o)
|
||||
batch.add(self.delete_object_stmt, (collection, o, s, p))
|
||||
|
||||
# Delete from collection table (partition key: collection only)
|
||||
batch.add(self.delete_collection_stmt, (collection, s, p, o))
|
||||
|
||||
count += 1
|
||||
|
||||
# Execute batch every 100 triples to avoid oversized batches
|
||||
if count % 100 == 0:
|
||||
self.session.execute(batch)
|
||||
batch = BatchStatement()
|
||||
|
||||
# Execute remaining deletions
|
||||
if count % 100 != 0:
|
||||
self.session.execute(batch)
|
||||
|
||||
# Step 3: Delete collection metadata
|
||||
self.session.execute(
|
||||
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {count} triples from collection {collection}")
|
||||
|
||||
def close(self):
|
||||
"""Close the Cassandra session and cluster connections properly"""
|
||||
|
|
|
|||
|
|
@ -49,6 +49,22 @@ class DocVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection exists (dimension-independent check)"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
return self.client.has_collection(collection_name)
|
||||
|
||||
def create_collection(self, user, collection, dimension=384):
|
||||
"""Create collection with default dimension"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
if self.client.has_collection(collection_name):
|
||||
logger.info(f"Collection {collection_name} already exists")
|
||||
return
|
||||
|
||||
self.init_collection(dimension, user, collection)
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
|
@ -128,14 +144,6 @@ class DocVectors:
|
|||
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
|
|
@ -145,10 +153,11 @@ class DocVectors:
|
|||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
anns_field="vector",
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
search_params={ "metric_type": "COSINE" },
|
||||
)[0]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,22 @@ class EntityVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection exists (dimension-independent check)"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
return self.client.has_collection(collection_name)
|
||||
|
||||
def create_collection(self, user, collection, dimension=384):
|
||||
"""Create collection with default dimension"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
if self.client.has_collection(collection_name):
|
||||
logger.info(f"Collection {collection_name} already exists")
|
||||
return
|
||||
|
||||
self.init_collection(dimension, user, collection)
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
|
@ -128,14 +144,6 @@ class EntityVectors:
|
|||
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
|
|
@ -145,10 +153,11 @@ class EntityVectors:
|
|||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
anns_field="vector",
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
search_params={ "metric_type": "COSINE" },
|
||||
)[0]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class CollectionManager:
|
|||
|
||||
async def ensure_collection_exists(self, user: str, collection: str):
|
||||
"""
|
||||
Ensure a collection exists, creating it if necessary (lazy creation)
|
||||
Ensure a collection exists, creating it if necessary with broadcast to storage
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
|
|
@ -74,7 +74,7 @@ class CollectionManager:
|
|||
return
|
||||
|
||||
# Create new collection with default metadata
|
||||
logger.info(f"Creating new collection {user}/{collection}")
|
||||
logger.info(f"Auto-creating collection {user}/{collection} from document submission")
|
||||
await self.table_store.create_collection(
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -83,10 +83,64 @@ class CollectionManager:
|
|||
tags=set()
|
||||
)
|
||||
|
||||
# Broadcast collection creation to all storage backends
|
||||
creation_key = (user, collection)
|
||||
logger.info(f"Broadcasting create-collection for {creation_key}")
|
||||
|
||||
self.pending_deletions[creation_key] = {
|
||||
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
"deletion_complete": asyncio.Event()
|
||||
}
|
||||
|
||||
storage_request = StorageManagementRequest(
|
||||
operation="create-collection",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Send creation requests to all storage types
|
||||
if self.vector_storage_producer:
|
||||
await self.vector_storage_producer.send(storage_request)
|
||||
if self.object_storage_producer:
|
||||
await self.object_storage_producer.send(storage_request)
|
||||
if self.triples_storage_producer:
|
||||
await self.triples_storage_producer.send(storage_request)
|
||||
|
||||
# Wait for all storage creations to complete (with timeout)
|
||||
creation_info = self.pending_deletions[creation_key]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
creation_info["deletion_complete"].wait(),
|
||||
timeout=30.0 # 30 second timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
|
||||
creation_info["all_successful"] = False
|
||||
creation_info["error_messages"].append("Timeout waiting for storage creation")
|
||||
|
||||
# Check if all creations succeeded
|
||||
if not creation_info["all_successful"]:
|
||||
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Clean up metadata on failure
|
||||
await self.table_store.delete_collection(user, collection)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
logger.info(f"Collection {creation_key} auto-created successfully in all storage backends")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
# Don't fail the operation if collection creation fails
|
||||
# This maintains backward compatibility
|
||||
raise e
|
||||
|
||||
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
||||
"""
|
||||
|
|
@ -154,6 +208,67 @@ class CollectionManager:
|
|||
tags=tags
|
||||
)
|
||||
|
||||
# Broadcast collection creation to all storage backends
|
||||
creation_key = (request.user, request.collection)
|
||||
logger.info(f"Broadcasting create-collection for {creation_key}")
|
||||
|
||||
self.pending_deletions[creation_key] = {
|
||||
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
"deletion_complete": asyncio.Event()
|
||||
}
|
||||
|
||||
storage_request = StorageManagementRequest(
|
||||
operation="create-collection",
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Send creation requests to all storage types
|
||||
if self.vector_storage_producer:
|
||||
await self.vector_storage_producer.send(storage_request)
|
||||
if self.object_storage_producer:
|
||||
await self.object_storage_producer.send(storage_request)
|
||||
if self.triples_storage_producer:
|
||||
await self.triples_storage_producer.send(storage_request)
|
||||
|
||||
# Wait for all storage creations to complete (with timeout)
|
||||
creation_info = self.pending_deletions[creation_key]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
creation_info["deletion_complete"].wait(),
|
||||
timeout=30.0 # 30 second timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
|
||||
creation_info["all_successful"] = False
|
||||
creation_info["error_messages"].append("Timeout waiting for storage creation")
|
||||
|
||||
# Check if all creations succeeded
|
||||
if not creation_info["all_successful"]:
|
||||
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Clean up metadata on failure
|
||||
await self.table_store.delete_collection(request.user, request.collection)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=Error(
|
||||
type="storage_creation_error",
|
||||
message=error_msg
|
||||
),
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
logger.info(f"Collection {creation_key} created successfully in all storage backends")
|
||||
|
||||
# Get the newly created collection for response
|
||||
created_collection = await self.table_store.get_collection(request.user, request.collection)
|
||||
|
||||
|
|
@ -213,7 +328,7 @@ class CollectionManager:
|
|||
|
||||
# Track this deletion request
|
||||
self.pending_deletions[deletion_key] = {
|
||||
"responses_pending": 3, # vector, object, triples
|
||||
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
|
|
@ -303,9 +418,9 @@ class CollectionManager:
|
|||
if response.error and response.error.message:
|
||||
info["all_successful"] = False
|
||||
info["error_messages"].append(response.error.message)
|
||||
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
|
||||
logger.warning(f"Storage operation failed for {deletion_key}: {response.error.message}")
|
||||
else:
|
||||
logger.debug(f"Storage deletion succeeded for {deletion_key}")
|
||||
logger.debug(f"Storage operation succeeded for {deletion_key}")
|
||||
|
||||
# If all responses received, signal completion
|
||||
if info["responses_pending"] == 0:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class Processor(LlmService):
|
|||
token = params.get("token", default_token)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = default_model
|
||||
model = params.get("model", default_model)
|
||||
|
||||
if endpoint is None:
|
||||
raise RuntimeError("Azure endpoint not specified")
|
||||
|
|
@ -53,9 +53,11 @@ class Processor(LlmService):
|
|||
self.token = token
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
|
||||
def build_prompt(self, system, content):
|
||||
def build_prompt(self, system, content, temperature=None):
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
|
|
@ -67,7 +69,7 @@ class Processor(LlmService):
|
|||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"temperature": effective_temperature,
|
||||
"top_p": 1
|
||||
}
|
||||
|
||||
|
|
@ -100,13 +102,22 @@ class Processor(LlmService):
|
|||
|
||||
return result
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
|
||||
prompt = self.build_prompt(
|
||||
system,
|
||||
prompt
|
||||
prompt,
|
||||
effective_temperature
|
||||
)
|
||||
|
||||
response = self.call_llm(prompt)
|
||||
|
|
@ -125,7 +136,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class Processor(LlmService):
|
|||
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
|
||||
self.openai = AzureOpenAI(
|
||||
api_key=token,
|
||||
|
|
@ -62,14 +62,22 @@ class Processor(LlmService):
|
|||
azure_endpoint = endpoint,
|
||||
)
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -81,7 +89,7 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
)
|
||||
|
|
@ -97,7 +105,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -41,21 +41,29 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.claude = anthropic.Anthropic(api_key=api_key)
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
logger.info("Claude LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
|
||||
response = message = self.claude.messages.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
max_tokens=self.max_output,
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
system = system,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -81,7 +89,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -39,21 +39,29 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.cohere = cohere.Client(api_key=api_key)
|
||||
|
||||
logger.info("Cohere LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
|
||||
output = self.cohere.chat(
|
||||
model=self.model,
|
||||
output = self.cohere.chat(
|
||||
model=model_name,
|
||||
message=prompt,
|
||||
preamble = system,
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
|
|
@ -71,7 +79,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -53,10 +53,13 @@ class Processor(LlmService):
|
|||
)
|
||||
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
# Cache for generation configs per model
|
||||
self.generation_configs = {}
|
||||
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
|
||||
self.safety_settings = [
|
||||
|
|
@ -83,22 +86,45 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("GoogleAIStudio LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
def _get_or_create_config(self, model_name, temperature=None):
|
||||
"""Get or create generation config with dynamic temperature"""
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature = self.temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
max_output_tokens = self.max_output,
|
||||
response_mime_type = "text/plain",
|
||||
system_instruction = system,
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
# Create cache key that includes temperature to avoid conflicts
|
||||
cache_key = f"{model_name}:{effective_temperature}"
|
||||
|
||||
if cache_key not in self.generation_configs:
|
||||
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
||||
self.generation_configs[cache_key] = types.GenerateContentConfig(
|
||||
temperature = effective_temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
max_output_tokens = self.max_output,
|
||||
response_mime_type = "text/plain",
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
|
||||
return self.generation_configs[cache_key]
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
generation_config = self._get_or_create_config(model_name, effective_temperature)
|
||||
# Set system instruction per request (can't be cached)
|
||||
generation_config.system_instruction = system
|
||||
|
||||
try:
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
|
|
@ -114,7 +140,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.llamafile=llamafile
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
|
@ -50,25 +50,33 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("Llamafile LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
#temperature=self.temperature,
|
||||
#max_tokens=self.max_output,
|
||||
#top_p=1,
|
||||
#frequency_penalty=0,
|
||||
#presence_penalty=0,
|
||||
#response_format={
|
||||
# "type": "text"
|
||||
#}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
|
|
@ -82,7 +90,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = "llama.cpp",
|
||||
model = model_name,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.url = url + "v1/"
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
|
@ -50,7 +50,15 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("LMStudio LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
|
|
@ -59,18 +67,18 @@ class Processor(LlmService):
|
|||
logger.debug(f"Prompt: {prompt}")
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
#temperature=self.temperature,
|
||||
#max_tokens=self.max_output,
|
||||
#top_p=1,
|
||||
#frequency_penalty=0,
|
||||
#presence_penalty=0,
|
||||
#response_format={
|
||||
# "type": "text"
|
||||
#}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Full response: {resp}")
|
||||
|
|
@ -86,7 +94,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -41,21 +41,29 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.mistral = Mistral(api_key=api_key)
|
||||
|
||||
logger.info("Mistral LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.mistral.chat.complete(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -67,7 +75,7 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
|
|
@ -87,7 +95,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from .... base import LlmService, LlmResult
|
|||
default_ident = "text-completion"
|
||||
|
||||
default_model = 'gemma2:9b'
|
||||
default_temperature = 0.0
|
||||
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
|
||||
|
||||
class Processor(LlmService):
|
||||
|
|
@ -24,25 +25,36 @@ class Processor(LlmService):
|
|||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
ollama = params.get("ollama", default_ollama)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"ollama": ollama,
|
||||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.llm = Client(host=ollama)
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
response = self.llm.generate(self.model, prompt)
|
||||
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
|
||||
|
||||
response_text = response['response']
|
||||
logger.debug("Sending response...")
|
||||
|
|
@ -55,7 +67,7 @@ class Processor(LlmService):
|
|||
text = response_text,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
@ -84,6 +96,13 @@ class Processor(LlmService):
|
|||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
|
|
@ -58,14 +58,22 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("OpenAI LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -77,7 +85,7 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
|
|
@ -97,7 +105,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -30,32 +30,43 @@ class Processor(LlmService):
|
|||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = params.get("model", "tgi")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"url": base_url,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.default_model = model
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
logger.info(f"Using TGI service at {base_url}")
|
||||
logger.info("TGI LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -67,7 +78,7 @@ class Processor(LlmService):
|
|||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"temperature": effective_temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -96,7 +107,7 @@ class Processor(LlmService):
|
|||
text = ans,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = "tgi",
|
||||
model = model_name,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -45,24 +45,32 @@ class Processor(LlmService):
|
|||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
logger.info(f"Using vLLM service at {base_url}")
|
||||
logger.info("vLLM LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"model": model_name,
|
||||
"prompt": system + "\n\n" + prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"temperature": effective_temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -91,7 +99,7 @@ class Processor(LlmService):
|
|||
text = ans,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model,
|
||||
model = model_name,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -57,21 +57,26 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
async def query_document_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
chunks = []
|
||||
|
||||
collection = (
|
||||
"d_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
return []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
self.ensure_collection_exists(collection, dim)
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
|
|
@ -70,12 +74,16 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
collection = (
|
||||
"t_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"t_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
return []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
self.ensure_collection_exists(collection, dim)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,56 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
CRITICAL SECURITY WARNING:
|
||||
This cache is shared within a GraphRag instance but GraphRag instances
|
||||
are created per-request. Cache keys MUST include user:collection prefix
|
||||
to ensure data isolation between different security contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size=5000, ttl=300):
|
||||
self.cache = OrderedDict()
|
||||
self.access_times = {}
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
|
||||
def get(self, key):
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
# Check TTL expiration
|
||||
if time.time() - self.access_times[key] > self.ttl:
|
||||
del self.cache[key]
|
||||
del self.access_times[key]
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
if len(self.cache) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest_key = next(iter(self.cache))
|
||||
del self.cache[oldest_key]
|
||||
del self.access_times[oldest_key]
|
||||
|
||||
self.cache[key] = value
|
||||
self.access_times[key] = time.time()
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -61,8 +105,14 @@ class Query:
|
|||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
# CRITICAL SECURITY: Cache key MUST include user and collection
|
||||
# to prevent data leakage between different contexts
|
||||
cache_key = f"{self.user}:{self.collection}:{e}"
|
||||
|
||||
# Check LRU cache first with isolated key
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
if cached_label is not None:
|
||||
return cached_label
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
|
|
@ -70,60 +120,104 @@ class Query:
|
|||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
self.rag.label_cache.put(cache_key, e)
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = str(res[0].o)
|
||||
return self.rag.label_cache[e]
|
||||
label = str(res[0].o)
|
||||
self.rag.label_cache.put(cache_key, label)
|
||||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently"""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
all_triples.extend(result)
|
||||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
|
||||
# Not needed?
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
# Stop spanning around if the subgraph is already maxed out
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=ent, p=None, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(str(triple.o), subgraph, path_length-1)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=ent, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=None, o=ent,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(
|
||||
str(triple.s), subgraph, path_length-1
|
||||
)
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
|
||||
|
|
@ -132,31 +226,52 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
subgraph = set()
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
for ent in entities:
|
||||
await self.follow_edges(ent, subgraph, self.max_path_length)
|
||||
return list(subgraph)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
return subgraph
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
|
||||
subgraph = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = await self.maybe_label(edge[0])
|
||||
p = await self.maybe_label(edge[1])
|
||||
o = await self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
)
|
||||
sg2.append(labeled_triple)
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
|
||||
|
|
@ -171,6 +286,13 @@ class Query:
|
|||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
CRITICAL SECURITY:
|
||||
This class MUST be instantiated per-request to ensure proper isolation
|
||||
between users and collections. The cache within this instance will only
|
||||
live for the duration of a single request, preventing cross-contamination
|
||||
of data between different security contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
|
|
@ -184,7 +306,9 @@ class GraphRag:
|
|||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
|
||||
self.label_cache = {}
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("GraphRag initialized")
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ class Processor(FlowProcessor):
|
|||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
# Caching must NEVER allow information leakage across these boundaries
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
|
|
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
self.rag = GraphRag(
|
||||
embeddings_client = flow("embeddings-request"),
|
||||
graph_embeddings_client = flow("graph-embeddings-request"),
|
||||
triples_client = flow("triples-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||
# This ensures proper isolation between users and collections
|
||||
# Flow clients are request-scoped and must not be shared
|
||||
rag = GraphRag(
|
||||
embeddings_client=flow("embeddings-request"),
|
||||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
response = await self.rag.query(
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
|
|||
|
|
@ -60,19 +60,34 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
self.vecstore.insert(
|
||||
vec, chunk,
|
||||
message.metadata.user,
|
||||
vec, chunk,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
|
|
@ -87,18 +102,21 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -113,17 +131,40 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Milvus collection for document embeddings"""
|
||||
try:
|
||||
if self.vecstore.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.vecstore.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
self.vecstore.delete_collection(message.user, message.collection)
|
||||
self.vecstore.delete_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -115,38 +115,36 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
index_name = (
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.pinecone.has_index(index_name):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
||||
if not self.pinecone.has_index(index_name):
|
||||
|
||||
try:
|
||||
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pinecone index creation failed")
|
||||
raise e
|
||||
|
||||
logger.info(f"Index {index_name} created")
|
||||
|
||||
self.last_index_name = index_name
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Generate unique ID for each vector
|
||||
|
|
@ -192,18 +190,21 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Pinecone region, (default: {default_region}'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -218,10 +219,36 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Pinecone index for document embeddings"""
|
||||
try:
|
||||
index_name = f"d-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
logger.info(f"Pinecone index {index_name} already exists")
|
||||
else:
|
||||
# Create with default dimension - will need to be recreated if dimension doesn't match
|
||||
self.create_index(index_name, dim=384)
|
||||
logger.info(f"Created Pinecone index: {index_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
index_name = f"d-{message.user}-{message.collection}"
|
||||
index_name = f"d-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
self.pinecone.delete_index(index_name)
|
||||
|
|
@ -234,7 +261,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -36,8 +36,6 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
}
|
||||
)
|
||||
|
||||
self.last_collection = None
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
# Set up storage management if base class attributes are available
|
||||
|
|
@ -71,8 +69,30 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
if hasattr(self, 'storage_request_consumer'):
|
||||
await self.storage_request_consumer.start()
|
||||
if hasattr(self, 'storage_response_producer'):
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
collection = (
|
||||
"d_" + message.metadata.user + "_" +
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
|
|
@ -80,29 +100,6 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
for vec in emb.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + message.metadata.user + "_" +
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Qdrant collection creation failed")
|
||||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
|
||||
self.qdrant.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
|
|
@ -133,18 +130,21 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Qdrant API key (default: None)'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -159,10 +159,43 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Qdrant collection for document embeddings"""
|
||||
try:
|
||||
collection_name = f"d_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
logger.info(f"Qdrant collection {collection_name} already exists")
|
||||
else:
|
||||
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
|
||||
# Using a placeholder dimension - actual dimension determined by first embedding
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=384, # Default dimension, common for many models
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
logger.info(f"Created Qdrant collection: {collection_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
collection_name = f"d_{message.user}_{message.collection}"
|
||||
collection_name = f"d_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
|
|
@ -175,7 +208,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -60,8 +60,23 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value != "" and entity.entity.value is not None:
|
||||
|
|
@ -83,18 +98,21 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -109,17 +127,40 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Milvus collection for graph embeddings"""
|
||||
try:
|
||||
if self.vecstore.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.vecstore.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
self.vecstore.delete_collection(message.user, message.collection)
|
||||
self.vecstore.delete_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -115,8 +115,27 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
index_name = (
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.pinecone.has_index(index_name):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value == "" or entity.entity.value is None:
|
||||
|
|
@ -124,28 +143,6 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
for vec in entity.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
||||
if not self.pinecone.has_index(index_name):
|
||||
|
||||
try:
|
||||
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pinecone index creation failed")
|
||||
raise e
|
||||
|
||||
logger.info(f"Index {index_name} created")
|
||||
|
||||
self.last_index_name = index_name
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Generate unique ID for each vector
|
||||
|
|
@ -191,18 +188,21 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Pinecone region, (default: {default_region}'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -217,10 +217,36 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Pinecone index for graph embeddings"""
|
||||
try:
|
||||
index_name = f"t-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
logger.info(f"Pinecone index {index_name} already exists")
|
||||
else:
|
||||
# Create with default dimension - will need to be recreated if dimension doesn't match
|
||||
self.create_index(index_name, dim=384)
|
||||
logger.info(f"Created Pinecone index: {index_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
index_name = f"t-{message.user}-{message.collection}"
|
||||
index_name = f"t-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
self.pinecone.delete_index(index_name)
|
||||
|
|
@ -233,7 +259,7 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -36,8 +36,6 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
}
|
||||
)
|
||||
|
||||
self.last_collection = None
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
# Set up storage management if base class attributes are available
|
||||
|
|
@ -71,31 +69,30 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def get_collection(self, dim, user, collection):
|
||||
|
||||
def get_collection(self, user, collection):
|
||||
"""Get collection name and validate it exists"""
|
||||
cname = (
|
||||
"t_" + user + "_" + collection
|
||||
)
|
||||
|
||||
if cname != self.last_collection:
|
||||
|
||||
if not self.qdrant.collection_exists(cname):
|
||||
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=cname,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Qdrant collection creation failed")
|
||||
raise e
|
||||
|
||||
self.last_collection = cname
|
||||
if not self.qdrant.collection_exists(cname):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return cname
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
if hasattr(self, 'storage_request_consumer'):
|
||||
await self.storage_request_consumer.start()
|
||||
if hasattr(self, 'storage_response_producer'):
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
for entity in message.entities:
|
||||
|
|
@ -104,10 +101,8 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
for vec in entity.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
collection = self.get_collection(
|
||||
dim, message.metadata.user, message.metadata.collection
|
||||
message.metadata.user, message.metadata.collection
|
||||
)
|
||||
|
||||
self.qdrant.upsert(
|
||||
|
|
@ -140,18 +135,21 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Qdrant API key'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -166,10 +164,43 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Qdrant collection for graph embeddings"""
|
||||
try:
|
||||
collection_name = f"t_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
logger.info(f"Qdrant collection {collection_name} already exists")
|
||||
else:
|
||||
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
|
||||
# Using a placeholder dimension - actual dimension determined by first embedding
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=384, # Default dimension, common for many models
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
logger.info(f"Created Qdrant collection: {collection_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
collection_name = f"t_{message.user}_{message.collection}"
|
||||
collection_name = f"t_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
|
|
@ -182,7 +213,7 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -295,6 +295,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
self.session.execute(create_table_cql)
|
||||
if keyspace not in self.known_tables:
|
||||
self.known_tables[keyspace] = set()
|
||||
self.known_tables[keyspace].add(table_key)
|
||||
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
|
||||
|
||||
|
|
@ -340,18 +342,47 @@ class Processor(FlowProcessor):
|
|||
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
|
||||
return str(value)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_object(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
|
||||
|
||||
# Validate collection/keyspace exists before accepting writes
|
||||
safe_keyspace = self.sanitize_name(obj.metadata.user)
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
# Check if keyspace actually exists in Cassandra
|
||||
self.connect_cassandra()
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
# Check if result is None (mock case) or has no rows
|
||||
if result is None or not result.one():
|
||||
error_msg = (
|
||||
f"Collection {obj.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
# Cache it if it exists
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
if safe_keyspace not in self.known_tables:
|
||||
self.known_tables[safe_keyspace] = set()
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
|
||||
# Ensure table exists
|
||||
keyspace = obj.metadata.user
|
||||
table_name = obj.schema_name
|
||||
|
|
@ -425,26 +456,36 @@ class Processor(FlowProcessor):
|
|||
|
||||
async def on_storage_management(self, msg, consumer, flow):
|
||||
"""Handle storage management requests for collection operations"""
|
||||
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
|
||||
request = msg.value()
|
||||
logger.info(f"Received storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if msg.operation == "delete-collection":
|
||||
await self.delete_collection(msg.user, msg.collection)
|
||||
if request.operation == "create-collection":
|
||||
await self.create_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
|
||||
logger.info(f"Successfully created collection {request.user}/{request.collection}")
|
||||
elif request.operation == "delete-collection":
|
||||
await self.delete_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
else:
|
||||
logger.warning(f"Unknown storage management operation: {msg.operation}")
|
||||
logger.warning(f"Unknown storage management operation: {request.operation}")
|
||||
# Send error response
|
||||
from .... schema import Error
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="unknown_operation",
|
||||
message=f"Unknown operation: {msg.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -459,10 +500,28 @@ class Processor(FlowProcessor):
|
|||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.send("storage-response", response)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def create_collection(self, user: str, collection: str):
|
||||
"""Create/verify collection exists in Cassandra object store"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Ensure keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
self.ensure_keyspace(safe_keyspace)
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# For Cassandra objects, collection is just a property in rows
|
||||
# No need to create separate tables per collection
|
||||
# Just mark that we've seen this collection
|
||||
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection"""
|
||||
"""Delete all data for a specific collection using schema information"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
|
|
@ -482,40 +541,78 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# Get all tables in the keyspace that might contain collection data
|
||||
get_tables_cql = """
|
||||
SELECT table_name FROM system_schema.tables
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
|
||||
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
|
||||
# Iterate over schemas we manage to delete from relevant tables
|
||||
tables_deleted = 0
|
||||
|
||||
for row in tables:
|
||||
table_name = row.table_name
|
||||
for schema_name, schema in self.schemas.items():
|
||||
safe_table = self.sanitize_table(schema_name)
|
||||
|
||||
# Check if the table has a collection column
|
||||
check_column_cql = """
|
||||
SELECT column_name FROM system_schema.columns
|
||||
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
|
||||
"""
|
||||
# Check if table exists
|
||||
table_key = f"{user}.{schema_name}"
|
||||
if table_key not in self.known_tables.get(user, set()):
|
||||
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
|
||||
continue
|
||||
|
||||
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
|
||||
if result.one():
|
||||
# Table has collection column, delete data for this collection
|
||||
try:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{table_name}
|
||||
try:
|
||||
# Get primary key fields from schema
|
||||
primary_key_fields = [field for field in schema.fields if field.primary]
|
||||
|
||||
if primary_key_fields:
|
||||
# Schema has primary keys: need to query for partition keys first
|
||||
# Build SELECT query for primary key fields
|
||||
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
|
||||
select_cql = f"""
|
||||
SELECT {', '.join(pk_field_names)}
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection,))
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using full partition key
|
||||
for row in rows:
|
||||
where_clauses = ["collection = %s"]
|
||||
values = [collection]
|
||||
|
||||
for field_name in pk_field_names:
|
||||
where_clauses.append(f"{field_name} = %s")
|
||||
values.append(getattr(row, field_name))
|
||||
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE {' AND '.join(where_clauses)}
|
||||
"""
|
||||
|
||||
self.session.execute(delete_cql, tuple(values))
|
||||
else:
|
||||
# No primary keys, uses synthetic_id
|
||||
# Need to query for synthetic_ids first
|
||||
select_cql = f"""
|
||||
SELECT synthetic_id
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using collection and synthetic_id
|
||||
for row in rows:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s AND synthetic_id = %s
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection, row.synthetic_id))
|
||||
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
|
|
|
|||
|
|
@ -109,6 +109,15 @@ class Processor(TriplesStoreService):
|
|||
|
||||
self.table = user
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.tg.collection_exists(message.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
self.tg.insert(
|
||||
message.metadata.collection,
|
||||
|
|
@ -117,18 +126,27 @@ class Processor(TriplesStoreService):
|
|||
t.o.value
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -143,42 +161,85 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete all data for a specific collection from the unified triples table"""
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a collection in Cassandra triple store"""
|
||||
try:
|
||||
# Create or reuse connection for this user's keyspace
|
||||
if self.table is None or self.table != message.user:
|
||||
if self.table is None or self.table != request.user:
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.user,
|
||||
keyspace=request.user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.user,
|
||||
keyspace=request.user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
|
||||
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
|
||||
raise
|
||||
|
||||
self.table = message.user
|
||||
self.table = request.user
|
||||
|
||||
# Delete all triples for this collection from the unified table
|
||||
# In the unified table schema, collection is the partition key
|
||||
delete_cql = """
|
||||
DELETE FROM triples
|
||||
WHERE collection = ?
|
||||
"""
|
||||
# Create collection using the built-in method
|
||||
logger.info(f"Creating collection {request.collection} for user {request.user}")
|
||||
|
||||
if self.tg.collection_exists(request.collection):
|
||||
logger.info(f"Collection {request.collection} already exists")
|
||||
else:
|
||||
self.tg.create_collection(request.collection)
|
||||
logger.info(f"Created collection {request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete all data for a specific collection from the unified triples table"""
|
||||
try:
|
||||
# Create or reuse connection for this user's keyspace
|
||||
if self.table is None or self.table != request.user:
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=request.user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=request.user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
|
||||
raise
|
||||
|
||||
self.table = request.user
|
||||
|
||||
# Delete all triples for this collection using the built-in method
|
||||
try:
|
||||
self.tg.session.execute(delete_cql, (message.collection,))
|
||||
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
|
||||
self.tg.delete_collection(request.collection)
|
||||
logger.info(f"Deleted all triples for collection {request.collection} from keyspace {request.user}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection data: {e}")
|
||||
raise
|
||||
|
|
@ -188,7 +249,7 @@ class Processor(TriplesStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -152,11 +152,43 @@ class Processor(TriplesStoreService):
|
|||
time=res.run_time_ms
|
||||
))
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection metadata node exists"""
|
||||
result = self.io.query(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"RETURN c LIMIT 1",
|
||||
params={"user": user, "collection": collection}
|
||||
)
|
||||
return result.result_set is not None and len(result.result_set) > 0
|
||||
|
||||
def create_collection(self, user, collection):
|
||||
"""Create collection metadata node"""
|
||||
import datetime
|
||||
self.io.query(
|
||||
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"SET c.created_at = $created_at",
|
||||
params={
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"created_at": datetime.datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
logger.info(f"Created collection metadata node for {user}/{collection}")
|
||||
|
||||
async def store_triples(self, message):
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(user, collection):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
|
@ -185,18 +217,27 @@ class Processor(TriplesStoreService):
|
|||
help=f'FalkorDB database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -211,28 +252,57 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create collection metadata in FalkorDB"""
|
||||
try:
|
||||
if self.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for FalkorDB triples"""
|
||||
try:
|
||||
# Delete all nodes and literals for this user/collection
|
||||
node_result = self.io.query(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
|
||||
params={"user": message.user, "collection": message.collection}
|
||||
params={"user": request.user, "collection": request.collection}
|
||||
)
|
||||
|
||||
literal_result = self.io.query(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
|
||||
params={"user": message.user, "collection": message.collection}
|
||||
params={"user": request.user, "collection": request.collection}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
|
||||
# Delete collection metadata node
|
||||
metadata_result = self.io.query(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
|
||||
params={"user": request.user, "collection": request.collection}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -267,12 +267,43 @@ class Processor(TriplesStoreService):
|
|||
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
|
||||
)
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection metadata node exists"""
|
||||
with self.io.session(database=self.db) as session:
|
||||
result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"RETURN c LIMIT 1",
|
||||
user=user, collection=collection
|
||||
)
|
||||
return bool(list(result))
|
||||
|
||||
def create_collection(self, user, collection):
|
||||
"""Create collection metadata node"""
|
||||
import datetime
|
||||
with self.io.session(database=self.db) as session:
|
||||
session.run(
|
||||
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"SET c.created_at = $created_at",
|
||||
user=user, collection=collection,
|
||||
created_at=datetime.datetime.now().isoformat()
|
||||
)
|
||||
logger.info(f"Created collection metadata node for {user}/{collection}")
|
||||
|
||||
async def store_triples(self, message):
|
||||
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(user, collection):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
|
@ -317,18 +348,27 @@ class Processor(TriplesStoreService):
|
|||
help=f'Memgraph database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -343,7 +383,30 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create collection metadata in Memgraph"""
|
||||
try:
|
||||
if self.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete all data for a specific collection"""
|
||||
try:
|
||||
with self.io.session(database=self.db) as session:
|
||||
|
|
@ -351,7 +414,7 @@ class Processor(TriplesStoreService):
|
|||
node_result = session.run(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
nodes_deleted = node_result.consume().counters.nodes_deleted
|
||||
|
||||
|
|
@ -359,20 +422,28 @@ class Processor(TriplesStoreService):
|
|||
literal_result = session.run(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
literals_deleted = literal_result.consume().counters.nodes_deleted
|
||||
|
||||
# Delete collection metadata node
|
||||
metadata_result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"DELETE c",
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
metadata_deleted = metadata_result.consume().counters.nodes_deleted
|
||||
|
||||
# Note: Relationships are automatically deleted with DETACH DELETE
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
|
||||
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -228,6 +228,15 @@ class Processor(TriplesStoreService):
|
|||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(user, collection):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
|
@ -268,18 +277,27 @@ class Processor(TriplesStoreService):
|
|||
help=f'Neo4j database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -294,7 +312,52 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection metadata node exists"""
|
||||
with self.io.session(database=self.db) as session:
|
||||
result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"RETURN c LIMIT 1",
|
||||
user=user, collection=collection
|
||||
)
|
||||
return bool(list(result))
|
||||
|
||||
def create_collection(self, user, collection):
|
||||
"""Create collection metadata node"""
|
||||
import datetime
|
||||
with self.io.session(database=self.db) as session:
|
||||
session.run(
|
||||
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"SET c.created_at = $created_at",
|
||||
user=user, collection=collection,
|
||||
created_at=datetime.datetime.now().isoformat()
|
||||
)
|
||||
logger.info(f"Created collection metadata node for {user}/{collection}")
|
||||
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create collection metadata in Neo4j"""
|
||||
try:
|
||||
if self.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete all data for a specific collection"""
|
||||
try:
|
||||
with self.io.session(database=self.db) as session:
|
||||
|
|
@ -302,7 +365,7 @@ class Processor(TriplesStoreService):
|
|||
node_result = session.run(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
nodes_deleted = node_result.consume().counters.nodes_deleted
|
||||
|
||||
|
|
@ -310,20 +373,28 @@ class Processor(TriplesStoreService):
|
|||
literal_result = session.run(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
literals_deleted = literal_result.consume().counters.nodes_deleted
|
||||
|
||||
# Note: Relationships are automatically deleted with DETACH DELETE
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
|
||||
# Delete collection metadata node
|
||||
metadata_result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"DELETE c",
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
metadata_deleted = metadata_result.consume().counters.nodes_deleted
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class ConfigTableStore:
|
|||
""")
|
||||
|
||||
self.get_all_stmt = self.cassandra.prepare("""
|
||||
SELECT class, key, value FROM config;
|
||||
SELECT class AS cls, key, value FROM config;
|
||||
""")
|
||||
|
||||
self.get_values_stmt = self.cassandra.prepare("""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue