Merge 2.0 to master (#651)

This commit is contained in:
cybermaggedon 2026-02-28 11:03:14 +00:00 committed by GitHub
parent 3666ece2c5
commit b9d7bf9a8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
212 changed files with 13940 additions and 6180 deletions

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.8,<1.9",
"trustgraph-base>=2.0,<2.1",
"aiohttp",
"anthropic",
"scylla-driver",
@ -19,7 +19,6 @@ dependencies = [
"faiss-cpu",
"falkordb",
"fastembed",
"google-genai",
"ibis",
"jsonschema",
"langchain",
@ -61,27 +60,27 @@ api-gateway = "trustgraph.gateway:run"
chunker-recursive = "trustgraph.chunking.recursive:run"
chunker-token = "trustgraph.chunking.token:run"
config-svc = "trustgraph.config.service:run"
de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
document-rag = "trustgraph.retrieval.document_rag:run"
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
embeddings-ollama = "trustgraph.embeddings.ollama:run"
ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-objects = "trustgraph.extract.kg.objects:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
kg-extract-topics = "trustgraph.extract.kg.topics:run"
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
@ -91,8 +90,11 @@ librarian = "trustgraph.librarian:run"
mcp-tool = "trustgraph.agent.mcp_tool:run"
metering = "trustgraph.metering:run"
nlp-query = "trustgraph.retrieval.nlp_query:run"
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
rows-query-cassandra = "trustgraph.query.rows.cassandra:run"
row-embeddings = "trustgraph.embeddings.row_embeddings:run"
row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run"
row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run"
pdf-decoder = "trustgraph.decoding.pdf:run"
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
prompt-template = "trustgraph.prompt.template:run"
@ -104,7 +106,6 @@ text-completion-azure = "trustgraph.model.text_completion.azure:run"
text-completion-azure-openai = "trustgraph.model.text_completion.azure_openai:run"
text-completion-claude = "trustgraph.model.text_completion.claude:run"
text-completion-cohere = "trustgraph.model.text_completion.cohere:run"
text-completion-googleaistudio = "trustgraph.model.text_completion.googleaistudio:run"
text-completion-llamafile = "trustgraph.model.text_completion.llamafile:run"
text-completion-lmstudio = "trustgraph.model.text_completion.lmstudio:run"
text-completion-mistral = "trustgraph.model.text_completion.mistral:run"

View file

@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
from . agent_manager import AgentManager
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
@ -87,6 +88,20 @@ class Processor(AgentService):
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
RowEmbeddingsQueryClientSpec(
request_name = "row-embeddings-query-request",
response_name = "row-embeddings-query-response",
)
)
async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}")
@ -147,11 +162,21 @@ class Processor(AgentService):
)
elif impl_id == "structured-query":
impl = functools.partial(
StructuredQueryImpl,
StructuredQueryImpl,
collection=data.get("collection"),
user=None # User will be provided dynamically via context
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
impl = functools.partial(
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None, # User will be provided dynamically via context
index_name=data.get("index-name"), # Optional filter
limit=int(data.get("limit", 10)) # Max results
)
arguments = RowEmbeddingsQueryImpl.get_arguments()
else:
raise RuntimeError(
f"Tool type {impl_id} not known"
@ -327,11 +352,11 @@ class Processor(AgentService):
def __init__(self, flow, user):
self._flow = flow
self._user = user
def __call__(self, service_name):
client = self._flow(service_name)
# For structured query clients, store user context
if service_name == "structured-query-request":
# For query clients that need user context, store it
if service_name in ("structured-query-request", "row-embeddings-query-request"):
client._current_user = self._user
return client

View file

@ -128,6 +128,62 @@ class StructuredQueryImpl:
return str(result)
# This tool implementation knows how to query row embeddings for semantic search
class RowEmbeddingsQueryImpl:
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
self.context = context
self.schema_name = schema_name
self.collection = collection
self.user = user
self.index_name = index_name # Optional: filter to specific index
self.limit = limit # Max results to return
@staticmethod
def get_arguments():
return [
Argument(
name="query",
type="string",
description="Text to search for semantically similar values in the structured data index"
)
]
async def invoke(self, **arguments):
# First get embeddings for the query text
embeddings_client = self.context("embeddings-request")
logger.debug("Getting embeddings for row query...")
query_text = arguments.get("query")
vectors = await embeddings_client.embed(query_text)
# Now query row embeddings
client = self.context("row-embeddings-query-request")
logger.debug("Row embeddings query...")
# Get user from client context if available
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vectors=vectors,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",
index_name=self.index_name,
limit=self.limit
)
# Format results for agent consumption
if not matches:
return "No matching records found"
results = []
for match in matches:
result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})"
results.append(result)
return "Matching records:\n" + "\n".join(results)
# This tool implementation knows how to execute prompt templates
class PromptImpl:
def __init__(self, context, template_id, arguments=None):

View file

@ -124,12 +124,13 @@ class Processor(AsyncProcessor):
logger.info(f"Configuration version: {version}")
if "flows" in config:
if "flow" in config:
self.flows = {
k: json.loads(v)
for k, v in config["flows"].items()
for k, v in config["flow"].items()
}
else:
self.flows = {}
logger.debug(f"Flows: {self.flows}")

File diff suppressed because it is too large Load diff

View file

@ -16,12 +16,14 @@ import logging
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings"
default_batch_size = 5
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
self.batch_size = params.get("batch_size", default_batch_size)
super(Processor, self).__init__(
**params | {
@ -73,12 +75,14 @@ class Processor(FlowProcessor):
)
)
r = GraphEmbeddings(
metadata=v.metadata,
entities=entities,
)
await flow("output").send(r)
# Send in batches to avoid oversized messages
for i in range(0, len(entities), self.batch_size):
batch = entities[i:i + self.batch_size]
r = GraphEmbeddings(
metadata=v.metadata,
entities=batch,
)
await flow("output").send(r)
except Exception as e:
logger.error("Exception occurred", exc_info=True)
@ -91,6 +95,13 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'--batch-size',
type=int,
default=default_batch_size,
help=f'Maximum entities per output message (default: {default_batch_size})'
)
FlowProcessor.add_args(parser)
def run():

View file

@ -0,0 +1,3 @@
from . embeddings import *

View file

@ -0,0 +1,6 @@
from . embeddings import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,263 @@
"""
Row embeddings processor. Calls the embeddings service to compute embeddings
for indexed field values in extracted row data.
Input is ExtractedObject (structured row data with schema).
Output is RowEmbeddings (row data with embeddings for indexed fields).
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (this processor): Compute embeddings
Stage 2 (row-embeddings-write-*): Store embeddings
"""
import json
import logging
from typing import Dict, List, Set
from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding
from ... schema import RowSchema, Field
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec, CollectionConfigHandler
logger = logging.getLogger(__name__)
default_ident = "row-embeddings"
default_batch_size = 10
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
self.batch_size = params.get("batch_size", default_batch_size)
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_message,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name="output",
schema=RowEmbeddings
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""Build the index_value list for a given index."""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
values.append(str(value) if value is not None else "")
return values
def build_text_for_embedding(self, index_value: List[str]) -> str:
"""Build text representation for embedding from index values."""
# Space-join the values for composite indexes
return " ".join(index_value)
async def on_message(self, msg, consumer, flow):
"""Process incoming ExtractedObject and compute embeddings"""
obj = msg.value()
logger.info(
f"Computing embeddings for {len(obj.values)} rows, "
f"schema {obj.schema_name}, doc {obj.metadata.id}"
)
# Validate collection exists before processing
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
logger.warning(
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
f"does not exist in config. Dropping message."
)
return
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping")
return
# Track unique texts to avoid duplicate embeddings
# text -> (index_name, index_value)
texts_to_embed: Dict[str, tuple] = {}
# Collect all texts that need embeddings
for value_map in obj.values:
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip empty values
if not index_value or all(v == "" for v in index_value):
continue
text = self.build_text_for_embedding(index_value)
if text and text not in texts_to_embed:
texts_to_embed[text] = (index_name, index_value)
if not texts_to_embed:
logger.info("No texts to embed")
return
# Compute embeddings
embeddings_list = []
try:
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(
RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
)
)
# Send in batches to avoid oversized messages
for i in range(0, len(embeddings_list), self.batch_size):
batch = embeddings_list[i:i + self.batch_size]
result = RowEmbeddings(
metadata=obj.metadata,
schema_name=obj.schema_name,
embeddings=batch,
)
await flow("output").send(result)
logger.info(
f"Computed {len(embeddings_list)} embeddings for "
f"{len(obj.values)} rows ({len(index_names)} indexes)"
)
except Exception as e:
logger.error("Exception during embedding computation", exc_info=True)
raise e
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Collection creation notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
async def delete_collection(self, user: str, collection: str):
"""Collection deletion notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--batch-size',
type=int,
default=default_batch_size,
help=f'Maximum embeddings per output message (default: {default_batch_size})'
)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -3,7 +3,7 @@ import json
import urllib.parse
import logging
from ....schema import Chunk, Triple, Triples, Metadata, Value
from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from ....schema import EntityContext, EntityContexts
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
@ -126,16 +126,42 @@ class Processor(FlowProcessor):
await pub.send(ecs)
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
def parse_jsonl(self, text):
"""
Parse JSONL response, returning list of valid objects.
return json.loads(json_str)
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
This provides truncation resilience - partial output yields partial results.
"""
results = []
# Strip markdown code fences if present
text = text.strip()
if text.startswith('```'):
# Remove opening fence (possibly with language hint)
text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text)
if text.endswith('```'):
text = text[:-3]
for line_num, line in enumerate(text.strip().split('\n'), 1):
line = line.strip()
# Skip empty lines
if not line:
continue
# Skip any remaining fence markers
if line.startswith('```'):
continue
try:
obj = json.loads(line)
results.append(obj)
except json.JSONDecodeError as e:
# Log warning but continue - this provides truncation resilience
logger.warning(f"JSONL parse error on line {line_num}: {e}")
return results
async def on_message(self, msg, consumer, flow):
@ -178,11 +204,12 @@ class Processor(FlowProcessor):
question = prompt
)
# Parse JSON response
try:
extraction_data = self.parse_json(agent_response)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON response from agent: {e}")
# Parse JSONL response
extraction_data = self.parse_jsonl(agent_response)
if not extraction_data:
logger.warning("JSONL parse returned no valid objects")
return
# Process extraction data
triples, entity_contexts = self.process_extraction_data(
@ -209,103 +236,113 @@ class Processor(FlowProcessor):
raise
def process_extraction_data(self, data, metadata):
"""Process combined extraction data to generate triples and entity contexts"""
"""Process JSONL extraction data to generate triples and entity contexts.
Data is a flat list of objects with 'type' discriminator field:
- {"type": "definition", "entity": "...", "definition": "..."}
- {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool}
"""
triples = []
entity_contexts = []
# Categorize items by type
definitions = [item for item in data if item.get("type") == "definition"]
relationships = [item for item in data if item.get("type") == "relationship"]
# Process definitions
for defn in data.get("definitions", []):
for defn in definitions:
entity_uri = self.to_uri(defn["entity"])
# Add entity label
triples.append(Triple(
s = Value(value=entity_uri, is_uri=True),
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=defn["entity"], is_uri=False),
s = Term(type=IRI, iri=entity_uri),
p = Term(type=IRI, iri=RDF_LABEL),
o = Term(type=LITERAL, value=defn["entity"]),
))
# Add definition
triples.append(Triple(
s = Value(value=entity_uri, is_uri=True),
p = Value(value=DEFINITION, is_uri=True),
o = Value(value=defn["definition"], is_uri=False),
s = Term(type=IRI, iri=entity_uri),
p = Term(type=IRI, iri=DEFINITION),
o = Term(type=LITERAL, value=defn["definition"]),
))
# Add subject-of relationship to document
if metadata.id:
triples.append(Triple(
s = Value(value=entity_uri, is_uri=True),
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
s = Term(type=IRI, iri=entity_uri),
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
# Create entity context for embeddings
entity_contexts.append(EntityContext(
entity=Value(value=entity_uri, is_uri=True),
entity=Term(type=IRI, iri=entity_uri),
context=defn["definition"]
))
# Process relationships
for rel in data.get("relationships", []):
for rel in relationships:
subject_uri = self.to_uri(rel["subject"])
predicate_uri = self.to_uri(rel["predicate"])
subject_value = Value(value=subject_uri, is_uri=True)
predicate_value = Value(value=predicate_uri, is_uri=True)
if data.get("object-entity", False):
object_value = Value(value=predicate_uri, is_uri=True)
subject_value = Term(type=IRI, iri=subject_uri)
predicate_value = Term(type=IRI, iri=predicate_uri)
if rel.get("object-entity", True):
object_uri = self.to_uri(rel["object"])
object_value = Term(type=IRI, iri=object_uri)
else:
object_value = Value(value=predicate_uri, is_uri=False)
object_value = Term(type=LITERAL, value=rel["object"])
# Add subject and predicate labels
triples.append(Triple(
s = subject_value,
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=rel["subject"], is_uri=False),
p = Term(type=IRI, iri=RDF_LABEL),
o = Term(type=LITERAL, value=rel["subject"]),
))
triples.append(Triple(
s = predicate_value,
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=rel["predicate"], is_uri=False),
p = Term(type=IRI, iri=RDF_LABEL),
o = Term(type=LITERAL, value=rel["predicate"]),
))
# Handle object (entity vs literal)
if rel.get("object-entity", True):
triples.append(Triple(
s = object_value,
p = Value(value=RDF_LABEL, is_uri=True),
o = Value(value=rel["object"], is_uri=True),
p = Term(type=IRI, iri=RDF_LABEL),
o = Term(type=LITERAL, value=rel["object"]),
))
# Add the main relationship triple
triples.append(Triple(
s = subject_value,
p = predicate_value,
o = object_value
))
# Add subject-of relationships to document
if metadata.id:
triples.append(Triple(
s = subject_value,
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
triples.append(Triple(
s = predicate_value,
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
if rel.get("object-entity", True):
triples.append(Triple(
s = object_value,
p = Value(value=SUBJECT_OF, is_uri=True),
o = Value(value=metadata.id, is_uri=True),
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
return triples, entity_contexts

View file

@ -9,7 +9,7 @@ import json
import urllib.parse
import logging
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
@ -20,12 +20,14 @@ from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION)
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
default_ident = "kg-extract-definitions"
default_concurrency = 1
default_triples_batch_size = 50
default_entity_batch_size = 5
class Processor(FlowProcessor):
@ -33,6 +35,8 @@ class Processor(FlowProcessor):
id = params.get("id")
concurrency = params.get("concurrency", 1)
self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size)
self.entity_batch_size = params.get("entity_batch_size", default_entity_batch_size)
super(Processor, self).__init__(
**params | {
@ -142,13 +146,13 @@ class Processor(FlowProcessor):
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
s_value = Term(type=IRI, iri=str(s_uri))
o_value = Term(type=LITERAL, value=str(o))
triples.append(Triple(
s=s_value,
p=RDF_LABEL_VALUE,
o=Value(value=s, is_uri=False),
o=Term(type=LITERAL, value=s),
))
triples.append(Triple(
@ -158,37 +162,48 @@ class Processor(FlowProcessor):
triples.append(Triple(
s=s_value,
p=SUBJECT_OF_VALUE,
o=Value(value=v.metadata.id, is_uri=True)
o=Term(type=IRI, iri=v.metadata.id)
))
ec = EntityContext(
# Output entity name as context for direct name matching
entities.append(EntityContext(
entity=s_value,
context=s,
))
# Output definition as context for semantic matching
entities.append(EntityContext(
entity=s_value,
context=defn["definition"],
))
# Send triples in batches
for i in range(0, len(triples), self.triples_batch_size):
batch = triples[i:i + self.triples_batch_size]
await self.emit_triples(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
batch
)
entities.append(ec)
await self.emit_triples(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
triples
)
await self.emit_ecs(
flow("entity-contexts"),
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
entities
)
# Send entity contexts in batches
for i in range(0, len(entities), self.entity_batch_size):
batch = entities[i:i + self.entity_batch_size]
await self.emit_ecs(
flow("entity-contexts"),
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
batch
)
except Exception as e:
logger.error(f"Definitions extraction exception: {e}", exc_info=True)
@ -205,6 +220,20 @@ class Processor(FlowProcessor):
help=f'Concurrent processing threads (default: {default_concurrency})'
)
parser.add_argument(
'--triples-batch-size',
type=int,
default=default_triples_batch_size,
help=f'Maximum triples per output message (default: {default_triples_batch_size})'
)
parser.add_argument(
'--entity-batch-size',
type=int,
default=default_entity_batch_size,
help=f'Maximum entity contexts per output message (default: {default_entity_batch_size})'
)
FlowProcessor.add_args(parser)
def run():

View file

@ -74,23 +74,27 @@ def build_entity_uri(entity_name: str, entity_type: str, ontology_id: str,
Args:
entity_name: Natural language entity name (e.g., "Cornish pasty")
entity_type: Ontology type (e.g., "fo/Recipe")
entity_type: Ontology type (e.g., "fo/Recipe" or "Recipe")
ontology_id: Ontology identifier (e.g., "food")
base_uri: Base URI for entity URIs (default: "https://trustgraph.ai")
Returns:
Full entity URI (e.g., "https://trustgraph.ai/food/fo-recipe-cornish-pasty")
Full entity URI (e.g., "https://trustgraph.ai/food/recipe-cornish-pasty")
Examples:
>>> build_entity_uri("Cornish pasty", "fo/Recipe", "food")
'https://trustgraph.ai/food/fo-recipe-cornish-pasty'
'https://trustgraph.ai/food/recipe-cornish-pasty'
>>> build_entity_uri("Cornish pasty", "fo/Food", "food")
'https://trustgraph.ai/food/fo-food-cornish-pasty'
>>> build_entity_uri("Cornish pasty", "Food", "food")
'https://trustgraph.ai/food/food-cornish-pasty'
>>> build_entity_uri("beef", "fo/Food", "food")
'https://trustgraph.ai/food/fo-food-beef'
'https://trustgraph.ai/food/food-beef'
"""
# Strip ontology prefix from type if present (e.g., "fo/Recipe" -> "Recipe")
if "/" in entity_type:
entity_type = entity_type.split("/")[-1]
type_part = normalize_type_identifier(entity_type)
name_part = normalize_entity_name(entity_name)

View file

@ -8,7 +8,7 @@ import logging
import asyncio
from typing import List, Dict, Any, Optional
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from .... schema import EntityContext, EntityContexts
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION
@ -27,6 +27,8 @@ logger = logging.getLogger(__name__)
default_ident = "kg-extract-ontology"
default_concurrency = 1
default_triples_batch_size = 50
default_entity_batch_size = 5
# URI prefix mappings for common namespaces
URI_PREFIXES = {
@ -39,12 +41,22 @@ URI_PREFIXES = {
}
def make_term(v, is_uri):
"""Helper to create Term from value and is_uri flag."""
if is_uri:
return Term(type=IRI, iri=v)
else:
return Term(type=LITERAL, value=v)
class Processor(FlowProcessor):
"""Main OntoRAG extraction processor."""
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size)
self.entity_batch_size = params.get("entity_batch_size", default_entity_batch_size)
super(Processor, self).__init__(
**params | {
@ -274,17 +286,6 @@ class Processor(FlowProcessor):
if not ontology_subsets:
logger.warning("No relevant ontology elements found for chunk")
# Emit empty outputs
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
return
# Merge subsets if multiple ontologies matched
@ -318,36 +319,29 @@ class Processor(FlowProcessor):
# Build entity contexts from all triples (including ontology elements)
entity_contexts = self.build_entity_contexts(all_triples)
# Emit all triples (extracted + ontology definitions)
await self.emit_triples(
flow("triples"),
v.metadata,
all_triples
)
# Emit triples in batches
for i in range(0, len(all_triples), self.triples_batch_size):
batch = all_triples[i:i + self.triples_batch_size]
await self.emit_triples(
flow("triples"),
v.metadata,
batch
)
# Emit entity contexts
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
entity_contexts
)
# Emit entity contexts in batches
for i in range(0, len(entity_contexts), self.entity_batch_size):
batch = entity_contexts[i:i + self.entity_batch_size]
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
batch
)
logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples "
f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts")
except Exception as e:
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
# Emit empty outputs on error
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
async def extract_with_simplified_format(
self,
@ -446,9 +440,9 @@ class Processor(FlowProcessor):
is_object_uri = False
# Create Triple object with expanded URIs
s_value = Value(value=subject_uri, is_uri=True)
p_value = Value(value=predicate_uri, is_uri=True)
o_value = Value(value=object_uri, is_uri=is_object_uri)
s_value = make_term(subject_uri, is_uri=True)
p_value = make_term(predicate_uri, is_uri=True)
o_value = make_term(object_uri, is_uri=is_object_uri)
validated_triples.append(Triple(
s=s_value,
@ -609,9 +603,9 @@ class Processor(FlowProcessor):
# rdf:type owl:Class
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True)
s=make_term(class_uri, is_uri=True),
p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=make_term("http://www.w3.org/2002/07/owl#Class", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyClass.__dict__)
@ -620,18 +614,18 @@ class Processor(FlowProcessor):
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
s=make_term(class_uri, is_uri=True),
p=make_term(RDF_LABEL, is_uri=True),
o=make_term(label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']:
comment = class_def['comment']
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
s=make_term(class_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=make_term(comment, is_uri=False)
))
# rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__)
@ -648,9 +642,9 @@ class Processor(FlowProcessor):
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
o=Value(value=parent_uri, is_uri=True)
s=make_term(class_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
o=make_term(parent_uri, is_uri=True)
))
# Generate triples for object properties
@ -663,9 +657,9 @@ class Processor(FlowProcessor):
# rdf:type owl:ObjectProperty
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=make_term("http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
@ -674,18 +668,18 @@ class Processor(FlowProcessor):
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
s=make_term(prop_uri, is_uri=True),
p=make_term(RDF_LABEL, is_uri=True),
o=make_term(label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
comment = prop_def['comment']
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=make_term(comment, is_uri=False)
))
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
@ -702,9 +696,9 @@ class Processor(FlowProcessor):
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=Value(value=domain_uri, is_uri=True)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=make_term(domain_uri, is_uri=True)
))
# rdfs:range (stored as 'range' in OntologyProperty.__dict__)
@ -721,9 +715,9 @@ class Processor(FlowProcessor):
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=Value(value=range_uri, is_uri=True)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=make_term(range_uri, is_uri=True)
))
# Generate triples for datatype properties
@ -736,9 +730,9 @@ class Processor(FlowProcessor):
# rdf:type owl:DatatypeProperty
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=make_term("http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
@ -747,18 +741,18 @@ class Processor(FlowProcessor):
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
s=make_term(prop_uri, is_uri=True),
p=make_term(RDF_LABEL, is_uri=True),
o=make_term(label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
comment = prop_def['comment']
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=make_term(comment, is_uri=False)
))
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
@ -775,9 +769,9 @@ class Processor(FlowProcessor):
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=Value(value=domain_uri, is_uri=True)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=make_term(domain_uri, is_uri=True)
))
# rdfs:range (datatype)
@ -790,9 +784,9 @@ class Processor(FlowProcessor):
range_uri = range_val
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=Value(value=range_uri, is_uri=True)
s=make_term(prop_uri, is_uri=True),
p=make_term("http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=make_term(range_uri, is_uri=True)
))
logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements")
@ -814,9 +808,9 @@ class Processor(FlowProcessor):
entity_data = {} # subject_uri -> {labels: [], definitions: []}
for triple in triples:
subject_uri = triple.s.value
predicate_uri = triple.p.value
object_val = triple.o.value
subject_uri = triple.s.iri if triple.s.type == IRI else triple.s.value
predicate_uri = triple.p.iri if triple.p.type == IRI else triple.p.value
object_val = triple.o.value if triple.o.type == LITERAL else triple.o.iri
# Initialize entity data if not exists
if subject_uri not in entity_data:
@ -824,12 +818,12 @@ class Processor(FlowProcessor):
# Collect labels (rdfs:label)
if predicate_uri == RDF_LABEL:
if not triple.o.is_uri: # Labels are literals
if triple.o.type == LITERAL: # Labels are literals
entity_data[subject_uri]['labels'].append(object_val)
# Collect definitions (skos:definition, schema:description)
elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description":
if not triple.o.is_uri:
if triple.o.type == LITERAL:
entity_data[subject_uri]['definitions'].append(object_val)
# Build EntityContext objects
@ -848,7 +842,7 @@ class Processor(FlowProcessor):
if context_parts:
context_text = ". ".join(context_parts)
entity_contexts.append(EntityContext(
entity=Value(value=subject_uri, is_uri=True),
entity=make_term(subject_uri, is_uri=True),
context=context_text
))
@ -876,6 +870,18 @@ class Processor(FlowProcessor):
default=0.3,
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
)
parser.add_argument(
'--triples-batch-size',
type=int,
default=default_triples_batch_size,
help=f'Maximum triples per output message (default: {default_triples_batch_size})'
)
parser.add_argument(
'--entity-batch-size',
type=int,
default=default_entity_batch_size,
help=f'Maximum entity contexts per output message (default: {default_entity_batch_size})'
)
FlowProcessor.add_args(parser)

View file

@ -49,8 +49,17 @@ class ExtractionResult:
def parse_extraction_response(response: Any) -> Optional[ExtractionResult]:
"""Parse LLM extraction response into structured format.
Supports two formats:
1. JSONL format (list): Flat list of objects with 'type' discriminator field
[{"type": "entity", ...}, {"type": "relationship", ...}, {"type": "attribute", ...}]
2. Legacy format (dict): Nested structure with separate arrays
{"entities": [...], "relationships": [...], "attributes": [...]}
Args:
response: LLM response (string JSON or already parsed dict)
response: LLM response - can be:
- string (JSON to parse)
- dict (legacy nested format)
- list (JSONL format - flat list with type discriminators)
Returns:
ExtractionResult with parsed entities/relationships/attributes,
@ -64,17 +73,89 @@ def parse_extraction_response(response: Any) -> Optional[ExtractionResult]:
logger.error(f"Failed to parse JSON response: {e}")
logger.debug(f"Response was: {response[:500]}")
return None
elif isinstance(response, dict):
elif isinstance(response, (dict, list)):
data = response
else:
logger.error(f"Unexpected response type: {type(response)}")
return None
# Validate structure
if not isinstance(data, dict):
logger.error(f"Expected dict, got {type(data)}")
return None
# Handle JSONL format (flat list with type discriminators)
if isinstance(data, list):
return parse_jsonl_format(data)
# Handle legacy format (nested dict)
if isinstance(data, dict):
return parse_legacy_format(data)
logger.error(f"Expected dict or list, got {type(data)}")
return None
def parse_jsonl_format(data: List[Dict[str, Any]]) -> ExtractionResult:
"""Parse JSONL format response (flat list with type discriminators).
Each item has a 'type' field: 'entity', 'relationship', or 'attribute'.
Args:
data: List of dicts with type discriminator
Returns:
ExtractionResult with categorized items
"""
entities = []
relationships = []
attributes = []
for item in data:
if not isinstance(item, dict):
logger.warning(f"Skipping non-dict item: {type(item)}")
continue
item_type = item.get('type')
if item_type == 'entity':
try:
entity = parse_entity_jsonl(item)
if entity:
entities.append(entity)
except Exception as e:
logger.warning(f"Failed to parse entity {item}: {e}")
elif item_type == 'relationship':
try:
relationship = parse_relationship(item)
if relationship:
relationships.append(relationship)
except Exception as e:
logger.warning(f"Failed to parse relationship {item}: {e}")
elif item_type == 'attribute':
try:
attribute = parse_attribute(item)
if attribute:
attributes.append(attribute)
except Exception as e:
logger.warning(f"Failed to parse attribute {item}: {e}")
else:
logger.warning(f"Unknown item type '{item_type}': {item}")
return ExtractionResult(
entities=entities,
relationships=relationships,
attributes=attributes
)
def parse_legacy_format(data: Dict[str, Any]) -> ExtractionResult:
"""Parse legacy format response (nested dict with arrays).
Args:
data: Dict with 'entities', 'relationships', 'attributes' arrays
Returns:
ExtractionResult with parsed items
"""
# Parse entities
entities = []
entities_data = data.get('entities', [])
@ -127,6 +208,37 @@ def parse_extraction_response(response: Any) -> Optional[ExtractionResult]:
)
def parse_entity_jsonl(data: Dict[str, Any]) -> Optional[Entity]:
"""Parse entity from JSONL format dict.
JSONL format uses 'entity_type' instead of 'type' for the entity's type
(since 'type' is the discriminator field).
Args:
data: Entity dict with 'entity' and 'entity_type' fields
Returns:
Entity object or None if invalid
"""
if not isinstance(data, dict):
logger.warning(f"Entity data is not a dict: {type(data)}")
return None
entity = data.get('entity')
# JSONL format uses 'entity_type' since 'type' is the discriminator
entity_type = data.get('entity_type')
if not entity or not entity_type:
logger.warning(f"Missing required fields in entity: {data}")
return None
if not isinstance(entity, str) or not isinstance(entity_type, str):
logger.warning(f"Entity fields must be strings: {data}")
return None
return Entity(entity=entity, type=entity_type)
def parse_entity(data: Dict[str, Any]) -> Optional[Entity]:
"""Parse entity from dict.

View file

@ -8,7 +8,7 @@ with full URIs and correct is_uri flags.
import logging
from typing import List, Optional
from .... schema import Triple, Value
from .... schema import Triple, Term, IRI, LITERAL
from .... rdf import RDF_TYPE, RDF_LABEL
from .simplified_parser import Entity, Relationship, Attribute, ExtractionResult
@ -87,17 +87,17 @@ class TripleConverter:
# Generate type triple: entity rdf:type ClassURI
type_triple = Triple(
s=Value(value=entity_uri, is_uri=True),
p=Value(value=RDF_TYPE, is_uri=True),
o=Value(value=class_uri, is_uri=True)
s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri=RDF_TYPE),
o=Term(type=IRI, iri=class_uri)
)
triples.append(type_triple)
# Generate label triple: entity rdfs:label "entity name"
label_triple = Triple(
s=Value(value=entity_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=entity.entity, is_uri=False) # Literal!
s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri=RDF_LABEL),
o=Term(type=LITERAL, value=entity.entity) # Literal!
)
triples.append(label_triple)
@ -131,9 +131,9 @@ class TripleConverter:
# Generate triple: subject property object
return Triple(
s=Value(value=subject_uri, is_uri=True),
p=Value(value=property_uri, is_uri=True),
o=Value(value=object_uri, is_uri=True)
s=Term(type=IRI, iri=subject_uri),
p=Term(type=IRI, iri=property_uri),
o=Term(type=IRI, iri=object_uri)
)
def convert_attribute(self, attribute: Attribute) -> Optional[Triple]:
@ -159,9 +159,9 @@ class TripleConverter:
# Generate triple: entity property "literal value"
return Triple(
s=Value(value=entity_uri, is_uri=True),
p=Value(value=property_uri, is_uri=True),
o=Value(value=attribute.value, is_uri=False) # Literal!
s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri=property_uri),
o=Term(type=LITERAL, value=attribute.value) # Literal!
)
def _get_class_uri(self, class_id: str) -> Optional[str]:

View file

@ -13,18 +13,19 @@ import urllib.parse
logger = logging.getLogger(__name__)
from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value
from .... schema import Metadata, Term, IRI, LITERAL
from .... schema import PromptRequest, PromptResponse
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
default_ident = "kg-extract-relationships"
default_concurrency = 1
default_triples_batch_size = 50
class Processor(FlowProcessor):
@ -32,6 +33,7 @@ class Processor(FlowProcessor):
id = params.get("id")
concurrency = params.get("concurrency", 1)
self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size)
super(Processor, self).__init__(
**params | {
@ -127,16 +129,16 @@ class Processor(FlowProcessor):
if o is None: continue
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
s_value = Term(type=IRI, iri=str(s_uri))
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
p_value = Term(type=IRI, iri=str(p_uri))
if rel["object-entity"]:
if rel["object-entity"]:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
o_value = Term(type=IRI, iri=str(o_uri))
else:
o_value = Value(value=str(o), is_uri=False)
o_value = Term(type=LITERAL, value=str(o))
triples.append(Triple(
s=s_value,
@ -148,14 +150,14 @@ class Processor(FlowProcessor):
triples.append(Triple(
s=s_value,
p=RDF_LABEL_VALUE,
o=Value(value=str(s), is_uri=False)
o=Term(type=LITERAL, value=str(s))
))
# Label for p
triples.append(Triple(
s=p_value,
p=RDF_LABEL_VALUE,
o=Value(value=str(p), is_uri=False)
o=Term(type=LITERAL, value=str(p))
))
if rel["object-entity"]:
@ -163,14 +165,14 @@ class Processor(FlowProcessor):
triples.append(Triple(
s=o_value,
p=RDF_LABEL_VALUE,
o=Value(value=str(o), is_uri=False)
o=Term(type=LITERAL, value=str(o))
))
# 'Subject of' for s
triples.append(Triple(
s=s_value,
p=SUBJECT_OF_VALUE,
o=Value(value=v.metadata.id, is_uri=True)
o=Term(type=IRI, iri=v.metadata.id)
))
if rel["object-entity"]:
@ -178,19 +180,22 @@ class Processor(FlowProcessor):
triples.append(Triple(
s=o_value,
p=SUBJECT_OF_VALUE,
o=Value(value=v.metadata.id, is_uri=True)
o=Term(type=IRI, iri=v.metadata.id)
))
await self.emit_triples(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
triples
)
# Send triples in batches
for i in range(0, len(triples), self.triples_batch_size):
batch = triples[i:i + self.triples_batch_size]
await self.emit_triples(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
batch
)
except Exception as e:
logger.error(f"Relationship extraction exception: {e}", exc_info=True)
@ -207,6 +212,13 @@ class Processor(FlowProcessor):
help=f'Concurrent processing threads (default: {default_concurrency})'
)
parser.add_argument(
'--triples-batch-size',
type=int,
default=default_triples_batch_size,
help=f'Maximum triples per output message (default: {default_triples_batch_size})'
)
FlowProcessor.add_args(parser)
def run():

View file

@ -1,5 +1,5 @@
"""
Object extraction service - extracts structured objects from text chunks
Row extraction service - extracts structured rows from text chunks
based on configured schemas.
"""
@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... messaging.translators import row_schema_translator
default_ident = "kg-extract-objects"
default_ident = "kg-extract-rows"
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
@ -310,5 +310,5 @@ class Processor(FlowProcessor):
FlowProcessor.add_args(parser)
def run():
"""Entry point for kg-extract-objects command"""
"""Entry point for kg-extract-rows command"""
Processor.launch(default_ident, __doc__)

View file

@ -11,7 +11,7 @@ import logging
# Module logger
logger = logging.getLogger(__name__)
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
@ -20,7 +20,7 @@ from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION)
module = "kg-extract-topics"
@ -106,8 +106,8 @@ class Processor(ConsumerProducer):
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
s_value = Term(type=IRI, iri=str(s_uri))
o_value = Term(type=LITERAL, value=str(o))
await self.emit_edge(
v.metadata, s_value, DEFINITION_VALUE, o_value

View file

@ -0,0 +1,31 @@
from ... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class DocumentEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(DocumentEmbeddingsQueryRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=DocumentEmbeddingsRequest,
response_schema=DocumentEmbeddingsResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("document-embeddings-query")
self.response_translator = TranslatorRegistry.get_response_translator("document-embeddings-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -20,12 +20,14 @@ from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . objects_query import ObjectsQueryRequestor
from . rows_query import RowsQueryRequestor
from . nlp_query import NLPQueryRequestor
from . structured_query import StructuredQueryRequestor
from . structured_diag import StructuredDiagRequestor
from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
from . row_embeddings_query import RowEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -39,7 +41,7 @@ from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
from . entity_contexts_import import EntityContextsImport
from . objects_import import ObjectsImport
from . rows_import import RowsImport
from . core_export import CoreExport
from . core_import import CoreImport
@ -55,11 +57,13 @@ request_response_dispatchers = {
"document-rag": DocumentRagRequestor,
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"document-embeddings": DocumentEmbeddingsQueryRequestor,
"triples": TriplesQueryRequestor,
"objects": ObjectsQueryRequestor,
"rows": RowsQueryRequestor,
"nlp-query": NLPQueryRequestor,
"structured-query": StructuredQueryRequestor,
"structured-diag": StructuredDiagRequestor,
"row-embeddings": RowEmbeddingsQueryRequestor,
}
global_dispatchers = {
@ -87,7 +91,7 @@ import_dispatchers = {
"graph-embeddings": GraphEmbeddingsImport,
"document-embeddings": DocumentEmbeddingsImport,
"entity-contexts": EntityContextsImport,
"objects": ObjectsImport,
"rows": RowsImport,
}
class DispatcherWrapper:

View file

@ -0,0 +1,31 @@
from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class RowEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(RowEmbeddingsQueryRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=RowEmbeddingsRequest,
response_schema=RowEmbeddingsResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query")
self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -12,7 +12,7 @@ from . serialize import to_subgraph
# Module logger
logger = logging.getLogger(__name__)
class ObjectsImport:
class RowsImport:
def __init__(
self, ws, running, backend, queue
@ -20,7 +20,7 @@ class ObjectsImport:
self.ws = ws
self.running = running
self.publisher = Publisher(
backend, topic = queue, schema = ExtractedObject
)
@ -73,4 +73,4 @@ class ObjectsImport:
if self.ws:
await self.ws.close()
self.ws = None
self.ws = None

View file

@ -1,30 +1,30 @@
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
from ... schema import RowsQueryRequest, RowsQueryResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class ObjectsQueryRequestor(ServiceRequestor):
class RowsQueryRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(ObjectsQueryRequestor, self).__init__(
super(RowsQueryRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=ObjectsQueryRequest,
response_schema=ObjectsQueryResponse,
request_schema=RowsQueryRequest,
response_schema=RowsQueryResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
self.request_translator = TranslatorRegistry.get_request_translator("rows-query")
self.response_translator = TranslatorRegistry.get_response_translator("rows-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)
return self.response_translator.from_response_with_completion(message)

View file

@ -1,46 +1,37 @@
import base64
from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata
from ... schema import Term, Triple, DocumentMetadata, ProcessingMetadata
from ... messaging.translators.primitives import TermTranslator, TripleTranslator
# Singleton translator instances
_term_translator = TermTranslator()
_triple_translator = TripleTranslator()
# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators
# Use the new messaging translation system instead for consistency and reusability.
# Examples:
# from trustgraph.... messaging.translators.primitives import ValueTranslator
# value_translator = ValueTranslator()
# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True})
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])
"""Convert dict to Term. Delegates to TermTranslator."""
return _term_translator.to_pulsar(x)
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
"""Convert list of dicts to list of Triples. Delegates to TripleTranslator."""
return [_triple_translator.to_pulsar(t) for t in x]
def serialize_value(v):
return {
"v": v.value,
"e": v.is_uri,
}
"""Convert Term to dict. Delegates to TermTranslator."""
return _term_translator.from_pulsar(v)
def serialize_triple(t):
return {
"s": serialize_value(t.s),
"p": serialize_value(t.p),
"o": serialize_value(t.o)
}
"""Convert Triple to dict. Delegates to TripleTranslator."""
return _triple_translator.from_pulsar(t)
def serialize_subgraph(sg):
return [
serialize_triple(t)
for t in sg
]
"""Convert list of Triples to list of dicts."""
return [serialize_triple(t) for t in sg]
def serialize_triples(message):
return {

View file

@ -75,6 +75,7 @@ class Processor(LlmService):
if stream:
data["stream"] = True
data["stream_options"] = {"include_usage": True}
body = json.dumps(data)
@ -191,6 +192,9 @@ class Processor(LlmService):
if response.status_code != 200:
raise RuntimeError("LLM failure")
total_input_tokens = 0
total_output_tokens = 0
# Parse SSE stream
for line in response.iter_lines():
if line:
@ -215,15 +219,21 @@ class Processor(LlmService):
model=model_name,
is_final=False
)
# Capture usage from final chunk
if 'usage' in chunk_data and chunk_data['usage']:
total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0)
total_output_tokens = chunk_data['usage'].get('completion_tokens', 0)
except json.JSONDecodeError:
logger.warning(f"Failed to parse chunk: {data}")
continue
# Send final chunk
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -161,9 +161,13 @@ class Processor(LlmService):
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
stream=True # Enable streaming
stream=True,
stream_options={"include_usage": True}
)
total_input_tokens = 0
total_output_tokens = 0
# Stream chunks
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
@ -175,11 +179,16 @@ class Processor(LlmService):
is_final=False
)
# Send final chunk
# Capture usage from final chunk
if chunk.usage:
total_input_tokens = chunk.usage.prompt_tokens
total_output_tokens = chunk.usage.completion_tokens
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -1,3 +0,0 @@
from . llm import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . llm import run
if __name__ == '__main__':
run()

View file

@ -1,257 +0,0 @@
"""
Simple LLM service, performs text prompt completion using GoogleAIStudio.
Input is prompt, output is response.
"""
#
# Using this SDK:
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
#
# Seems to have simpler dependencies on the 'VertexAI' service, which
# TrustGraph implements in the trustgraph-vertexai package.
#
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.api_core.exceptions import ResourceExhausted
import os
import logging
# Module logger
logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion"
default_model = 'gemini-2.0-flash-001'
default_temperature = 0.0
default_max_output = 8192
default_api_key = os.getenv("GOOGLE_AI_STUDIO_KEY")
class Processor(LlmService):
def __init__(self, **params):
model = params.get("model", default_model)
api_key = params.get("api_key", default_api_key)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
if api_key is None:
raise RuntimeError("Google AI Studio API key not specified")
super(Processor, self).__init__(
**params | {
"model": model,
"temperature": temperature,
"max_output": max_output,
}
)
self.client = genai.Client(api_key=api_key)
self.default_model = model
self.temperature = temperature
self.max_output = max_output
# Cache for generation configs per model
self.generation_configs = {}
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [
types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
),
types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
),
types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
),
types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
),
# There is a documentation conflict on whether or not
# CIVIC_INTEGRITY is a valid category
# HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: block_level,
]
logger.info("GoogleAIStudio LLM service initialized")
def _get_or_create_config(self, model_name, temperature=None):
"""Get or create generation config with dynamic temperature"""
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
# Create cache key that includes temperature to avoid conflicts
cache_key = f"{model_name}:{effective_temperature}"
if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = types.GenerateContentConfig(
temperature = effective_temperature,
top_p = 1,
top_k = 40,
max_output_tokens = self.max_output,
response_mime_type = "text/plain",
safety_settings = self.safety_settings,
)
return self.generation_configs[cache_key]
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
generation_config = self._get_or_create_config(model_name, effective_temperature)
# Set system instruction per request (can't be cached)
generation_config.system_instruction = system
try:
response = self.client.models.generate_content(
model=model_name,
config=generation_config,
contents=prompt,
)
resp = response.text
inputtokens = int(response.usage_metadata.prompt_token_count)
outputtokens = int(response.usage_metadata.candidates_token_count)
logger.debug(f"LLM response: {resp}")
logger.info(f"Input Tokens: {inputtokens}")
logger.info(f"Output Tokens: {outputtokens}")
resp = LlmResult(
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = model_name
)
return resp
except ResourceExhausted as e:
logger.warning("Rate limit exceeded")
# Leave rate limit retries to the default handler
raise TooManyRequests()
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
logger.error(f"GoogleAIStudio LLM exception ({type(e).__name__}): {e}", exc_info=True)
raise e
def supports_streaming(self):
"""Google AI Studio supports streaming"""
return True
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
"""Stream content generation from Google AI Studio"""
model_name = model or self.default_model
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model (streaming): {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
generation_config = self._get_or_create_config(model_name, effective_temperature)
generation_config.system_instruction = system
try:
response = self.client.models.generate_content_stream(
model=model_name,
config=generation_config,
contents=prompt,
)
total_input_tokens = 0
total_output_tokens = 0
for chunk in response:
if hasattr(chunk, 'text') and chunk.text:
yield LlmChunk(
text=chunk.text,
in_token=None,
out_token=None,
model=model_name,
is_final=False
)
# Accumulate token counts if available
if hasattr(chunk, 'usage_metadata'):
if hasattr(chunk.usage_metadata, 'prompt_token_count'):
total_input_tokens = int(chunk.usage_metadata.prompt_token_count)
if hasattr(chunk.usage_metadata, 'candidates_token_count'):
total_output_tokens = int(chunk.usage_metadata.candidates_token_count)
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)
logger.debug("Streaming complete")
except ResourceExhausted:
logger.warning("Rate limit exceeded during streaming")
raise TooManyRequests()
except Exception as e:
logger.error(f"GoogleAIStudio streaming exception ({type(e).__name__}): {e}", exc_info=True)
raise e
@staticmethod
def add_args(parser):
LlmService.add_args(parser)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'LLM model (default: {default_model})'
)
parser.add_argument(
'-k', '--api-key',
default=default_api_key,
help=f'GoogleAIStudio API key'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -126,9 +126,13 @@ class Processor(LlmService):
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "text"},
stream=True
stream=True,
stream_options={"include_usage": True}
)
total_input_tokens = 0
total_output_tokens = 0
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield LlmChunk(
@ -139,10 +143,15 @@ class Processor(LlmService):
is_final=False
)
# Capture usage from final chunk
if chunk.usage:
total_input_tokens = chunk.usage.prompt_tokens
total_output_tokens = chunk.usage.completion_tokens
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -40,7 +40,7 @@ class Processor(LlmService):
)
self.default_model = model
self.url = url + "v1/"
self.url = url.rstrip('/') + "/v1/"
self.temperature = temperature
self.max_output = max_output
self.openai = OpenAI(
@ -130,9 +130,13 @@ class Processor(LlmService):
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "text"},
stream=True
stream=True,
stream_options={"include_usage": True}
)
total_input_tokens = 0
total_output_tokens = 0
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield LlmChunk(
@ -143,10 +147,15 @@ class Processor(LlmService):
is_final=False
)
# Capture usage from final chunk
if chunk.usage:
total_input_tokens = chunk.usage.prompt_tokens
total_output_tokens = chunk.usage.completion_tokens
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -156,6 +156,9 @@ class Processor(LlmService):
response_format={"type": "text"}
)
total_input_tokens = 0
total_output_tokens = 0
for chunk in stream:
if chunk.data.choices and chunk.data.choices[0].delta.content:
yield LlmChunk(
@ -166,11 +169,16 @@ class Processor(LlmService):
is_final=False
)
# Send final chunk
# Capture usage data when available (typically in final chunk)
if chunk.data.usage:
total_input_tokens = chunk.data.usage.prompt_tokens
total_output_tokens = chunk.data.usage.completion_tokens
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -153,9 +153,13 @@ class Processor(LlmService):
],
temperature=effective_temperature,
max_tokens=self.max_output,
stream=True # Enable streaming
stream=True,
stream_options={"include_usage": True}
)
total_input_tokens = 0
total_output_tokens = 0
# Stream chunks
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
@ -167,12 +171,16 @@ class Processor(LlmService):
is_final=False
)
# Note: OpenAI doesn't provide token counts in streaming mode
# Send final chunk without token counts
# Capture usage from final chunk
if chunk.usage:
total_input_tokens = chunk.usage.prompt_tokens
total_output_tokens = chunk.usage.completion_tokens
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -83,7 +83,7 @@ class Processor(LlmService):
try:
url = f"{self.base_url}/chat/completions"
url = f"{self.base_url.rstrip('/')}/chat/completions"
async with self.session.post(
url,
@ -152,10 +152,14 @@ class Processor(LlmService):
"max_tokens": self.max_output,
"temperature": effective_temperature,
"stream": True,
"stream_options": {"include_usage": True},
}
try:
url = f"{self.base_url}/chat/completions"
url = f"{self.base_url.rstrip('/')}/chat/completions"
total_input_tokens = 0
total_output_tokens = 0
async with self.session.post(
url,
@ -196,15 +200,21 @@ class Processor(LlmService):
model=model_name,
is_final=False
)
# Capture usage from final chunk
if 'usage' in chunk_data and chunk_data['usage']:
total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0)
total_output_tokens = chunk_data['usage'].get('completion_tokens', 0)
except json.JSONDecodeError:
logger.warning(f"Failed to parse chunk: {data}")
continue
# Send final chunk
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -75,7 +75,7 @@ class Processor(LlmService):
try:
url = f"{self.base_url}/completions"
url = f"{self.base_url.rstrip('/')}/completions"
async with self.session.post(
url,
@ -135,10 +135,14 @@ class Processor(LlmService):
"max_tokens": self.max_output,
"temperature": effective_temperature,
"stream": True,
"stream_options": {"include_usage": True},
}
try:
url = f"{self.base_url}/completions"
url = f"{self.base_url.rstrip('/')}/completions"
total_input_tokens = 0
total_output_tokens = 0
async with self.session.post(
url,
@ -177,15 +181,21 @@ class Processor(LlmService):
model=model_name,
is_final=False
)
# Capture usage from final chunk
if 'usage' in chunk_data and chunk_data['usage']:
total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0)
total_output_tokens = chunk_data['usage'].get('completion_tokens', 0)
except json.JSONDecodeError:
logger.warning(f"Failed to parse chunk: {data}")
continue
# Send final chunk
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=None,
out_token=None,
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)

View file

@ -8,13 +8,13 @@ import logging
from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-query"
default_ident = "doc-embeddings-query"
default_store_uri = 'http://localhost:19530'
class Processor(DocumentEmbeddingsQueryService):

View file

@ -16,7 +16,7 @@ from .... base import DocumentEmbeddingsQueryService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-query"
default_ident = "doc-embeddings-query"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(DocumentEmbeddingsQueryService):

View file

@ -11,13 +11,13 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-query"
default_ident = "doc-embeddings-query"
default_store_uri = 'http://localhost:6333'

View file

@ -8,13 +8,13 @@ import logging
from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-query"
default_ident = "graph-embeddings-query"
default_store_uri = 'http://localhost:19530'
class Processor(GraphEmbeddingsQueryService):
@ -33,9 +33,9 @@ class Processor(GraphEmbeddingsQueryService):
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
else:
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):

View file

@ -12,13 +12,13 @@ from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-query"
default_ident = "graph-embeddings-query"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(GraphEmbeddingsQueryService):
@ -51,9 +51,9 @@ class Processor(GraphEmbeddingsQueryService):
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
else:
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):

View file

@ -11,13 +11,13 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-query"
default_ident = "graph-embeddings-query"
default_store_uri = 'http://localhost:6333'
@ -67,9 +67,9 @@ class Processor(GraphEmbeddingsQueryService):
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
else:
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):

View file

@ -0,0 +1,22 @@
"""
Shared GraphQL utilities for row query services.
This module provides reusable GraphQL components including:
- Filter types (IntFilter, StringFilter, FloatFilter)
- Dynamic schema generation from RowSchema definitions
- Filter parsing utilities
"""
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
from .schema import GraphQLSchemaBuilder
from .filters import parse_filter_key, parse_where_clause
__all__ = [
"IntFilter",
"StringFilter",
"FloatFilter",
"SortDirection",
"GraphQLSchemaBuilder",
"parse_filter_key",
"parse_where_clause",
]

View file

@ -0,0 +1,104 @@
"""
Filter parsing utilities for GraphQL row queries.
Provides functions to parse GraphQL filter objects into a normalized
format that can be used by different query backends.
"""
import logging
from typing import Dict, Any, Tuple
logger = logging.getLogger(__name__)
def parse_filter_key(filter_key: str) -> Tuple[str, str]:
"""
Parse GraphQL filter key into field name and operator.
Supports common GraphQL filter patterns:
- field_name -> (field_name, "eq")
- field_name_gt -> (field_name, "gt")
- field_name_gte -> (field_name, "gte")
- field_name_lt -> (field_name, "lt")
- field_name_lte -> (field_name, "lte")
- field_name_in -> (field_name, "in")
Args:
filter_key: The filter key string from GraphQL
Returns:
Tuple of (field_name, operator)
"""
if not filter_key:
return ("", "eq")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
def parse_where_clause(where_obj) -> Dict[str, Any]:
"""
Parse the idiomatic nested GraphQL filter structure into a flat dict.
Converts Strawberry filter objects (StringFilter, IntFilter, etc.)
into a dictionary mapping field names with operators to values.
Example:
Input: where_obj with email.eq = "foo@bar.com"
Output: {"email": "foo@bar.com"}
Input: where_obj with age.gt = 21
Output: {"age_gt": 21}
Args:
where_obj: The GraphQL where clause object
Returns:
Dictionary mapping field_operator keys to values
"""
if not where_obj:
return {}
conditions = {}
logger.debug(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.debug(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
elif operator == "startsWith":
conditions[f"{field_name}_startsWith"] = value
elif operator == "endsWith":
conditions[f"{field_name}_endsWith"] = value
elif operator == "not_":
conditions[f"{field_name}_not"] = value
elif operator == "not_in":
conditions[f"{field_name}_not_in"] = value
logger.debug(f"Final parsed conditions: {conditions}")
return conditions

View file

@ -0,0 +1,251 @@
"""
Dynamic GraphQL schema generation from RowSchema definitions.
Provides a builder class that creates Strawberry GraphQL schemas
from TrustGraph RowSchema definitions, with pluggable query backends.
"""
import logging
from typing import Dict, Any, Optional, List, Callable, Awaitable
import strawberry
from strawberry import Schema
from strawberry.types import Info
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
logger = logging.getLogger(__name__)
# Type alias for query callback function
QueryCallback = Callable[
[str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]],
Awaitable[List[Dict[str, Any]]]
]
class GraphQLSchemaBuilder:
"""
Builds GraphQL schemas from RowSchema definitions.
This class extracts the GraphQL schema generation logic so it can be
reused across different query backends (Cassandra, etc.).
Usage:
builder = GraphQLSchemaBuilder()
# Add schemas
for name, row_schema in schemas.items():
builder.add_schema(name, row_schema)
# Build with a query callback
schema = builder.build(query_callback)
"""
def __init__(self):
self.schemas: Dict[str, Any] = {} # name -> RowSchema
self.graphql_types: Dict[str, type] = {}
self.filter_types: Dict[str, type] = {}
def add_schema(self, name: str, row_schema) -> None:
"""
Add a RowSchema to the builder.
Args:
name: The schema name (used as the GraphQL query field name)
row_schema: The RowSchema object defining fields
"""
self.schemas[name] = row_schema
self.graphql_types[name] = self._create_graphql_type(name, row_schema)
self.filter_types[name] = self._create_filter_type(name, row_schema)
logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields")
def clear(self) -> None:
"""Clear all schemas from the builder."""
self.schemas = {}
self.graphql_types = {}
self.filter_types = {}
def build(self, query_callback: QueryCallback) -> Optional[Schema]:
"""
Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with:
- user: str
- collection: str
- schema_name: str
- row_schema: RowSchema
- filters: Dict[str, Any]
- limit: int
- order_by: Optional[str]
- direction: Optional[SortDirection]
It should return a list of row dictionaries.
Args:
query_callback: Async function to execute queries
Returns:
Strawberry Schema, or None if no schemas are loaded
"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
return None
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = self.filter_types[schema_name]
# Create resolver function for this schema
resolver_func = self._make_resolver(
schema_name, row_schema, graphql_type, filter_type, query_callback
)
# Add field to query dictionary
query_dict[schema_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][schema_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
return schema
def _get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL."""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def _create_graphql_type(self, schema_name: str, row_schema) -> type:
"""Create a GraphQL output type from a RowSchema."""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self._get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def _create_filter_type(self, schema_name: str, row_schema) -> type:
"""Create a dynamic filter input type for a schema."""
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.debug(
f"Field {field.name}: type={field.type}, "
f"indexed={field.indexed}, primary={field.primary}"
)
# Allow filtering on any field
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.debug(
f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}"
)
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def _make_resolver(
self,
schema_name: str,
row_schema,
graphql_type: type,
filter_type: type,
query_callback: QueryCallback
):
"""Create a resolver function for a schema."""
from .filters import parse_where_clause
async def resolver(
info: Info,
where: Optional[filter_type] = None,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None,
limit: Optional[int] = 100
) -> List[graphql_type]:
# Get context values
user = info.context["user"]
collection = info.context["collection"]
# Parse the where clause
filters = parse_where_clause(where)
# Call the query backend
results = await query_callback(
user, collection, schema_name, row_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = graphql_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver

View file

@ -0,0 +1,56 @@
"""
GraphQL filter and sort types for row queries.
These types are used to build dynamic GraphQL schemas for querying
structured row data.
"""
from typing import Optional, List
from enum import Enum
import strawberry
@strawberry.input
class IntFilter:
"""Filter type for integer fields."""
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
"""Filter type for string fields."""
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
"""Filter type for float fields."""
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
@strawberry.enum
class SortDirection(Enum):
"""Sort direction for query results."""
ASC = "asc"
DESC = "desc"

View file

@ -1,738 +0,0 @@
"""
Objects query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
"""
import json
import logging
import asyncio
from typing import Dict, Any, Optional, List, Set
from enum import Enum
from dataclasses import dataclass, field
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
import strawberry
from strawberry import Schema
from strawberry.types import Info
from strawberry.scalars import JSON
from strawberry.tools import create_type
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-query"
# GraphQL filter input types
@strawberry.input
class IntFilter:
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = ObjectsQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = ObjectsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema
self.graphql_schema: Optional[Schema] = None
# GraphQL types cache
self.graphql_types: Dict[str, type] = {}
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces and tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {}
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize table names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
safe_name = 'o_' + safe_name
return safe_name.lower()
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
"""Parse GraphQL filter key into field name and operator"""
if not filter_key:
return ("", "eq")
# Support common GraphQL filter patterns:
# field_name -> (field_name, "eq")
# field_name_gt -> (field_name, "gt")
# field_name_gte -> (field_name, "gte")
# field_name_lt -> (field_name, "lt")
# field_name_lte -> (field_name, "lte")
# field_name_in -> (field_name, "in")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.graphql_types = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.generate_graphql_schema()
def get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL"""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
"""Create a GraphQL type from a RowSchema"""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self.get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
"""Create a dynamic filter input type for a schema"""
# Create the filter type dynamically
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
# Allow filtering on any field for now, not just indexed/primary
# if field.indexed or field.primary:
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
logger.info(f"Added IntFilter for {field.name}")
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
logger.info(f"Added FloatFilter for {field.name}")
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.info(f"Added StringFilter for {field.name}")
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def create_sort_direction_enum(self):
"""Create sort direction enum"""
@strawberry.enum
class SortDirection(Enum):
ASC = "asc"
DESC = "desc"
return SortDirection
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
"""Parse the idiomatic nested filter structure"""
if not where_obj:
return {}
conditions = {}
logger.info(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.info(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
logger.info(f"Final parsed conditions: {conditions}")
return conditions
def generate_graphql_schema(self):
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
self.graphql_schema = None
return
# Create GraphQL types and filter types for each schema
filter_types = {}
sort_direction_enum = self.create_sort_direction_enum()
for schema_name, row_schema in self.schemas.items():
graphql_type = self.create_graphql_type(schema_name, row_schema)
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
self.graphql_types[schema_name] = graphql_type
filter_types[schema_name] = filter_type
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = filter_types[schema_name]
# Create resolver function for this schema
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
async def resolver(
info: Info,
where: Optional[f_type] = None,
order_by: Optional[str] = None,
direction: Optional[sort_enum] = None,
limit: Optional[int] = 100
) -> List[g_type]:
# Get the processor instance from context
processor = info.context["processor"]
user = info.context["user"]
collection = info.context["collection"]
# Parse the idiomatic where clause
filters = processor.parse_idiomatic_where_clause(where)
# Query Cassandra
results = await processor.query_cassandra(
user, collection, s_name, r_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = g_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver
# Add resolver to query
resolver_name = schema_name
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
# Add field to query dictionary
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][resolver_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
self.graphql_schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[Any] = None
) -> List[Dict[str, Any]]:
"""Execute a query against Cassandra"""
# Connect if needed
self.connect_cassandra()
# Build the query
keyspace = self.sanitize_name(user)
table = self.sanitize_table(schema_name)
# Start with basic SELECT
query = f"SELECT * FROM {keyspace}.{table}"
# Add WHERE clauses
where_clauses = [f"collection = %s"]
params = [collection]
# Add filters for indexed or primary key fields
for filter_key, value in filters.items():
if value is not None:
# Parse field name and operator from filter key
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
result = self.parse_filter_key(filter_key)
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
if not result or len(result) != 2:
logger.error(f"parse_filter_key returned invalid result: {result}")
continue # Skip this filter
field_name, operator = result
# Find the field in schema
schema_field = None
for f in row_schema.fields:
if f.name == field_name:
schema_field = f
break
if schema_field:
safe_field = self.sanitize_name(field_name)
# Build WHERE clause based on operator
if operator == "eq":
where_clauses.append(f"{safe_field} = %s")
params.append(value)
elif operator == "gt":
where_clauses.append(f"{safe_field} > %s")
params.append(value)
elif operator == "gte":
where_clauses.append(f"{safe_field} >= %s")
params.append(value)
elif operator == "lt":
where_clauses.append(f"{safe_field} < %s")
params.append(value)
elif operator == "lte":
where_clauses.append(f"{safe_field} <= %s")
params.append(value)
elif operator == "in":
if isinstance(value, list):
placeholders = ",".join(["%s"] * len(value))
where_clauses.append(f"{safe_field} IN ({placeholders})")
params.extend(value)
else:
# Default to equality for unknown operators
where_clauses.append(f"{safe_field} = %s")
params.append(value)
if where_clauses:
query += " WHERE " + " AND ".join(where_clauses)
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
cassandra_order_by_added = False
if order_by and direction:
# Validate that order_by field exists in schema
order_field_exists = any(f.name == order_by for f in row_schema.fields)
if order_field_exists:
safe_order_field = self.sanitize_name(order_by)
direction_str = "ASC" if direction.value == "asc" else "DESC"
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
query += f" ORDER BY {safe_order_field} {direction_str}"
# Add limit first (must come before ALLOW FILTERING)
if limit:
query += f" LIMIT {limit}"
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
query += " ALLOW FILTERING"
# Execute query
try:
result = self.session.execute(query, params)
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
except Exception as e:
# If ORDER BY fails, try without it
if order_by and direction and "ORDER BY" in query:
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
# Remove ORDER BY clause and retry
query_parts = query.split(" ORDER BY ")
if len(query_parts) == 2:
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
result = self.session.execute(query_without_order, params)
cassandra_order_by_added = False
else:
raise
else:
raise
# Convert rows to dicts
results = []
for row in result:
row_dict = {}
for field in row_schema.fields:
safe_field = self.sanitize_name(field.name)
if hasattr(row, safe_field):
value = getattr(row, safe_field)
# Use original field name in result
row_dict[field.name] = value
results.append(row_dict)
# Post-query sorting if Cassandra didn't handle ORDER BY
if order_by and direction and not cassandra_order_by_added:
reverse_order = (direction.value == "desc")
try:
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = ObjectsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in objects query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = ObjectsQueryResponse(
error = Error(
type = "objects-query-error",
message = str(e),
),
data = None,
errors = [],
extensions = {}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-query-graphql-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
"""
Row embeddings query modules.
"""

View file

@ -0,0 +1,5 @@
"""
Qdrant row embeddings query service.
"""
from .service import Processor, run, default_ident

View file

@ -0,0 +1,4 @@
from .service import run
run()

View file

@ -0,0 +1,209 @@
"""
Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context.
Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
import logging
import re
from typing import Optional
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from .... schema import (
RowEmbeddingsRequest, RowEmbeddingsResponse,
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowEmbeddingsRequest,
handler=self.on_message
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowEmbeddingsResponse
)
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
try:
all_collections = self.qdrant.get_collections().collections
matching = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if matching:
# Return first match (there should typically be only one per dimension)
return matching[0]
except Exception as e:
logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True)
return None
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
matches = []
# Find the collection for this user/collection/schema
qdrant_collection = self.find_collection(
request.user, request.collection, request.schema_name
)
if not qdrant_collection:
logger.info(
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
)
return matches
for vec in request.vectors:
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(
f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..."
)
# Execute query
matches = await self.query_row_embeddings(request)
response = RowEmbeddingsResponse(
error=None,
matches=matches
)
logger.debug(f"Returning {len(matches)} matches")
await flow("response").send(response, properties={"id": id})
except Exception as e:
logger.error(f"Exception in row embeddings query: {e}", exc_info=True)
response = RowEmbeddingsResponse(
error=Error(
type="row-embeddings-query-error",
message=str(e)
),
matches=[]
)
await flow("response").send(response, properties={"id": id})
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
def run():
"""Entry point for row-embeddings-query-qdrant command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,523 @@
"""
Row query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
Queries against the unified 'rows' table with schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
"""
import json
import logging
import re
from typing import Dict, Any, Optional, List, Set
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from ... graphql import GraphQLSchemaBuilder, SortDirection
# Module logger
logger = logging.getLogger(__name__)
default_ident = "rows-query"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowsQueryRequest,
handler=self.on_message
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema builder and generated schema
self.schema_builder = GraphQLSchemaBuilder()
self.graphql_schema = None
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces
self.known_keyspaces: Set[str] = set()
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.schema_builder.clear()
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
self.schema_builder.add_schema(schema_name, row_schema)
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def find_matching_index(
self,
schema: RowSchema,
filters: Dict[str, Any]
) -> Optional[tuple]:
"""
Find an index that can satisfy the query filters.
Returns (index_name, index_value) if found, None otherwise.
For exact match queries, we need a filter on an indexed field.
"""
index_names = self.get_index_names(schema)
# Look for an exact match filter on an indexed field
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
return None
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None
) -> List[Dict[str, Any]]:
"""
Execute a query against the unified Cassandra rows table.
For exact match queries on indexed fields, we can query directly.
For other queries, we need to scan and post-filter.
"""
# Connect if needed
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
results = []
if index_match:
# Direct query using index
index_name, index_value = index_match
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
AND index_value = %s
"""
params = [collection, schema_name, index_name, index_value]
if limit:
query += f" LIMIT {limit}"
try:
rows = self.session.execute(query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
else:
# No direct index match - scan all rows for this schema
# This is less efficient but necessary for non-indexed queries
logger.warning(
f"No index match for filters {filters} - scanning all indexes"
)
# Get all index names for this schema
index_names = self.get_index_names(row_schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexes")
return []
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
ALLOW FILTERING
"""
params = [collection, schema_name, primary_index]
try:
rows = self.session.execute(query, params)
for row in rows:
row_dict = dict(row.data) if row.data else {}
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)
if limit and len(results) >= limit:
break
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
raise
# Post-query sorting if requested
if order_by and results:
reverse_order = direction and direction.value == "desc"
try:
results.sort(
key=lambda x: x.get(order_by, ""),
reverse=reverse_order
)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
def _matches_filters(
self,
row_dict: Dict[str, Any],
filters: Dict[str, Any],
row_schema: RowSchema
) -> bool:
"""Check if a row matches the given filters."""
for filter_key, filter_value in filters.items():
if filter_value is None:
continue
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
field_name = parts[0]
operator = parts[1]
else:
field_name = filter_key
operator = 'eq'
else:
field_name = filter_key
operator = 'eq'
row_value = row_dict.get(field_name)
if row_value is None:
return False
# Convert types for comparison
try:
if operator == 'eq':
if str(row_value) != str(filter_value):
return False
elif operator == 'gt':
if float(row_value) <= float(filter_value):
return False
elif operator == 'gte':
if float(row_value) < float(filter_value):
return False
elif operator == 'lt':
if float(row_value) >= float(filter_value):
return False
elif operator == 'lte':
if float(row_value) > float(filter_value):
return False
elif operator == 'contains':
if str(filter_value) not in str(row_value):
return False
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False
return True
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = RowsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in rows query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = RowsQueryResponse(
error=Error(
type="rows-query-error",
message=str(e),
),
data=None,
errors=[],
extensions={}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for rows-query-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -1,14 +1,16 @@
"""
Triples query service. Input is a (s, p, o) triple, some values may be
null. Output is a list of triples.
Triples query service. Input is a (s, p, o, g) quad pattern, some values may be
null. Output is a list of quads.
"""
import logging
from .... direct.cassandra_kg import KnowledgeGraph
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
)
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import Term, Triple, IRI, LITERAL
from .... base import TriplesQueryService
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
@ -18,6 +20,56 @@ logger = logging.getLogger(__name__)
default_ident = "triples-query"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
def create_term(value, otype=None, dtype=None, lang=None):
"""
Create a Term from a string value, optionally using type metadata.
Args:
value: The string value
otype: Object type - 'u' (URI), 'l' (literal), 't' (triple)
dtype: XSD datatype (for literals)
lang: Language tag (for literals)
If otype is provided, uses it to determine Term type.
Otherwise falls back to URL detection heuristic.
"""
if otype is not None:
if otype == 'u':
return Term(type=IRI, iri=value)
elif otype == 'l':
return Term(
type=LITERAL,
value=value,
datatype=dtype or "",
language=lang or ""
)
elif otype == 't':
# Triple/reification - treat as IRI for now
return Term(type=IRI, iri=value)
else:
# Unknown otype, fall back to heuristic
pass
# Heuristic fallback for backwards compatibility
if value.startswith("http://") or value.startswith("https://"):
return Term(type=IRI, iri=value)
else:
return Term(type=LITERAL, value=value)
class Processor(TriplesQueryService):
def __init__(self, **params):
@ -46,12 +98,6 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
else:
return Value(value=ent, is_uri=False)
async def query_triples(self, query):
try:
@ -59,90 +105,137 @@ class Processor(TriplesQueryService):
user = query.user
if user != self.table:
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
)
self.table = user
triples = []
# Extract values from query
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g # Already a string or None
if query.s is not None:
if query.p is not None:
if query.o is not None:
# Helper to extract object metadata from result row
def get_o_metadata(t):
"""Extract otype/dtype/lang from result row if available"""
otype = getattr(t, 'otype', None)
dtype = getattr(t, 'dtype', None)
lang = getattr(t, 'lang', None)
return otype, dtype, lang
quads = []
# Route to appropriate query method based on which fields are specified
if s_val is not None:
if p_val is not None:
if o_val is not None:
# SPO specified - find matching graphs
resp = self.tg.get_spo(
query.collection, query.s.value, query.p.value, query.o.value,
query.collection, s_val, p_val, o_val, g=g_val,
limit=query.limit
)
triples.append((query.s.value, query.p.value, query.o.value))
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, p_val, o_val, g, otype, dtype, lang))
else:
# SP specified
resp = self.tg.get_sp(
query.collection, query.s.value, query.p.value,
query.collection, s_val, p_val, g=g_val,
limit=query.limit
)
for t in resp:
triples.append((query.s.value, query.p.value, t.o))
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, p_val, t.o, g, otype, dtype, lang))
else:
if query.o is not None:
if o_val is not None:
# SO specified
resp = self.tg.get_os(
query.collection, query.o.value, query.s.value,
query.collection, o_val, s_val, g=g_val,
limit=query.limit
)
for t in resp:
triples.append((query.s.value, t.p, query.o.value))
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, t.p, o_val, g, otype, dtype, lang))
else:
# S only
resp = self.tg.get_s(
query.collection, query.s.value,
query.collection, s_val, g=g_val,
limit=query.limit
)
for t in resp:
triples.append((query.s.value, t.p, t.o))
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, t.p, t.o, g, otype, dtype, lang))
else:
if query.p is not None:
if query.o is not None:
if p_val is not None:
if o_val is not None:
# PO specified
resp = self.tg.get_po(
query.collection, query.p.value, query.o.value,
query.collection, p_val, o_val, g=g_val,
limit=query.limit
)
for t in resp:
triples.append((t.s, query.p.value, query.o.value))
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, p_val, o_val, g, otype, dtype, lang))
else:
# P only
resp = self.tg.get_p(
query.collection, query.p.value,
query.collection, p_val, g=g_val,
limit=query.limit
)
for t in resp:
triples.append((t.s, query.p.value, t.o))
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, p_val, t.o, g, otype, dtype, lang))
else:
if query.o is not None:
if o_val is not None:
# O only
resp = self.tg.get_o(
query.collection, query.o.value,
query.collection, o_val, g=g_val,
limit=query.limit
)
for t in resp:
triples.append((t.s, t.p, query.o.value))
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, t.p, o_val, g, otype, dtype, lang))
else:
# Nothing specified - get all
resp = self.tg.get_all(
query.collection,
limit=query.limit
)
for t in resp:
triples.append((t.s, t.p, t.o))
# Note: quads_by_collection uses 'd' for graph field
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, t.p, t.o, g, otype, dtype, lang))
# Convert to Triple objects (with g field)
# Use otype/dtype/lang for proper Term reconstruction if available
triples = [
Triple(
s=self.create_value(t[0]),
p=self.create_value(t[1]),
o=self.create_value(t[2])
s=create_term(q[0]),
p=create_term(q[1]),
o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]),
g=q[3] if q[3] != DEFAULT_GRAPH else None
)
for t in triples
for q in quads
]
return triples
@ -162,4 +255,3 @@ class Processor(TriplesQueryService):
def run():
Processor.launch(default_ident, __doc__)

View file

@ -10,12 +10,24 @@ import logging
from falkordb import FalkorDB
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import Term, Triple, IRI, LITERAL
from .... base import TriplesQueryService
# Module logger
logger = logging.getLogger(__name__)
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
return term.id or term.value
default_ident = "triples-query"
default_graph_url = 'falkor://falkordb:6379'
@ -42,9 +54,9 @@ class Processor(TriplesQueryService):
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
else:
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
@ -63,28 +75,28 @@ class Processor(TriplesQueryService):
"RETURN $src as src "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"rel": query.p.value,
"value": query.o.value,
"src": get_term_value(query.s),
"rel": get_term_value(query.p),
"value": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"rel": query.p.value,
"uri": query.o.value,
"src": get_term_value(query.s),
"rel": get_term_value(query.p),
"uri": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
else:
@ -95,26 +107,26 @@ class Processor(TriplesQueryService):
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"rel": query.p.value,
"src": get_term_value(query.s),
"rel": get_term_value(query.p),
},
).result_set
for rec in records:
triples.append((query.s.value, query.p.value, rec[0]))
triples.append((get_term_value(query.s), get_term_value(query.p), rec[0]))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"rel": query.p.value,
"src": get_term_value(query.s),
"rel": get_term_value(query.p),
},
).result_set
for rec in records:
triples.append((query.s.value, query.p.value, rec[0]))
triples.append((get_term_value(query.s), get_term_value(query.p), rec[0]))
else:
@ -127,26 +139,26 @@ class Processor(TriplesQueryService):
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"value": query.o.value,
"src": get_term_value(query.s),
"value": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((query.s.value, rec[0], query.o.value))
triples.append((get_term_value(query.s), rec[0], get_term_value(query.o)))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"uri": query.o.value,
"src": get_term_value(query.s),
"uri": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((query.s.value, rec[0], query.o.value))
triples.append((get_term_value(query.s), rec[0], get_term_value(query.o)))
else:
@ -157,24 +169,24 @@ class Processor(TriplesQueryService):
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"src": get_term_value(query.s),
},
).result_set
for rec in records:
triples.append((query.s.value, rec[0], rec[1]))
triples.append((get_term_value(query.s), rec[0], rec[1]))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
params={
"src": query.s.value,
"src": get_term_value(query.s),
},
).result_set
for rec in records:
triples.append((query.s.value, rec[0], rec[1]))
triples.append((get_term_value(query.s), rec[0], rec[1]))
else:
@ -190,26 +202,26 @@ class Processor(TriplesQueryService):
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
params={
"uri": query.p.value,
"value": query.o.value,
"uri": get_term_value(query.p),
"value": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((rec[0], query.p.value, query.o.value))
triples.append((rec[0], get_term_value(query.p), get_term_value(query.o)))
records = self.io.query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
params={
"uri": query.p.value,
"dest": query.o.value,
"uri": get_term_value(query.p),
"dest": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((rec[0], query.p.value, query.o.value))
triples.append((rec[0], get_term_value(query.p), get_term_value(query.o)))
else:
@ -220,24 +232,24 @@ class Processor(TriplesQueryService):
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
params={
"uri": query.p.value,
"uri": get_term_value(query.p),
},
).result_set
for rec in records:
triples.append((rec[0], query.p.value, rec[1]))
triples.append((rec[0], get_term_value(query.p), rec[1]))
records = self.io.query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
params={
"uri": query.p.value,
"uri": get_term_value(query.p),
},
).result_set
for rec in records:
triples.append((rec[0], query.p.value, rec[1]))
triples.append((rec[0], get_term_value(query.p), rec[1]))
else:
@ -250,24 +262,24 @@ class Processor(TriplesQueryService):
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"value": query.o.value,
"value": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((rec[0], rec[1], query.o.value))
triples.append((rec[0], rec[1], get_term_value(query.o)))
records = self.io.query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"uri": query.o.value,
"uri": get_term_value(query.o),
},
).result_set
for rec in records:
triples.append((rec[0], rec[1], query.o.value))
triples.append((rec[0], rec[1], get_term_value(query.o)))
else:

View file

@ -10,12 +10,24 @@ import logging
from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import Term, Triple, IRI, LITERAL
from .... base import TriplesQueryService
# Module logger
logger = logging.getLogger(__name__)
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
return term.id or term.value
default_ident = "triples-query"
default_graph_host = 'bolt://memgraph:7687'
@ -47,9 +59,9 @@ class Processor(TriplesQueryService):
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
else:
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
@ -73,13 +85,13 @@ class Processor(TriplesQueryService):
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value, value=query.o.value,
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
@ -87,13 +99,13 @@ class Processor(TriplesQueryService):
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value, uri=query.o.value,
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
else:
@ -105,14 +117,14 @@ class Processor(TriplesQueryService):
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value,
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, query.p.value, data["dest"]))
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
@ -120,14 +132,14 @@ class Processor(TriplesQueryService):
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value,
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, query.p.value, data["dest"]))
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
else:
@ -141,14 +153,14 @@ class Processor(TriplesQueryService):
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=query.s.value, value=query.o.value,
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], query.o.value))
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
@ -156,14 +168,14 @@ class Processor(TriplesQueryService):
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=query.s.value, uri=query.o.value,
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], query.o.value))
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
else:
@ -175,14 +187,14 @@ class Processor(TriplesQueryService):
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=query.s.value,
src=get_term_value(query.s),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], data["dest"]))
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
@ -190,14 +202,14 @@ class Processor(TriplesQueryService):
"(dest:Node {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=query.s.value,
src=get_term_value(query.s),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], data["dest"]))
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
else:
@ -214,14 +226,14 @@ class Processor(TriplesQueryService):
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=query.p.value, value=query.o.value,
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, query.o.value))
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
@ -229,14 +241,14 @@ class Processor(TriplesQueryService):
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=query.p.value, dest=query.o.value,
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, query.o.value))
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
else:
@ -248,14 +260,14 @@ class Processor(TriplesQueryService):
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=query.p.value,
uri=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, data["dest"]))
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
@ -263,14 +275,14 @@ class Processor(TriplesQueryService):
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=query.p.value,
uri=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, data["dest"]))
triples.append((data["src"], get_term_value(query.p), data["dest"]))
else:
@ -284,14 +296,14 @@ class Processor(TriplesQueryService):
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=query.o.value,
value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], query.o.value))
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
@ -299,14 +311,14 @@ class Processor(TriplesQueryService):
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=query.o.value,
uri=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], query.o.value))
triples.append((data["src"], data["rel"], get_term_value(query.o)))
else:

View file

@ -10,12 +10,24 @@ import logging
from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import Term, Triple, IRI, LITERAL
from .... base import TriplesQueryService
# Module logger
logger = logging.getLogger(__name__)
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
return term.id or term.value
default_ident = "triples-query"
default_graph_host = 'bolt://neo4j:7687'
@ -47,9 +59,9 @@ class Processor(TriplesQueryService):
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
else:
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
@ -71,27 +83,29 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src",
src=query.s.value, rel=query.p.value, value=query.o.value,
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN $src as src",
src=query.s.value, rel=query.p.value, uri=query.o.value,
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
else:
@ -101,29 +115,31 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest",
src=query.s.value, rel=query.p.value,
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, query.p.value, data["dest"]))
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest",
src=query.s.value, rel=query.p.value,
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, query.p.value, data["dest"]))
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
else:
@ -135,29 +151,31 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN rel.uri as rel",
src=query.s.value, value=query.o.value,
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], query.o.value))
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel",
src=query.s.value, uri=query.o.value,
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], query.o.value))
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
else:
@ -167,29 +185,31 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest",
src=query.s.value,
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], data["dest"]))
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest",
src=query.s.value,
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((query.s.value, data["rel"], data["dest"]))
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
else:
@ -204,29 +224,31 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src",
uri=query.p.value, value=query.o.value,
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, query.o.value))
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"RETURN src.uri as src",
uri=query.p.value, dest=query.o.value,
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, query.o.value))
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
else:
@ -236,29 +258,31 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest",
uri=query.p.value,
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, data["dest"]))
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest",
uri=query.p.value,
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], query.p.value, data["dest"]))
triples.append((data["src"], get_term_value(query.p), data["dest"]))
else:
@ -270,29 +294,31 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel",
value=query.o.value,
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], query.o.value))
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel",
uri=query.o.value,
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], query.o.value))
triples.append((data["src"], data["rel"], get_term_value(query.o)))
else:
@ -302,7 +328,8 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
database_=self.db,
)
@ -315,7 +342,8 @@ class Processor(TriplesQueryService):
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
database_=self.db,
)
@ -327,10 +355,10 @@ class Processor(TriplesQueryService):
triples = [
Triple(
s=self.create_value(t[0]),
p=self.create_value(t[1]),
p=self.create_value(t[1]),
o=self.create_value(t[2])
)
for t in triples
for t in triples[:query.limit]
]
return triples

View file

@ -1,6 +1,6 @@
"""
Structured Query Service - orchestrates natural language question processing.
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
Takes a question, converts it to GraphQL via nlp-query, executes via rows-query,
and returns the results.
"""
@ -10,7 +10,7 @@ from typing import Dict, Any, Optional
from ...schema import StructuredQueryRequest, StructuredQueryResponse
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
from ...schema import RowsQueryRequest, RowsQueryResponse
from ...schema import Error
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
@ -57,13 +57,13 @@ class Processor(FlowProcessor):
)
)
# Client spec for calling objects query service
# Client spec for calling rows query service
self.register_specification(
RequestResponseSpec(
request_name = "objects-query-request",
response_name = "objects-query-response",
request_schema = ObjectsQueryRequest,
response_schema = ObjectsQueryResponse
request_name = "rows-query-request",
response_name = "rows-query-response",
request_schema = RowsQueryRequest,
response_schema = RowsQueryResponse
)
)
@ -112,7 +112,7 @@ class Processor(FlowProcessor):
variables_as_strings[key] = str(value)
# Use user/collection values from request
objects_request = ObjectsQueryRequest(
objects_request = RowsQueryRequest(
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
@ -120,12 +120,12 @@ class Processor(FlowProcessor):
operation_name=None
)
objects_response = await flow("objects-query-request").request(objects_request)
objects_response = await flow("rows-query-request").request(objects_request)
if objects_response.error is not None:
raise Exception(f"Objects query service error: {objects_response.error.message}")
# Handle GraphQL errors from the objects query service
raise Exception(f"Rows query service error: {objects_response.error.message}")
# Handle GraphQL errors from the rows query service
graphql_errors = []
if objects_response.errors:
for gql_error in objects_response.errors:

View file

@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:19530'
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):

View file

@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"

View file

@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:6333'

View file

@ -9,11 +9,25 @@ from .... direct.milvus_graph_embeddings import EntityVectors
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:19530'
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
@ -36,11 +50,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
async def store_graph_embeddings(self, message):
for entity in message.entities:
entity_value = get_term_value(entity.entity)
if entity.entity.value != "" and entity.entity.value is not None:
if entity_value != "" and entity_value is not None:
for vec in entity.vectors:
self.vecstore.insert(
vec, entity.entity.value,
vec, entity_value,
message.metadata.user,
message.metadata.collection
)

View file

@ -14,11 +14,25 @@ import logging
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"
@ -100,8 +114,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
return
for entity in message.entities:
entity_value = get_term_value(entity.entity)
if entity.entity.value == "" or entity.entity.value is None:
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
@ -126,7 +141,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
{
"id": vector_id,
"values": vec,
"metadata": { "entity": entity.entity.value },
"metadata": { "entity": entity_value },
}
]

View file

@ -12,11 +12,26 @@ import logging
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:6333'
@ -51,8 +66,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
return
for entity in message.entities:
entity_value = get_term_value(entity.entity)
if entity.entity.value == "" or entity.entity.value is None: return
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
@ -80,7 +97,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
id=str(uuid.uuid4()),
vector=vec,
payload={
"entity": entity.entity.value,
"entity": entity_value,
}
)
]

View file

@ -64,12 +64,14 @@ class Processor(FlowProcessor):
async def on_triples(self, msg, consumer, flow):
v = msg.value()
await self.table_store.add_triples(v)
if v.triples:
await self.table_store.add_triples(v)
async def on_graph_embeddings(self, msg, consumer, flow):
v = msg.value()
await self.table_store.add_graph_embeddings(v)
if v.entities:
await self.table_store.add_graph_embeddings(v)
@staticmethod
def add_args(parser):

View file

@ -1 +0,0 @@
# Objects storage module

View file

@ -1 +0,0 @@
from . write import *

View file

@ -1,3 +0,0 @@
from . write import run
run()

View file

@ -1,538 +0,0 @@
"""
Object writer for Cassandra. Input is ExtractedObject.
Writes structured objects to Cassandra tables based on schema definitions.
"""
import json
import logging
from typing import Dict, Set, Optional, Any
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.cqlengine import connection
from cassandra import ConsistencyLevel
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-write"
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = ExtractedObject,
handler = self.on_object
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
# Cache of known keyspaces/tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
self.known_tables[keyspace] = set()
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
"""Convert schema field type to Cassandra type"""
# Handle None size
if size is None:
size = 0
type_mapping = {
"string": "text",
"integer": "bigint" if size > 4 else "int",
"float": "double" if size > 4 else "float",
"boolean": "boolean",
"timestamp": "timestamp",
"date": "date",
"time": "time",
"uuid": "uuid"
}
return type_mapping.get(field_type, "text")
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
safe_name = 'o_' + safe_name
return safe_name.lower()
def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema):
"""Ensure table exists with proper structure"""
table_key = f"{keyspace}.{table_name}"
if table_key in self.known_tables.get(keyspace, set()):
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Build column definitions
columns = ["collection text"] # Collection is always part of table
primary_key_fields = []
clustering_fields = []
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
cassandra_type = self.get_cassandra_type(field.type, field.size)
columns.append(f"{safe_field_name} {cassandra_type}")
if field.primary:
primary_key_fields.append(safe_field_name)
# Build primary key - collection is always first in partition key
if primary_key_fields:
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
else:
# If no primary key defined, use collection and a synthetic id
columns.append("synthetic_id uuid")
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
# Create table
create_table_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
{', '.join(columns)},
{primary_key}
)
"""
try:
self.session.execute(create_table_cql)
if keyspace not in self.known_tables:
self.known_tables[keyspace] = set()
self.known_tables[keyspace].add(table_key)
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
# Create secondary indexes for indexed fields
for field in schema.fields:
if field.indexed and not field.primary:
safe_field_name = self.sanitize_name(field.name)
index_name = f"{safe_table}_{safe_field_name}_idx"
create_index_cql = f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {safe_keyspace}.{safe_table} ({safe_field_name})
"""
try:
self.session.execute(create_index_cql)
logger.info(f"Created index: {index_name}")
except Exception as e:
logger.warning(f"Failed to create index {index_name}: {e}")
except Exception as e:
logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True)
raise
def convert_value(self, value: Any, field_type: str) -> Any:
"""Convert value to appropriate type for Cassandra"""
if value is None:
return None
try:
if field_type == "integer":
return int(value)
elif field_type == "float":
return float(value)
elif field_type == "boolean":
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes')
return bool(value)
elif field_type == "timestamp":
# Handle timestamp conversion if needed
return value
else:
return str(value)
except Exception as e:
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
return str(value)
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Ensure table exists
keyspace = obj.metadata.user
table_name = obj.schema_name
self.ensure_table(keyspace, table_name, schema)
# Prepare data for insertion
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Process each object in the batch
for obj_index, value_map in enumerate(obj.values):
# Build column names and values for this object
columns = ["collection"]
values = [obj.metadata.collection]
placeholders = ["%s"]
# Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key:
import uuid
columns.append("synthetic_id")
values.append(uuid.uuid4())
placeholders.append("%s")
# Process fields for this object
skip_object = False
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
raw_value = value_map.get(field.name)
# Handle required fields
if field.required and raw_value is None:
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
# Continue anyway - Cassandra doesn't enforce NOT NULL
# Check if primary key field is NULL
if field.primary and raw_value is None:
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
skip_object = True
break
# Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type)
columns.append(safe_field_name)
values.append(converted_value)
placeholders.append("%s")
# Skip this object if primary key validation failed
if skip_object:
continue
# Build and execute insert query for this object
insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
# Debug: Show data being inserted
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
if len(columns) != len(values) or len(columns) != len(placeholders):
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
try:
# Convert to tuple - Cassandra driver requires tuple for parameters
self.session.execute(insert_cql, tuple(values))
except Exception as e:
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
raise
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra object store"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Ensure keyspace exists
if safe_keyspace not in self.known_keyspaces:
self.ensure_keyspace(safe_keyspace)
self.known_keyspaces.add(safe_keyspace)
# For Cassandra objects, collection is just a property in rows
# No need to create separate tables per collection
# Just mark that we've seen this collection
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using schema information"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if safe_keyspace not in self.known_keyspaces:
# Query to verify keyspace exists
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(safe_keyspace)
# Iterate over schemas we manage to delete from relevant tables
tables_deleted = 0
for schema_name, schema in self.schemas.items():
safe_table = self.sanitize_table(schema_name)
# Check if table exists
table_key = f"{user}.{schema_name}"
if table_key not in self.known_tables.get(user, set()):
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
continue
try:
# Get primary key fields from schema
primary_key_fields = [field for field in schema.fields if field.primary]
if primary_key_fields:
# Schema has primary keys: need to query for partition keys first
# Build SELECT query for primary key fields
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
select_cql = f"""
SELECT {', '.join(pk_field_names)}
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using full partition key
for row in rows:
where_clauses = ["collection = %s"]
values = [collection]
for field_name in pk_field_names:
where_clauses.append(f"{field_name} = %s")
values.append(getattr(row, field_name))
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE {' AND '.join(where_clauses)}
"""
self.session.execute(delete_cql, tuple(values))
else:
# No primary keys, uses synthetic_id
# Need to query for synthetic_ids first
select_cql = f"""
SELECT synthetic_id
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using collection and synthetic_id
for row in rows:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE collection = %s AND synthetic_id = %s
"""
self.session.execute(delete_cql, (collection, row.synthetic_id))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-write-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
"""
Row embeddings storage modules.
"""

View file

@ -0,0 +1,5 @@
"""
Qdrant storage for row embeddings.
"""
from .write import Processor, run, default_ident

View file

@ -0,0 +1,4 @@
from .write import run
run()

View file

@ -0,0 +1,264 @@
"""
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
- index_value: The original list of values (for Cassandra lookup)
- text: The text that was embedded (for debugging/display)
"""
import logging
import re
import uuid
from typing import Set, Tuple
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=RowEmbeddings,
handler=self.on_embeddings
)
)
# Register config handler for collection management
self.register_config_handler(self.on_collection_config)
# Cache of created Qdrant collections
self.created_collections: Set[str] = set()
# Qdrant client
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(user)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist"""
if collection_name in self.created_collections:
return
if not self.qdrant.collection_exists(collection_name):
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
)
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
)
)
self.created_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant"""
embeddings = msg.value()
logger.info(
f"Writing {len(embeddings.embeddings)} embeddings for schema "
f"{embeddings.schema_name} from {embeddings.metadata.id}"
)
# Validate collection exists in config before processing
if not self.collection_exists(
embeddings.metadata.user, embeddings.metadata.collection
):
logger.warning(
f"Collection {embeddings.metadata.collection} for user "
f"{embeddings.metadata.user} does not exist in config. "
f"Dropping message."
)
return
user = embeddings.metadata.user
collection = embeddings.metadata.collection
schema_name = embeddings.schema_name
embeddings_written = 0
qdrant_collection = None
for row_emb in embeddings.embeddings:
if not row_emb.vectors:
logger.warning(
f"No vectors for index {row_emb.index_name} - skipping"
)
continue
# Use first vector (there may be multiple from different models)
for vector in row_emb.vectors:
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {user}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, user: str, collection: str):
"""Delete all Qdrant collections for a given user/collection"""
try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, user: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific user/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='Qdrant API key (default: None)'
)
def run():
"""Entry point for row-embeddings-write-qdrant command"""
Processor.launch(default_ident, __doc__)

View file

@ -1,46 +1,49 @@
"""
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
Row writer for Cassandra. Input is ExtractedObject.
Writes structured rows to a unified Cassandra table with multi-index support.
Uses a single 'rows' table with the schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
Each row is written multiple times - once per indexed field defined in the schema.
"""
raise RuntimeError("This code is no longer in use")
import pulsar
import base64
import os
import argparse
import time
import json
import logging
import re
from typing import Dict, Set, Optional, Any, List, Tuple
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
from .... schema import Rows
from .... log_level import LogLevel
from .... base import Consumer
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
module = "rows-write"
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
default_ident = "rows-write"
default_input_queue = "rows-store" # Default queue name
default_subscriber = module
class Processor(Consumer):
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
@ -48,99 +51,549 @@ class Processor(Consumer):
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Rows,
"cassandra_host": ','.join(hosts),
"cassandra_username": username,
"cassandra_password": password,
"id": id,
"config_type": self.config_key,
}
)
if username and password:
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
self.tables = set()
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_object
)
)
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
""");
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
self.session.execute("use trustgraph");
# Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
async def handle(self, msg):
# Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set()
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
v = msg.value()
name = v.row_schema.name
if name not in self.tables:
# FIXME: SQL injection?
pkey = []
stmt = "create table if not exists " + name + " ( "
for field in v.row_schema.fields:
stmt += field.name + " text, "
if field.primary:
pkey.append(field.name)
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
self.session.execute(stmt)
self.tables.add(name);
for row in v.rows:
field_names = []
values = []
for field in v.row_schema.fields:
field_names.append(field.name)
values.append(row[field.name])
# FIXME: SQL injection?
stmt = (
"insert into " + name + " (" + ", ".join(field_names) +
") values (" + ",".join(["%s"] * len(values)) + ")"
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session.execute(stmt, values)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
logger.error(f"Exception: {str(e)}", exc_info=True)
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# If there's an error make sure to do table creation etc.
self.tables.remove(name)
# Track which schemas changed so we can clear partition cache
old_schema_names = set(self.schemas.keys())
raise e
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Clear partition cache for schemas that changed
# This ensures next write will re-register partitions
new_schema_names = set(self.schemas.keys())
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
if changed_schemas:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if sch not in changed_schemas
}
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def ensure_tables(self, keyspace: str):
"""Ensure unified rows and row_partitions tables exist"""
if keyspace in self.tables_initialized:
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
# Create unified rows table
create_rows_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
collection text,
schema_name text,
index_name text,
index_value frozen<list<text>>,
data map<text, text>,
source text,
PRIMARY KEY ((collection, schema_name, index_name), index_value)
)
"""
# Create row_partitions tracking table
create_partitions_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
collection text,
schema_name text,
index_name text,
PRIMARY KEY ((collection), schema_name, index_name)
)
"""
try:
self.session.execute(create_rows_cql)
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
self.session.execute(create_partitions_cql)
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
self.tables_initialized.add(keyspace)
except Exception as e:
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
raise
def get_index_names(self, schema: RowSchema) -> List[str]:
"""
Get all index names for a schema.
Returns list of index_name strings (single field names or comma-joined composites).
"""
index_names = []
for field in schema.fields:
# Primary key fields are treated as indexes
if field.primary:
index_names.append(field.name)
# Indexed fields
elif field.indexed:
index_names.append(field.name)
# TODO: Support composite indexes in the future
# For now, each indexed field is a single-field index
return index_names
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
"""
Register partition entries for a (collection, schema_name) pair.
Called once on first row for each pair.
"""
cache_key = (collection, schema_name)
if cache_key in self.registered_partitions:
return
schema = self.schemas.get(schema_name)
if not schema:
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
return
safe_keyspace = self.sanitize_name(keyspace)
index_names = self.get_index_names(schema)
# Insert partition entries for each index
insert_cql = f"""
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
VALUES (%s, %s, %s)
"""
for index_name in index_names:
try:
self.session.execute(insert_cql, (collection, schema_name, index_name))
except Exception as e:
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
self.registered_partitions.add(cache_key)
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""
Build the index_value list for a given index.
For single-field indexes, returns a single-element list.
For composite indexes (comma-separated), returns multiple elements.
"""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
# Convert to string for storage
values.append(str(value) if value is not None else "")
return values
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
f"from {obj.metadata.id}"
)
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
keyspace = obj.metadata.user
collection = obj.metadata.collection
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist
self.ensure_tables(keyspace)
# Register partitions if first time seeing this (collection, schema_name)
self.register_partitions(keyspace, collection, schema_name)
safe_keyspace = self.sanitize_name(keyspace)
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
return
# Prepare insert statement
insert_cql = f"""
INSERT INTO {safe_keyspace}.rows
(collection, schema_name, index_name, index_value, data, source)
VALUES (%s, %s, %s, %s, %s, %s)
"""
# Process each row in the batch
rows_written = 0
for row_index, value_map in enumerate(obj.values):
# Convert all values to strings for the data map
data_map = {}
for field in schema.fields:
raw_value = value_map.get(field.name)
if raw_value is not None:
data_map[field.name] = str(raw_value)
# Write one copy per index
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip if index value is empty/null
if not index_value or all(v == "" for v in index_value):
logger.debug(
f"Skipping index {index_name} for row {row_index} - "
f"empty index value"
)
continue
try:
self.session.execute(
insert_cql,
(collection, schema_name, index_name, index_value, data_map, source)
)
rows_written += 1
except Exception as e:
logger.error(
f"Failed to insert row {row_index} for index {index_name}: {e}",
exc_info=True
)
raise
logger.info(
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected
self.connect_cassandra()
# Ensure tables exist
self.ensure_tables(user)
logger.info(f"Collection {collection} ready for user {user}")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if user not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
# Discover all partitions for this collection
select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
partitions = self.session.execute(select_partitions_cql, (collection,))
partition_list = list(partitions)
except Exception as e:
logger.error(f"Failed to query partitions for collection {collection}: {e}")
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
self.session.execute(
delete_rows_cql,
(collection, partition.schema_name, partition.index_name)
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{partition.schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
self.session.execute(delete_partitions_cql, (collection,))
except Exception as e:
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
# Clear from local cache
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
}
logger.info(
f"Deleted collection {collection}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Discover partitions for this collection + schema
select_partitions_cql = f"""
SELECT index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
partition_list = list(partitions)
except Exception as e:
logger.error(
f"Failed to query partitions for {collection}/{schema_name}: {e}"
)
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
self.session.execute(
delete_rows_cql,
(collection, schema_name, partition.index_name)
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries for this schema
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
self.session.execute(delete_partitions_cql, (collection, schema_name))
except Exception as e:
logger.error(
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
)
raise
# Clear from local cache
self.registered_partitions.discard((collection, schema_name))
logger.info(
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
Processor.launch(module, __doc__)
"""Entry point for rows-write-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -10,11 +10,14 @@ import argparse
import time
import logging
from .... direct.cassandra_kg import KnowledgeGraph
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
)
from .... base import TriplesStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... schema import IRI, LITERAL, BLANK, TRIPLE
# Module logger
logger = logging.getLogger(__name__)
@ -22,6 +25,59 @@ logger = logging.getLogger(__name__)
default_ident = "triples-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
def get_term_otype(term):
"""
Get object type code from a Term for entity-centric storage.
Maps Term.type to otype:
- IRI ("i") "u" (URI)
- BLANK ("b") "u" (treated as URI)
- LITERAL ("l") "l" (Literal)
- TRIPLE ("t") "t" (Triple/reification)
"""
if term is None:
return "u"
if term.type == IRI or term.type == BLANK:
return "u"
elif term.type == LITERAL:
return "l"
elif term.type == TRIPLE:
return "t"
else:
return "u"
def get_term_dtype(term):
"""Extract datatype from a Term (for literals)"""
if term is None:
return ""
if term.type == LITERAL:
return term.datatype or ""
return ""
def get_term_lang(term):
"""Extract language tag from a Term (for literals)"""
if term is None:
return ""
if term.type == LITERAL:
return term.language or ""
return ""
class Processor(CollectionConfigHandler, TriplesStoreService):
def __init__(self, **params):
@ -64,15 +120,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
)
@ -84,11 +143,27 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.table = user
for t in message.triples:
# Extract values from Term objects
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
# t.g is None for default graph, or a graph IRI
g_val = t.g if t.g is not None else DEFAULT_GRAPH
# Extract object type metadata for entity-centric storage
otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg.insert(
message.metadata.collection,
t.s.value,
t.p.value,
t.o.value
s_val,
p_val,
o_val,
g=g_val,
otype=otype,
dtype=dtype,
lang=lang
)
async def create_collection(self, user: str, collection: str, metadata: dict):
@ -98,16 +173,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.table is None or self.table != user:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
)
@ -137,16 +215,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.table is None or self.table != user:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
)

View file

@ -15,12 +15,27 @@ from falkordb import FalkorDB
from .... base import TriplesStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
default_ident = "triples-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_graph_url = 'falkor://falkordb:6379'
default_database = 'falkordb'
@ -164,14 +179,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
for t in message.triples:
self.create_node(t.s.value, user, collection)
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
if t.o.is_uri:
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
self.create_node(s_val, user, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
else:
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
@staticmethod
def add_args(parser):

View file

@ -15,12 +15,27 @@ from neo4j import GraphDatabase
from .... base import TriplesStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
default_ident = "triples-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_graph_host = 'bolt://memgraph:7687'
default_username = 'memgraph'
default_password = 'password'
@ -204,40 +219,44 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
def create_triple(self, tx, t, user, collection):
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
# Create new s node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=t.s.value, user=user, collection=collection
uri=s_val, user=user, collection=collection
)
if t.o.is_uri:
if t.o.type == IRI:
# Create new o node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=t.o.value, user=user, collection=collection
uri=o_val, user=user, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
)
else:
# Create new o literal with given uri, if not exists
result = tx.run(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=t.o.value, user=user, collection=collection
value=o_val, user=user, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
)
async def store_triples(self, message):
@ -257,14 +276,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
for t in message.triples:
self.create_node(t.s.value, user, collection)
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
if t.o.is_uri:
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
self.create_node(s_val, user, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
else:
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
# Alternative implementation using transactions
# with self.io.session(database=self.db) as session:

View file

@ -14,12 +14,27 @@ from neo4j import GraphDatabase
from .... base import TriplesStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
default_ident = "triples-write"
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_graph_host = 'bolt://neo4j:7687'
default_username = 'neo4j'
default_password = 'password'
@ -212,14 +227,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
for t in message.triples:
self.create_node(t.s.value, user, collection)
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
if t.o.is_uri:
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
self.create_node(s_val, user, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
else:
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
@staticmethod
def add_args(parser):

View file

@ -1,6 +1,6 @@
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
from .. schema import Metadata, Value, GraphEmbeddings
from .. schema import Metadata, GraphEmbeddings
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider

View file

@ -1,8 +1,24 @@
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
from .. schema import Metadata, Value, GraphEmbeddings
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
from cassandra.cluster import Cluster
def term_to_tuple(term):
"""Convert Term to (value, is_uri) tuple for database storage."""
if term.type == IRI:
return (term.iri, True)
else: # LITERAL
return (term.value, False)
def tuple_to_term(value, is_uri):
"""Convert (value, is_uri) tuple from database to Term."""
if is_uri:
return Term(type=IRI, iri=value)
else:
return Term(type=LITERAL, value=value)
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
@ -205,8 +221,7 @@ class KnowledgeTableStore:
if m.metadata.metadata:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.metadata.metadata
]
@ -215,8 +230,7 @@ class KnowledgeTableStore:
triples = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.triples
]
@ -248,8 +262,7 @@ class KnowledgeTableStore:
if m.metadata.metadata:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.metadata.metadata
]
@ -258,7 +271,7 @@ class KnowledgeTableStore:
entities = [
(
(v.entity.value, v.entity.is_uri),
term_to_tuple(v.entity),
v.vectors
)
for v in m.entities
@ -291,8 +304,7 @@ class KnowledgeTableStore:
if m.metadata.metadata:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.metadata.metadata
]
@ -414,23 +426,26 @@ class KnowledgeTableStore:
if row[2]:
metadata = [
Triple(
s = Value(value = elt[0], is_uri = elt[1]),
p = Value(value = elt[2], is_uri = elt[3]),
o = Value(value = elt[4], is_uri = elt[5]),
s = tuple_to_term(elt[0], elt[1]),
p = tuple_to_term(elt[2], elt[3]),
o = tuple_to_term(elt[4], elt[5]),
)
for elt in row[2]
]
else:
metadata = []
triples = [
Triple(
s = Value(value = elt[0], is_uri = elt[1]),
p = Value(value = elt[2], is_uri = elt[3]),
o = Value(value = elt[4], is_uri = elt[5]),
)
for elt in row[3]
]
if row[3]:
triples = [
Triple(
s = tuple_to_term(elt[0], elt[1]),
p = tuple_to_term(elt[2], elt[3]),
o = tuple_to_term(elt[4], elt[5]),
)
for elt in row[3]
]
else:
triples = []
await receiver(
Triples(
@ -470,22 +485,25 @@ class KnowledgeTableStore:
if row[2]:
metadata = [
Triple(
s = Value(value = elt[0], is_uri = elt[1]),
p = Value(value = elt[2], is_uri = elt[3]),
o = Value(value = elt[4], is_uri = elt[5]),
s = tuple_to_term(elt[0], elt[1]),
p = tuple_to_term(elt[2], elt[3]),
o = tuple_to_term(elt[4], elt[5]),
)
for elt in row[2]
]
else:
metadata = []
entities = [
EntityEmbeddings(
entity = Value(value = ent[0][0], is_uri = ent[0][1]),
vectors = ent[1]
)
for ent in row[3]
]
if row[3]:
entities = [
EntityEmbeddings(
entity = tuple_to_term(ent[0][0], ent[0][1]),
vectors = ent[1]
)
for ent in row[3]
]
else:
entities = []
await receiver(
GraphEmbeddings(

View file

@ -1,8 +1,24 @@
from .. schema import LibrarianRequest, LibrarianResponse
from .. schema import DocumentMetadata, ProcessingMetadata
from .. schema import Error, Triple, Value
from .. schema import Error, Triple, Term, IRI, LITERAL
from .. knowledge import hash
def term_to_tuple(term):
"""Convert Term to (value, is_uri) tuple for database storage."""
if term.type == IRI:
return (term.iri, True)
else: # LITERAL
return (term.value, False)
def tuple_to_term(value, is_uri):
"""Convert (value, is_uri) tuple from database to Term."""
if is_uri:
return Term(type=IRI, iri=value)
else:
return Term(type=LITERAL, value=value)
from .. exceptions import RequestError
from cassandra.cluster import Cluster
@ -215,8 +231,7 @@ class LibraryTableStore:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in document.metadata
]
@ -249,8 +264,7 @@ class LibraryTableStore:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in document.metadata
]
@ -331,9 +345,9 @@ class LibraryTableStore:
comments = row[4],
metadata = [
Triple(
s=Value(value=m[0], is_uri=m[1]),
p=Value(value=m[2], is_uri=m[3]),
o=Value(value=m[4], is_uri=m[5])
s=tuple_to_term(m[0], m[1]),
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in row[5]
],
@ -376,9 +390,9 @@ class LibraryTableStore:
comments = row[3],
metadata = [
Triple(
s=Value(value=m[0], is_uri=m[1]),
p=Value(value=m[2], is_uri=m[3]),
o=Value(value=m[4], is_uri=m[5])
s=tuple_to_term(m[0], m[1]),
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in row[4]
],

View file

@ -83,7 +83,7 @@ class PromptManager:
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
@ -92,6 +92,43 @@ class PromptManager:
return json.loads(json_str)
def parse_jsonl(self, text):
"""
Parse JSONL response, returning list of valid objects.
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
This provides truncation resilience - partial output yields partial results.
"""
results = []
# Strip markdown code fences if present
text = text.strip()
if text.startswith('```'):
# Remove opening fence (possibly with language hint)
text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text)
if text.endswith('```'):
text = text[:-3]
for line_num, line in enumerate(text.strip().split('\n'), 1):
line = line.strip()
# Skip empty lines
if not line:
continue
# Skip any remaining fence markers
if line.startswith('```'):
continue
try:
obj = json.loads(line)
results.append(obj)
except json.JSONDecodeError as e:
# Log warning but continue - this provides truncation resilience
logger.warning(f"JSONL parse error on line {line_num}: {e}")
return results
def render(self, id, input):
if id not in self.prompts:
@ -121,21 +158,41 @@ class PromptManager:
if resp_type == "text":
return resp
if resp_type != "json":
raise RuntimeError(f"Response type {resp_type} not known")
try:
obj = self.parse_json(resp)
except:
logger.error(f"JSON parse failed: {resp}")
raise RuntimeError("JSON parse fail")
if self.prompts[id].schema:
if resp_type == "json":
try:
validate(instance=obj, schema=self.prompts[id].schema)
logger.debug("Schema validation successful")
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")
obj = self.parse_json(resp)
except:
logger.error(f"JSON parse failed: {resp}")
raise RuntimeError("JSON parse fail")
return obj
if self.prompts[id].schema:
try:
validate(instance=obj, schema=self.prompts[id].schema)
logger.debug("Schema validation successful")
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")
return obj
if resp_type == "jsonl":
objects = self.parse_jsonl(resp)
if not objects:
logger.warning("JSONL parse returned no valid objects")
return []
# Validate each object against schema if provided
if self.prompts[id].schema:
validated = []
for i, obj in enumerate(objects):
try:
validate(instance=obj, schema=self.prompts[id].schema)
validated.append(obj)
except Exception as e:
logger.warning(f"Object {i} failed schema validation: {e}")
return validated
return objects
raise RuntimeError(f"Response type {resp_type} not known")