mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 09:26:22 +02:00
Merge 2.0 to master (#651)
This commit is contained in:
parent
3666ece2c5
commit
b9d7bf9a8b
212 changed files with 13940 additions and 6180 deletions
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=1.8,<1.9",
|
||||
"trustgraph-base>=2.0,<2.1",
|
||||
"aiohttp",
|
||||
"anthropic",
|
||||
"scylla-driver",
|
||||
|
|
@ -19,7 +19,6 @@ dependencies = [
|
|||
"faiss-cpu",
|
||||
"falkordb",
|
||||
"fastembed",
|
||||
"google-genai",
|
||||
"ibis",
|
||||
"jsonschema",
|
||||
"langchain",
|
||||
|
|
@ -61,27 +60,27 @@ api-gateway = "trustgraph.gateway:run"
|
|||
chunker-recursive = "trustgraph.chunking.recursive:run"
|
||||
chunker-token = "trustgraph.chunking.token:run"
|
||||
config-svc = "trustgraph.config.service:run"
|
||||
de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
|
||||
de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
|
||||
de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
|
||||
de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
|
||||
de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
|
||||
de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
|
||||
doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
|
||||
doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
|
||||
doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
|
||||
doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
|
||||
doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
|
||||
doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
|
||||
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
|
||||
document-rag = "trustgraph.retrieval.document_rag:run"
|
||||
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
|
||||
embeddings-ollama = "trustgraph.embeddings.ollama:run"
|
||||
ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
|
||||
ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
|
||||
ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
|
||||
ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
|
||||
ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
|
||||
ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
|
||||
graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
|
||||
graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
|
||||
graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
|
||||
graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
|
||||
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||
kg-extract-objects = "trustgraph.extract.kg.objects:run"
|
||||
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
|
||||
kg-extract-topics = "trustgraph.extract.kg.topics:run"
|
||||
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
|
||||
|
|
@ -91,8 +90,11 @@ librarian = "trustgraph.librarian:run"
|
|||
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
||||
metering = "trustgraph.metering:run"
|
||||
nlp-query = "trustgraph.retrieval.nlp_query:run"
|
||||
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
|
||||
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
|
||||
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
|
||||
rows-query-cassandra = "trustgraph.query.rows.cassandra:run"
|
||||
row-embeddings = "trustgraph.embeddings.row_embeddings:run"
|
||||
row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run"
|
||||
row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run"
|
||||
pdf-decoder = "trustgraph.decoding.pdf:run"
|
||||
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
|
||||
prompt-template = "trustgraph.prompt.template:run"
|
||||
|
|
@ -104,7 +106,6 @@ text-completion-azure = "trustgraph.model.text_completion.azure:run"
|
|||
text-completion-azure-openai = "trustgraph.model.text_completion.azure_openai:run"
|
||||
text-completion-claude = "trustgraph.model.text_completion.claude:run"
|
||||
text-completion-cohere = "trustgraph.model.text_completion.cohere:run"
|
||||
text-completion-googleaistudio = "trustgraph.model.text_completion.googleaistudio:run"
|
||||
text-completion-llamafile = "trustgraph.model.text_completion.llamafile:run"
|
||||
text-completion-lmstudio = "trustgraph.model.text_completion.lmstudio:run"
|
||||
text-completion-mistral = "trustgraph.model.text_completion.mistral:run"
|
||||
|
|
|
|||
|
|
@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
|
||||
from . agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
|
|
@ -87,6 +88,20 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name = "embeddings-request",
|
||||
response_name = "embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RowEmbeddingsQueryClientSpec(
|
||||
request_name = "row-embeddings-query-request",
|
||||
response_name = "row-embeddings-query-response",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -147,11 +162,21 @@ class Processor(AgentService):
|
|||
)
|
||||
elif impl_id == "structured-query":
|
||||
impl = functools.partial(
|
||||
StructuredQueryImpl,
|
||||
StructuredQueryImpl,
|
||||
collection=data.get("collection"),
|
||||
user=None # User will be provided dynamically via context
|
||||
)
|
||||
arguments = StructuredQueryImpl.get_arguments()
|
||||
elif impl_id == "row-embeddings-query":
|
||||
impl = functools.partial(
|
||||
RowEmbeddingsQueryImpl,
|
||||
schema_name=data.get("schema-name"),
|
||||
collection=data.get("collection"),
|
||||
user=None, # User will be provided dynamically via context
|
||||
index_name=data.get("index-name"), # Optional filter
|
||||
limit=int(data.get("limit", 10)) # Max results
|
||||
)
|
||||
arguments = RowEmbeddingsQueryImpl.get_arguments()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool type {impl_id} not known"
|
||||
|
|
@ -327,11 +352,11 @@ class Processor(AgentService):
|
|||
def __init__(self, flow, user):
|
||||
self._flow = flow
|
||||
self._user = user
|
||||
|
||||
|
||||
def __call__(self, service_name):
|
||||
client = self._flow(service_name)
|
||||
# For structured query clients, store user context
|
||||
if service_name == "structured-query-request":
|
||||
# For query clients that need user context, store it
|
||||
if service_name in ("structured-query-request", "row-embeddings-query-request"):
|
||||
client._current_user = self._user
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,62 @@ class StructuredQueryImpl:
|
|||
return str(result)
|
||||
|
||||
|
||||
# This tool implementation knows how to query row embeddings for semantic search
|
||||
class RowEmbeddingsQueryImpl:
|
||||
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
|
||||
self.context = context
|
||||
self.schema_name = schema_name
|
||||
self.collection = collection
|
||||
self.user = user
|
||||
self.index_name = index_name # Optional: filter to specific index
|
||||
self.limit = limit # Max results to return
|
||||
|
||||
@staticmethod
|
||||
def get_arguments():
|
||||
return [
|
||||
Argument(
|
||||
name="query",
|
||||
type="string",
|
||||
description="Text to search for semantically similar values in the structured data index"
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
# First get embeddings for the query text
|
||||
embeddings_client = self.context("embeddings-request")
|
||||
logger.debug("Getting embeddings for row query...")
|
||||
|
||||
query_text = arguments.get("query")
|
||||
vectors = await embeddings_client.embed(query_text)
|
||||
|
||||
# Now query row embeddings
|
||||
client = self.context("row-embeddings-query-request")
|
||||
logger.debug("Row embeddings query...")
|
||||
|
||||
# Get user from client context if available
|
||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||
|
||||
matches = await client.row_embeddings_query(
|
||||
vectors=vectors,
|
||||
schema_name=self.schema_name,
|
||||
user=user,
|
||||
collection=self.collection or "default",
|
||||
index_name=self.index_name,
|
||||
limit=self.limit
|
||||
)
|
||||
|
||||
# Format results for agent consumption
|
||||
if not matches:
|
||||
return "No matching records found"
|
||||
|
||||
results = []
|
||||
for match in matches:
|
||||
result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})"
|
||||
results.append(result)
|
||||
|
||||
return "Matching records:\n" + "\n".join(results)
|
||||
|
||||
|
||||
# This tool implementation knows how to execute prompt templates
|
||||
class PromptImpl:
|
||||
def __init__(self, context, template_id, arguments=None):
|
||||
|
|
|
|||
|
|
@ -124,12 +124,13 @@ class Processor(AsyncProcessor):
|
|||
|
||||
logger.info(f"Configuration version: {version}")
|
||||
|
||||
if "flows" in config:
|
||||
|
||||
if "flow" in config:
|
||||
self.flows = {
|
||||
k: json.loads(v)
|
||||
for k, v in config["flows"].items()
|
||||
for k, v in config["flow"].items()
|
||||
}
|
||||
else:
|
||||
self.flows = {}
|
||||
|
||||
logger.debug(f"Flows: {self.flows}")
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -16,12 +16,14 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "graph-embeddings"
|
||||
default_batch_size = 5
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
self.batch_size = params.get("batch_size", default_batch_size)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -73,12 +75,14 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
r = GraphEmbeddings(
|
||||
metadata=v.metadata,
|
||||
entities=entities,
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
# Send in batches to avoid oversized messages
|
||||
for i in range(0, len(entities), self.batch_size):
|
||||
batch = entities[i:i + self.batch_size]
|
||||
r = GraphEmbeddings(
|
||||
metadata=v.metadata,
|
||||
entities=batch,
|
||||
)
|
||||
await flow("output").send(r)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
|
|
@ -91,6 +95,13 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=default_batch_size,
|
||||
help=f'Maximum entities per output message (default: {default_batch_size})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . embeddings import *
|
||||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
|
||||
from . embeddings import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
|
||||
"""
|
||||
Row embeddings processor. Calls the embeddings service to compute embeddings
|
||||
for indexed field values in extracted row data.
|
||||
|
||||
Input is ExtractedObject (structured row data with schema).
|
||||
Output is RowEmbeddings (row data with embeddings for indexed fields).
|
||||
|
||||
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
|
||||
Stage 1 (this processor): Compute embeddings
|
||||
Stage 2 (row-embeddings-write-*): Store embeddings
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding
|
||||
from ... schema import RowSchema, Field
|
||||
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
|
||||
from ... base import ProducerSpec, CollectionConfigHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "row-embeddings"
|
||||
default_batch_size = 10
|
||||
|
||||
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
self.batch_size = params.get("batch_size", default_batch_size)
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=ExtractedObject,
|
||||
handler=self.on_message,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name="embeddings-request",
|
||||
response_name="embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="output",
|
||||
schema=RowEmbeddings
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handlers
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = Field(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
"""Get all index names for a schema."""
|
||||
index_names = []
|
||||
for field in schema.fields:
|
||||
if field.primary or field.indexed:
|
||||
index_names.append(field.name)
|
||||
return index_names
|
||||
|
||||
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
||||
"""Build the index_value list for a given index."""
|
||||
field_names = [f.strip() for f in index_name.split(',')]
|
||||
values = []
|
||||
for field_name in field_names:
|
||||
value = value_map.get(field_name)
|
||||
values.append(str(value) if value is not None else "")
|
||||
return values
|
||||
|
||||
def build_text_for_embedding(self, index_value: List[str]) -> str:
|
||||
"""Build text representation for embedding from index values."""
|
||||
# Space-join the values for composite indexes
|
||||
return " ".join(index_value)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and compute embeddings"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(obj.values)} rows, "
|
||||
f"schema {obj.schema_name}, doc {obj.metadata.id}"
|
||||
)
|
||||
|
||||
# Validate collection exists before processing
|
||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||
logger.warning(
|
||||
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
|
||||
f"does not exist in config. Dropping message."
|
||||
)
|
||||
return
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
# Get all index names for this schema
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
if not index_names:
|
||||
logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping")
|
||||
return
|
||||
|
||||
# Track unique texts to avoid duplicate embeddings
|
||||
# text -> (index_name, index_value)
|
||||
texts_to_embed: Dict[str, tuple] = {}
|
||||
|
||||
# Collect all texts that need embeddings
|
||||
for value_map in obj.values:
|
||||
for index_name in index_names:
|
||||
index_value = self.build_index_value(value_map, index_name)
|
||||
|
||||
# Skip empty values
|
||||
if not index_value or all(v == "" for v in index_value):
|
||||
continue
|
||||
|
||||
text = self.build_text_for_embedding(index_value)
|
||||
if text and text not in texts_to_embed:
|
||||
texts_to_embed[text] = (index_name, index_value)
|
||||
|
||||
if not texts_to_embed:
|
||||
logger.info("No texts to embed")
|
||||
return
|
||||
|
||||
# Compute embeddings
|
||||
embeddings_list = []
|
||||
|
||||
try:
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
|
||||
embeddings_list.append(
|
||||
RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
)
|
||||
)
|
||||
|
||||
# Send in batches to avoid oversized messages
|
||||
for i in range(0, len(embeddings_list), self.batch_size):
|
||||
batch = embeddings_list[i:i + self.batch_size]
|
||||
result = RowEmbeddings(
|
||||
metadata=obj.metadata,
|
||||
schema_name=obj.schema_name,
|
||||
embeddings=batch,
|
||||
)
|
||||
await flow("output").send(result)
|
||||
|
||||
logger.info(
|
||||
f"Computed {len(embeddings_list)} embeddings for "
|
||||
f"{len(obj.values)} rows ({len(index_names)} indexes)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception during embedding computation", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Collection creation notification - no action needed for embedding stage"""
|
||||
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Collection deletion notification - no action needed for embedding stage"""
|
||||
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=default_batch_size,
|
||||
help=f'Maximum embeddings per output message (default: {default_batch_size})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import urllib.parse
|
||||
import logging
|
||||
|
||||
from ....schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from ....schema import EntityContext, EntityContexts
|
||||
|
||||
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
|
||||
|
|
@ -126,16 +126,42 @@ class Processor(FlowProcessor):
|
|||
|
||||
await pub.send(ecs)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
def parse_jsonl(self, text):
|
||||
"""
|
||||
Parse JSONL response, returning list of valid objects.
|
||||
|
||||
return json.loads(json_str)
|
||||
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
|
||||
This provides truncation resilience - partial output yields partial results.
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Strip markdown code fences if present
|
||||
text = text.strip()
|
||||
if text.startswith('```'):
|
||||
# Remove opening fence (possibly with language hint)
|
||||
text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text)
|
||||
if text.endswith('```'):
|
||||
text = text[:-3]
|
||||
|
||||
for line_num, line in enumerate(text.strip().split('\n'), 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip any remaining fence markers
|
||||
if line.startswith('```'):
|
||||
continue
|
||||
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
results.append(obj)
|
||||
except json.JSONDecodeError as e:
|
||||
# Log warning but continue - this provides truncation resilience
|
||||
logger.warning(f"JSONL parse error on line {line_num}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
|
|
@ -178,11 +204,12 @@ class Processor(FlowProcessor):
|
|||
question = prompt
|
||||
)
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
extraction_data = self.parse_json(agent_response)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON response from agent: {e}")
|
||||
# Parse JSONL response
|
||||
extraction_data = self.parse_jsonl(agent_response)
|
||||
|
||||
if not extraction_data:
|
||||
logger.warning("JSONL parse returned no valid objects")
|
||||
return
|
||||
|
||||
# Process extraction data
|
||||
triples, entity_contexts = self.process_extraction_data(
|
||||
|
|
@ -209,103 +236,113 @@ class Processor(FlowProcessor):
|
|||
raise
|
||||
|
||||
def process_extraction_data(self, data, metadata):
|
||||
"""Process combined extraction data to generate triples and entity contexts"""
|
||||
"""Process JSONL extraction data to generate triples and entity contexts.
|
||||
|
||||
Data is a flat list of objects with 'type' discriminator field:
|
||||
- {"type": "definition", "entity": "...", "definition": "..."}
|
||||
- {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool}
|
||||
"""
|
||||
triples = []
|
||||
entity_contexts = []
|
||||
|
||||
# Categorize items by type
|
||||
definitions = [item for item in data if item.get("type") == "definition"]
|
||||
relationships = [item for item in data if item.get("type") == "relationship"]
|
||||
|
||||
# Process definitions
|
||||
for defn in data.get("definitions", []):
|
||||
for defn in definitions:
|
||||
|
||||
entity_uri = self.to_uri(defn["entity"])
|
||||
|
||||
|
||||
# Add entity label
|
||||
triples.append(Triple(
|
||||
s = Value(value=entity_uri, is_uri=True),
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=defn["entity"], is_uri=False),
|
||||
s = Term(type=IRI, iri=entity_uri),
|
||||
p = Term(type=IRI, iri=RDF_LABEL),
|
||||
o = Term(type=LITERAL, value=defn["entity"]),
|
||||
))
|
||||
|
||||
|
||||
# Add definition
|
||||
triples.append(Triple(
|
||||
s = Value(value=entity_uri, is_uri=True),
|
||||
p = Value(value=DEFINITION, is_uri=True),
|
||||
o = Value(value=defn["definition"], is_uri=False),
|
||||
s = Term(type=IRI, iri=entity_uri),
|
||||
p = Term(type=IRI, iri=DEFINITION),
|
||||
o = Term(type=LITERAL, value=defn["definition"]),
|
||||
))
|
||||
|
||||
|
||||
# Add subject-of relationship to document
|
||||
if metadata.id:
|
||||
triples.append(Triple(
|
||||
s = Value(value=entity_uri, is_uri=True),
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
s = Term(type=IRI, iri=entity_uri),
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
|
||||
# Create entity context for embeddings
|
||||
entity_contexts.append(EntityContext(
|
||||
entity=Value(value=entity_uri, is_uri=True),
|
||||
entity=Term(type=IRI, iri=entity_uri),
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
# Process relationships
|
||||
for rel in data.get("relationships", []):
|
||||
for rel in relationships:
|
||||
|
||||
subject_uri = self.to_uri(rel["subject"])
|
||||
predicate_uri = self.to_uri(rel["predicate"])
|
||||
|
||||
subject_value = Value(value=subject_uri, is_uri=True)
|
||||
predicate_value = Value(value=predicate_uri, is_uri=True)
|
||||
if data.get("object-entity", False):
|
||||
object_value = Value(value=predicate_uri, is_uri=True)
|
||||
subject_value = Term(type=IRI, iri=subject_uri)
|
||||
predicate_value = Term(type=IRI, iri=predicate_uri)
|
||||
if rel.get("object-entity", True):
|
||||
object_uri = self.to_uri(rel["object"])
|
||||
object_value = Term(type=IRI, iri=object_uri)
|
||||
else:
|
||||
object_value = Value(value=predicate_uri, is_uri=False)
|
||||
|
||||
object_value = Term(type=LITERAL, value=rel["object"])
|
||||
|
||||
# Add subject and predicate labels
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=rel["subject"], is_uri=False),
|
||||
p = Term(type=IRI, iri=RDF_LABEL),
|
||||
o = Term(type=LITERAL, value=rel["subject"]),
|
||||
))
|
||||
|
||||
|
||||
triples.append(Triple(
|
||||
s = predicate_value,
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=rel["predicate"], is_uri=False),
|
||||
p = Term(type=IRI, iri=RDF_LABEL),
|
||||
o = Term(type=LITERAL, value=rel["predicate"]),
|
||||
))
|
||||
|
||||
|
||||
# Handle object (entity vs literal)
|
||||
if rel.get("object-entity", True):
|
||||
triples.append(Triple(
|
||||
s = object_value,
|
||||
p = Value(value=RDF_LABEL, is_uri=True),
|
||||
o = Value(value=rel["object"], is_uri=True),
|
||||
p = Term(type=IRI, iri=RDF_LABEL),
|
||||
o = Term(type=LITERAL, value=rel["object"]),
|
||||
))
|
||||
|
||||
|
||||
# Add the main relationship triple
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = predicate_value,
|
||||
o = object_value
|
||||
))
|
||||
|
||||
|
||||
# Add subject-of relationships to document
|
||||
if metadata.id:
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
|
||||
triples.append(Triple(
|
||||
s = predicate_value,
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
|
||||
if rel.get("object-entity", True):
|
||||
triples.append(Triple(
|
||||
s = object_value,
|
||||
p = Value(value=SUBJECT_OF, is_uri=True),
|
||||
o = Value(value=metadata.id, is_uri=True),
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
return triples, entity_contexts
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import json
|
|||
import urllib.parse
|
||||
import logging
|
||||
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -20,12 +20,14 @@ from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
|||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||
DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION)
|
||||
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
|
||||
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
|
||||
|
||||
default_ident = "kg-extract-definitions"
|
||||
default_concurrency = 1
|
||||
default_triples_batch_size = 50
|
||||
default_entity_batch_size = 5
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -33,6 +35,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size)
|
||||
self.entity_batch_size = params.get("entity_batch_size", default_entity_batch_size)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -142,13 +146,13 @@ class Processor(FlowProcessor):
|
|||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
s_value = Term(type=IRI, iri=str(s_uri))
|
||||
o_value = Term(type=LITERAL, value=str(o))
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=RDF_LABEL_VALUE,
|
||||
o=Value(value=s, is_uri=False),
|
||||
o=Term(type=LITERAL, value=s),
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
|
|
@ -158,37 +162,48 @@ class Processor(FlowProcessor):
|
|||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=SUBJECT_OF_VALUE,
|
||||
o=Value(value=v.metadata.id, is_uri=True)
|
||||
o=Term(type=IRI, iri=v.metadata.id)
|
||||
))
|
||||
|
||||
ec = EntityContext(
|
||||
# Output entity name as context for direct name matching
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
context=s,
|
||||
))
|
||||
|
||||
# Output definition as context for semantic matching
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
context=defn["definition"],
|
||||
))
|
||||
|
||||
# Send triples in batches
|
||||
for i in range(0, len(triples), self.triples_batch_size):
|
||||
batch = triples[i:i + self.triples_batch_size]
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
batch
|
||||
)
|
||||
|
||||
entities.append(ec)
|
||||
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples
|
||||
)
|
||||
|
||||
await self.emit_ecs(
|
||||
flow("entity-contexts"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
entities
|
||||
)
|
||||
# Send entity contexts in batches
|
||||
for i in range(0, len(entities), self.entity_batch_size):
|
||||
batch = entities[i:i + self.entity_batch_size]
|
||||
await self.emit_ecs(
|
||||
flow("entity-contexts"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
batch
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Definitions extraction exception: {e}", exc_info=True)
|
||||
|
|
@ -205,6 +220,20 @@ class Processor(FlowProcessor):
|
|||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--triples-batch-size',
|
||||
type=int,
|
||||
default=default_triples_batch_size,
|
||||
help=f'Maximum triples per output message (default: {default_triples_batch_size})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--entity-batch-size',
|
||||
type=int,
|
||||
default=default_entity_batch_size,
|
||||
help=f'Maximum entity contexts per output message (default: {default_entity_batch_size})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -74,23 +74,27 @@ def build_entity_uri(entity_name: str, entity_type: str, ontology_id: str,
|
|||
|
||||
Args:
|
||||
entity_name: Natural language entity name (e.g., "Cornish pasty")
|
||||
entity_type: Ontology type (e.g., "fo/Recipe")
|
||||
entity_type: Ontology type (e.g., "fo/Recipe" or "Recipe")
|
||||
ontology_id: Ontology identifier (e.g., "food")
|
||||
base_uri: Base URI for entity URIs (default: "https://trustgraph.ai")
|
||||
|
||||
Returns:
|
||||
Full entity URI (e.g., "https://trustgraph.ai/food/fo-recipe-cornish-pasty")
|
||||
Full entity URI (e.g., "https://trustgraph.ai/food/recipe-cornish-pasty")
|
||||
|
||||
Examples:
|
||||
>>> build_entity_uri("Cornish pasty", "fo/Recipe", "food")
|
||||
'https://trustgraph.ai/food/fo-recipe-cornish-pasty'
|
||||
'https://trustgraph.ai/food/recipe-cornish-pasty'
|
||||
|
||||
>>> build_entity_uri("Cornish pasty", "fo/Food", "food")
|
||||
'https://trustgraph.ai/food/fo-food-cornish-pasty'
|
||||
>>> build_entity_uri("Cornish pasty", "Food", "food")
|
||||
'https://trustgraph.ai/food/food-cornish-pasty'
|
||||
|
||||
>>> build_entity_uri("beef", "fo/Food", "food")
|
||||
'https://trustgraph.ai/food/fo-food-beef'
|
||||
'https://trustgraph.ai/food/food-beef'
|
||||
"""
|
||||
# Strip ontology prefix from type if present (e.g., "fo/Recipe" -> "Recipe")
|
||||
if "/" in entity_type:
|
||||
entity_type = entity_type.split("/")[-1]
|
||||
|
||||
type_part = normalize_type_identifier(entity_type)
|
||||
name_part = normalize_entity_name(entity_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from .... schema import EntityContext, EntityContexts
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION
|
||||
|
|
@ -27,6 +27,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
default_ident = "kg-extract-ontology"
|
||||
default_concurrency = 1
|
||||
default_triples_batch_size = 50
|
||||
default_entity_batch_size = 5
|
||||
|
||||
# URI prefix mappings for common namespaces
|
||||
URI_PREFIXES = {
|
||||
|
|
@ -39,12 +41,22 @@ URI_PREFIXES = {
|
|||
}
|
||||
|
||||
|
||||
def make_term(v, is_uri):
|
||||
"""Helper to create Term from value and is_uri flag."""
|
||||
if is_uri:
|
||||
return Term(type=IRI, iri=v)
|
||||
else:
|
||||
return Term(type=LITERAL, value=v)
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
"""Main OntoRAG extraction processor."""
|
||||
|
||||
def __init__(self, **params):
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size)
|
||||
self.entity_batch_size = params.get("entity_batch_size", default_entity_batch_size)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -274,17 +286,6 @@ class Processor(FlowProcessor):
|
|||
|
||||
if not ontology_subsets:
|
||||
logger.warning("No relevant ontology elements found for chunk")
|
||||
# Emit empty outputs
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
return
|
||||
|
||||
# Merge subsets if multiple ontologies matched
|
||||
|
|
@ -318,36 +319,29 @@ class Processor(FlowProcessor):
|
|||
# Build entity contexts from all triples (including ontology elements)
|
||||
entity_contexts = self.build_entity_contexts(all_triples)
|
||||
|
||||
# Emit all triples (extracted + ontology definitions)
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
all_triples
|
||||
)
|
||||
# Emit triples in batches
|
||||
for i in range(0, len(all_triples), self.triples_batch_size):
|
||||
batch = all_triples[i:i + self.triples_batch_size]
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
batch
|
||||
)
|
||||
|
||||
# Emit entity contexts
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
entity_contexts
|
||||
)
|
||||
# Emit entity contexts in batches
|
||||
for i in range(0, len(entity_contexts), self.entity_batch_size):
|
||||
batch = entity_contexts[i:i + self.entity_batch_size]
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
batch
|
||||
)
|
||||
|
||||
logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples "
|
||||
f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
|
||||
# Emit empty outputs on error
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
|
||||
async def extract_with_simplified_format(
|
||||
self,
|
||||
|
|
@ -446,9 +440,9 @@ class Processor(FlowProcessor):
|
|||
is_object_uri = False
|
||||
|
||||
# Create Triple object with expanded URIs
|
||||
s_value = Value(value=subject_uri, is_uri=True)
|
||||
p_value = Value(value=predicate_uri, is_uri=True)
|
||||
o_value = Value(value=object_uri, is_uri=is_object_uri)
|
||||
s_value = make_term(subject_uri, is_uri=True)
|
||||
p_value = make_term(predicate_uri, is_uri=True)
|
||||
o_value = make_term(object_uri, is_uri=is_object_uri)
|
||||
|
||||
validated_triples.append(Triple(
|
||||
s=s_value,
|
||||
|
|
@ -609,9 +603,9 @@ class Processor(FlowProcessor):
|
|||
|
||||
# rdf:type owl:Class
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True)
|
||||
s=make_term(class_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=make_term("http://www.w3.org/2002/07/owl#Class", is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:label (stored as 'labels' in OntologyClass.__dict__)
|
||||
|
|
@ -620,18 +614,18 @@ class Processor(FlowProcessor):
|
|||
if isinstance(labels, list) and labels:
|
||||
label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0])
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=label_val, is_uri=False)
|
||||
s=make_term(class_uri, is_uri=True),
|
||||
p=make_term(RDF_LABEL, is_uri=True),
|
||||
o=make_term(label_val, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:comment (stored as 'comment' in OntologyClass.__dict__)
|
||||
if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']:
|
||||
comment = class_def['comment']
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=Value(value=comment, is_uri=False)
|
||||
s=make_term(class_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=make_term(comment, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__)
|
||||
|
|
@ -648,9 +642,9 @@ class Processor(FlowProcessor):
|
|||
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
|
||||
o=Value(value=parent_uri, is_uri=True)
|
||||
s=make_term(class_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
|
||||
o=make_term(parent_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# Generate triples for object properties
|
||||
|
|
@ -663,9 +657,9 @@ class Processor(FlowProcessor):
|
|||
|
||||
# rdf:type owl:ObjectProperty
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=make_term("http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
|
||||
|
|
@ -674,18 +668,18 @@ class Processor(FlowProcessor):
|
|||
if isinstance(labels, list) and labels:
|
||||
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=label_val, is_uri=False)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term(RDF_LABEL, is_uri=True),
|
||||
o=make_term(label_val, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
|
||||
comment = prop_def['comment']
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=Value(value=comment, is_uri=False)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=make_term(comment, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
|
||||
|
|
@ -702,9 +696,9 @@ class Processor(FlowProcessor):
|
|||
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
|
||||
o=Value(value=domain_uri, is_uri=True)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
|
||||
o=make_term(domain_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:range (stored as 'range' in OntologyProperty.__dict__)
|
||||
|
|
@ -721,9 +715,9 @@ class Processor(FlowProcessor):
|
|||
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
|
||||
o=Value(value=range_uri, is_uri=True)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
|
||||
o=make_term(range_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# Generate triples for datatype properties
|
||||
|
|
@ -736,9 +730,9 @@ class Processor(FlowProcessor):
|
|||
|
||||
# rdf:type owl:DatatypeProperty
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=make_term("http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
|
||||
|
|
@ -747,18 +741,18 @@ class Processor(FlowProcessor):
|
|||
if isinstance(labels, list) and labels:
|
||||
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=label_val, is_uri=False)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term(RDF_LABEL, is_uri=True),
|
||||
o=make_term(label_val, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
|
||||
comment = prop_def['comment']
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=Value(value=comment, is_uri=False)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=make_term(comment, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
|
||||
|
|
@ -775,9 +769,9 @@ class Processor(FlowProcessor):
|
|||
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
|
||||
o=Value(value=domain_uri, is_uri=True)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
|
||||
o=make_term(domain_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:range (datatype)
|
||||
|
|
@ -790,9 +784,9 @@ class Processor(FlowProcessor):
|
|||
range_uri = range_val
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
|
||||
o=Value(value=range_uri, is_uri=True)
|
||||
s=make_term(prop_uri, is_uri=True),
|
||||
p=make_term("http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
|
||||
o=make_term(range_uri, is_uri=True)
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements")
|
||||
|
|
@ -814,9 +808,9 @@ class Processor(FlowProcessor):
|
|||
entity_data = {} # subject_uri -> {labels: [], definitions: []}
|
||||
|
||||
for triple in triples:
|
||||
subject_uri = triple.s.value
|
||||
predicate_uri = triple.p.value
|
||||
object_val = triple.o.value
|
||||
subject_uri = triple.s.iri if triple.s.type == IRI else triple.s.value
|
||||
predicate_uri = triple.p.iri if triple.p.type == IRI else triple.p.value
|
||||
object_val = triple.o.value if triple.o.type == LITERAL else triple.o.iri
|
||||
|
||||
# Initialize entity data if not exists
|
||||
if subject_uri not in entity_data:
|
||||
|
|
@ -824,12 +818,12 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Collect labels (rdfs:label)
|
||||
if predicate_uri == RDF_LABEL:
|
||||
if not triple.o.is_uri: # Labels are literals
|
||||
if triple.o.type == LITERAL: # Labels are literals
|
||||
entity_data[subject_uri]['labels'].append(object_val)
|
||||
|
||||
# Collect definitions (skos:definition, schema:description)
|
||||
elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description":
|
||||
if not triple.o.is_uri:
|
||||
if triple.o.type == LITERAL:
|
||||
entity_data[subject_uri]['definitions'].append(object_val)
|
||||
|
||||
# Build EntityContext objects
|
||||
|
|
@ -848,7 +842,7 @@ class Processor(FlowProcessor):
|
|||
if context_parts:
|
||||
context_text = ". ".join(context_parts)
|
||||
entity_contexts.append(EntityContext(
|
||||
entity=Value(value=subject_uri, is_uri=True),
|
||||
entity=make_term(subject_uri, is_uri=True),
|
||||
context=context_text
|
||||
))
|
||||
|
||||
|
|
@ -876,6 +870,18 @@ class Processor(FlowProcessor):
|
|||
default=0.3,
|
||||
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--triples-batch-size',
|
||||
type=int,
|
||||
default=default_triples_batch_size,
|
||||
help=f'Maximum triples per output message (default: {default_triples_batch_size})'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--entity-batch-size',
|
||||
type=int,
|
||||
default=default_entity_batch_size,
|
||||
help=f'Maximum entity contexts per output message (default: {default_entity_batch_size})'
|
||||
)
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,8 +49,17 @@ class ExtractionResult:
|
|||
def parse_extraction_response(response: Any) -> Optional[ExtractionResult]:
|
||||
"""Parse LLM extraction response into structured format.
|
||||
|
||||
Supports two formats:
|
||||
1. JSONL format (list): Flat list of objects with 'type' discriminator field
|
||||
[{"type": "entity", ...}, {"type": "relationship", ...}, {"type": "attribute", ...}]
|
||||
2. Legacy format (dict): Nested structure with separate arrays
|
||||
{"entities": [...], "relationships": [...], "attributes": [...]}
|
||||
|
||||
Args:
|
||||
response: LLM response (string JSON or already parsed dict)
|
||||
response: LLM response - can be:
|
||||
- string (JSON to parse)
|
||||
- dict (legacy nested format)
|
||||
- list (JSONL format - flat list with type discriminators)
|
||||
|
||||
Returns:
|
||||
ExtractionResult with parsed entities/relationships/attributes,
|
||||
|
|
@ -64,17 +73,89 @@ def parse_extraction_response(response: Any) -> Optional[ExtractionResult]:
|
|||
logger.error(f"Failed to parse JSON response: {e}")
|
||||
logger.debug(f"Response was: {response[:500]}")
|
||||
return None
|
||||
elif isinstance(response, dict):
|
||||
elif isinstance(response, (dict, list)):
|
||||
data = response
|
||||
else:
|
||||
logger.error(f"Unexpected response type: {type(response)}")
|
||||
return None
|
||||
|
||||
# Validate structure
|
||||
if not isinstance(data, dict):
|
||||
logger.error(f"Expected dict, got {type(data)}")
|
||||
return None
|
||||
# Handle JSONL format (flat list with type discriminators)
|
||||
if isinstance(data, list):
|
||||
return parse_jsonl_format(data)
|
||||
|
||||
# Handle legacy format (nested dict)
|
||||
if isinstance(data, dict):
|
||||
return parse_legacy_format(data)
|
||||
|
||||
logger.error(f"Expected dict or list, got {type(data)}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_jsonl_format(data: List[Dict[str, Any]]) -> ExtractionResult:
|
||||
"""Parse JSONL format response (flat list with type discriminators).
|
||||
|
||||
Each item has a 'type' field: 'entity', 'relationship', or 'attribute'.
|
||||
|
||||
Args:
|
||||
data: List of dicts with type discriminator
|
||||
|
||||
Returns:
|
||||
ExtractionResult with categorized items
|
||||
"""
|
||||
entities = []
|
||||
relationships = []
|
||||
attributes = []
|
||||
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"Skipping non-dict item: {type(item)}")
|
||||
continue
|
||||
|
||||
item_type = item.get('type')
|
||||
|
||||
if item_type == 'entity':
|
||||
try:
|
||||
entity = parse_entity_jsonl(item)
|
||||
if entity:
|
||||
entities.append(entity)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse entity {item}: {e}")
|
||||
|
||||
elif item_type == 'relationship':
|
||||
try:
|
||||
relationship = parse_relationship(item)
|
||||
if relationship:
|
||||
relationships.append(relationship)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse relationship {item}: {e}")
|
||||
|
||||
elif item_type == 'attribute':
|
||||
try:
|
||||
attribute = parse_attribute(item)
|
||||
if attribute:
|
||||
attributes.append(attribute)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse attribute {item}: {e}")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown item type '{item_type}': {item}")
|
||||
|
||||
return ExtractionResult(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
attributes=attributes
|
||||
)
|
||||
|
||||
|
||||
def parse_legacy_format(data: Dict[str, Any]) -> ExtractionResult:
|
||||
"""Parse legacy format response (nested dict with arrays).
|
||||
|
||||
Args:
|
||||
data: Dict with 'entities', 'relationships', 'attributes' arrays
|
||||
|
||||
Returns:
|
||||
ExtractionResult with parsed items
|
||||
"""
|
||||
# Parse entities
|
||||
entities = []
|
||||
entities_data = data.get('entities', [])
|
||||
|
|
@ -127,6 +208,37 @@ def parse_extraction_response(response: Any) -> Optional[ExtractionResult]:
|
|||
)
|
||||
|
||||
|
||||
def parse_entity_jsonl(data: Dict[str, Any]) -> Optional[Entity]:
|
||||
"""Parse entity from JSONL format dict.
|
||||
|
||||
JSONL format uses 'entity_type' instead of 'type' for the entity's type
|
||||
(since 'type' is the discriminator field).
|
||||
|
||||
Args:
|
||||
data: Entity dict with 'entity' and 'entity_type' fields
|
||||
|
||||
Returns:
|
||||
Entity object or None if invalid
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
logger.warning(f"Entity data is not a dict: {type(data)}")
|
||||
return None
|
||||
|
||||
entity = data.get('entity')
|
||||
# JSONL format uses 'entity_type' since 'type' is the discriminator
|
||||
entity_type = data.get('entity_type')
|
||||
|
||||
if not entity or not entity_type:
|
||||
logger.warning(f"Missing required fields in entity: {data}")
|
||||
return None
|
||||
|
||||
if not isinstance(entity, str) or not isinstance(entity_type, str):
|
||||
logger.warning(f"Entity fields must be strings: {data}")
|
||||
return None
|
||||
|
||||
return Entity(entity=entity, type=entity_type)
|
||||
|
||||
|
||||
def parse_entity(data: Dict[str, Any]) -> Optional[Entity]:
|
||||
"""Parse entity from dict.
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ with full URIs and correct is_uri flags.
|
|||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .... schema import Triple, Value
|
||||
from .... schema import Triple, Term, IRI, LITERAL
|
||||
from .... rdf import RDF_TYPE, RDF_LABEL
|
||||
|
||||
from .simplified_parser import Entity, Relationship, Attribute, ExtractionResult
|
||||
|
|
@ -87,17 +87,17 @@ class TripleConverter:
|
|||
|
||||
# Generate type triple: entity rdf:type ClassURI
|
||||
type_triple = Triple(
|
||||
s=Value(value=entity_uri, is_uri=True),
|
||||
p=Value(value=RDF_TYPE, is_uri=True),
|
||||
o=Value(value=class_uri, is_uri=True)
|
||||
s=Term(type=IRI, iri=entity_uri),
|
||||
p=Term(type=IRI, iri=RDF_TYPE),
|
||||
o=Term(type=IRI, iri=class_uri)
|
||||
)
|
||||
triples.append(type_triple)
|
||||
|
||||
# Generate label triple: entity rdfs:label "entity name"
|
||||
label_triple = Triple(
|
||||
s=Value(value=entity_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=entity.entity, is_uri=False) # Literal!
|
||||
s=Term(type=IRI, iri=entity_uri),
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=entity.entity) # Literal!
|
||||
)
|
||||
triples.append(label_triple)
|
||||
|
||||
|
|
@ -131,9 +131,9 @@ class TripleConverter:
|
|||
|
||||
# Generate triple: subject property object
|
||||
return Triple(
|
||||
s=Value(value=subject_uri, is_uri=True),
|
||||
p=Value(value=property_uri, is_uri=True),
|
||||
o=Value(value=object_uri, is_uri=True)
|
||||
s=Term(type=IRI, iri=subject_uri),
|
||||
p=Term(type=IRI, iri=property_uri),
|
||||
o=Term(type=IRI, iri=object_uri)
|
||||
)
|
||||
|
||||
def convert_attribute(self, attribute: Attribute) -> Optional[Triple]:
|
||||
|
|
@ -159,9 +159,9 @@ class TripleConverter:
|
|||
|
||||
# Generate triple: entity property "literal value"
|
||||
return Triple(
|
||||
s=Value(value=entity_uri, is_uri=True),
|
||||
p=Value(value=property_uri, is_uri=True),
|
||||
o=Value(value=attribute.value, is_uri=False) # Literal!
|
||||
s=Term(type=IRI, iri=entity_uri),
|
||||
p=Term(type=IRI, iri=property_uri),
|
||||
o=Term(type=LITERAL, value=attribute.value) # Literal!
|
||||
)
|
||||
|
||||
def _get_class_uri(self, class_id: str) -> Optional[str]:
|
||||
|
|
|
|||
|
|
@ -13,18 +13,19 @@ import urllib.parse
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .... schema import Chunk, Triple, Triples
|
||||
from .... schema import Metadata, Value
|
||||
from .... schema import Metadata, Term, IRI, LITERAL
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
|
||||
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
|
||||
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
|
||||
|
||||
default_ident = "kg-extract-relationships"
|
||||
default_concurrency = 1
|
||||
default_triples_batch_size = 50
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
self.triples_batch_size = params.get("triples_batch_size", default_triples_batch_size)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -127,16 +129,16 @@ class Processor(FlowProcessor):
|
|||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
s_value = Term(type=IRI, iri=str(s_uri))
|
||||
|
||||
p_uri = self.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
p_value = Term(type=IRI, iri=str(p_uri))
|
||||
|
||||
if rel["object-entity"]:
|
||||
if rel["object-entity"]:
|
||||
o_uri = self.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
o_value = Term(type=IRI, iri=str(o_uri))
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
o_value = Term(type=LITERAL, value=str(o))
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
|
|
@ -148,14 +150,14 @@ class Processor(FlowProcessor):
|
|||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=RDF_LABEL_VALUE,
|
||||
o=Value(value=str(s), is_uri=False)
|
||||
o=Term(type=LITERAL, value=str(s))
|
||||
))
|
||||
|
||||
# Label for p
|
||||
triples.append(Triple(
|
||||
s=p_value,
|
||||
p=RDF_LABEL_VALUE,
|
||||
o=Value(value=str(p), is_uri=False)
|
||||
o=Term(type=LITERAL, value=str(p))
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
|
|
@ -163,14 +165,14 @@ class Processor(FlowProcessor):
|
|||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=RDF_LABEL_VALUE,
|
||||
o=Value(value=str(o), is_uri=False)
|
||||
o=Term(type=LITERAL, value=str(o))
|
||||
))
|
||||
|
||||
# 'Subject of' for s
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=SUBJECT_OF_VALUE,
|
||||
o=Value(value=v.metadata.id, is_uri=True)
|
||||
o=Term(type=IRI, iri=v.metadata.id)
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
|
|
@ -178,19 +180,22 @@ class Processor(FlowProcessor):
|
|||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=SUBJECT_OF_VALUE,
|
||||
o=Value(value=v.metadata.id, is_uri=True)
|
||||
o=Term(type=IRI, iri=v.metadata.id)
|
||||
))
|
||||
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples
|
||||
)
|
||||
# Send triples in batches
|
||||
for i in range(0, len(triples), self.triples_batch_size):
|
||||
batch = triples[i:i + self.triples_batch_size]
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
batch
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Relationship extraction exception: {e}", exc_info=True)
|
||||
|
|
@ -207,6 +212,13 @@ class Processor(FlowProcessor):
|
|||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--triples-batch-size',
|
||||
type=int,
|
||||
default=default_triples_batch_size,
|
||||
help=f'Maximum triples per output message (default: {default_triples_batch_size})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Object extraction service - extracts structured objects from text chunks
|
||||
Row extraction service - extracts structured rows from text chunks
|
||||
based on configured schemas.
|
||||
"""
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|||
from .... base import PromptClientSpec
|
||||
from .... messaging.translators import row_schema_translator
|
||||
|
||||
default_ident = "kg-extract-objects"
|
||||
default_ident = "kg-extract-rows"
|
||||
|
||||
|
||||
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
|
||||
|
|
@ -310,5 +310,5 @@ class Processor(FlowProcessor):
|
|||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
"""Entry point for kg-extract-objects command"""
|
||||
"""Entry point for kg-extract-rows command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -11,7 +11,7 @@ import logging
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
|
|
@ -20,7 +20,7 @@ from .... clients.prompt_client import PromptClient
|
|||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION)
|
||||
|
||||
module = "kg-extract-topics"
|
||||
|
||||
|
|
@ -106,8 +106,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
s_value = Term(type=IRI, iri=str(s_uri))
|
||||
o_value = Term(type=LITERAL, value=str(o))
|
||||
|
||||
await self.emit_edge(
|
||||
v.metadata, s_value, DEFINITION_VALUE, o_value
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class DocumentEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(DocumentEmbeddingsQueryRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=DocumentEmbeddingsRequest,
|
||||
response_schema=DocumentEmbeddingsResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("document-embeddings-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("document-embeddings-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -20,12 +20,14 @@ from . prompt import PromptRequestor
|
|||
from . graph_rag import GraphRagRequestor
|
||||
from . document_rag import DocumentRagRequestor
|
||||
from . triples_query import TriplesQueryRequestor
|
||||
from . objects_query import ObjectsQueryRequestor
|
||||
from . rows_query import RowsQueryRequestor
|
||||
from . nlp_query import NLPQueryRequestor
|
||||
from . structured_query import StructuredQueryRequestor
|
||||
from . structured_diag import StructuredDiagRequestor
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
|
@ -39,7 +41,7 @@ from . triples_import import TriplesImport
|
|||
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
from . entity_contexts_import import EntityContextsImport
|
||||
from . objects_import import ObjectsImport
|
||||
from . rows_import import RowsImport
|
||||
|
||||
from . core_export import CoreExport
|
||||
from . core_import import CoreImport
|
||||
|
|
@ -55,11 +57,13 @@ request_response_dispatchers = {
|
|||
"document-rag": DocumentRagRequestor,
|
||||
"embeddings": EmbeddingsRequestor,
|
||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"document-embeddings": DocumentEmbeddingsQueryRequestor,
|
||||
"triples": TriplesQueryRequestor,
|
||||
"objects": ObjectsQueryRequestor,
|
||||
"rows": RowsQueryRequestor,
|
||||
"nlp-query": NLPQueryRequestor,
|
||||
"structured-query": StructuredQueryRequestor,
|
||||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
}
|
||||
|
||||
global_dispatchers = {
|
||||
|
|
@ -87,7 +91,7 @@ import_dispatchers = {
|
|||
"graph-embeddings": GraphEmbeddingsImport,
|
||||
"document-embeddings": DocumentEmbeddingsImport,
|
||||
"entity-contexts": EntityContextsImport,
|
||||
"objects": ObjectsImport,
|
||||
"rows": RowsImport,
|
||||
}
|
||||
|
||||
class DispatcherWrapper:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RowEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RowEmbeddingsQueryRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RowEmbeddingsRequest,
|
||||
response_schema=RowEmbeddingsResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -12,7 +12,7 @@ from . serialize import to_subgraph
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectsImport:
|
||||
class RowsImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, backend, queue
|
||||
|
|
@ -20,7 +20,7 @@ class ObjectsImport:
|
|||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
|
||||
self.publisher = Publisher(
|
||||
backend, topic = queue, schema = ExtractedObject
|
||||
)
|
||||
|
|
@ -73,4 +73,4 @@ class ObjectsImport:
|
|||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
self.ws = None
|
||||
|
|
@ -1,30 +1,30 @@
|
|||
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from ... schema import RowsQueryRequest, RowsQueryResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class ObjectsQueryRequestor(ServiceRequestor):
|
||||
class RowsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(ObjectsQueryRequestor, self).__init__(
|
||||
super(RowsQueryRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=ObjectsQueryRequest,
|
||||
response_schema=ObjectsQueryResponse,
|
||||
request_schema=RowsQueryRequest,
|
||||
response_schema=RowsQueryResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("rows-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("rows-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
|
@ -1,46 +1,37 @@
|
|||
|
||||
import base64
|
||||
|
||||
from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata
|
||||
from ... schema import Term, Triple, DocumentMetadata, ProcessingMetadata
|
||||
from ... messaging.translators.primitives import TermTranslator, TripleTranslator
|
||||
|
||||
# Singleton translator instances
|
||||
_term_translator = TermTranslator()
|
||||
_triple_translator = TripleTranslator()
|
||||
|
||||
# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators
|
||||
# Use the new messaging translation system instead for consistency and reusability.
|
||||
# Examples:
|
||||
# from trustgraph.... messaging.translators.primitives import ValueTranslator
|
||||
# value_translator = ValueTranslator()
|
||||
# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True})
|
||||
|
||||
def to_value(x):
|
||||
return Value(value=x["v"], is_uri=x["e"])
|
||||
"""Convert dict to Term. Delegates to TermTranslator."""
|
||||
return _term_translator.to_pulsar(x)
|
||||
|
||||
|
||||
def to_subgraph(x):
|
||||
return [
|
||||
Triple(
|
||||
s=to_value(t["s"]),
|
||||
p=to_value(t["p"]),
|
||||
o=to_value(t["o"])
|
||||
)
|
||||
for t in x
|
||||
]
|
||||
"""Convert list of dicts to list of Triples. Delegates to TripleTranslator."""
|
||||
return [_triple_translator.to_pulsar(t) for t in x]
|
||||
|
||||
|
||||
def serialize_value(v):
|
||||
return {
|
||||
"v": v.value,
|
||||
"e": v.is_uri,
|
||||
}
|
||||
"""Convert Term to dict. Delegates to TermTranslator."""
|
||||
return _term_translator.from_pulsar(v)
|
||||
|
||||
|
||||
def serialize_triple(t):
|
||||
return {
|
||||
"s": serialize_value(t.s),
|
||||
"p": serialize_value(t.p),
|
||||
"o": serialize_value(t.o)
|
||||
}
|
||||
"""Convert Triple to dict. Delegates to TripleTranslator."""
|
||||
return _triple_translator.from_pulsar(t)
|
||||
|
||||
|
||||
def serialize_subgraph(sg):
|
||||
return [
|
||||
serialize_triple(t)
|
||||
for t in sg
|
||||
]
|
||||
"""Convert list of Triples to list of dicts."""
|
||||
return [serialize_triple(t) for t in sg]
|
||||
|
||||
def serialize_triples(message):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ class Processor(LlmService):
|
|||
|
||||
if stream:
|
||||
data["stream"] = True
|
||||
data["stream_options"] = {"include_usage": True}
|
||||
|
||||
body = json.dumps(data)
|
||||
|
||||
|
|
@ -191,6 +192,9 @@ class Processor(LlmService):
|
|||
if response.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Parse SSE stream
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
|
|
@ -215,15 +219,21 @@ class Processor(LlmService):
|
|||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Capture usage from final chunk
|
||||
if 'usage' in chunk_data and chunk_data['usage']:
|
||||
total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0)
|
||||
total_output_tokens = chunk_data['usage'].get('completion_tokens', 0)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chunk: {data}")
|
||||
continue
|
||||
|
||||
# Send final chunk
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -161,9 +161,13 @@ class Processor(LlmService):
|
|||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
stream=True # Enable streaming
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Stream chunks
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
|
|
@ -175,11 +179,16 @@ class Processor(LlmService):
|
|||
is_final=False
|
||||
)
|
||||
|
||||
# Send final chunk
|
||||
# Capture usage from final chunk
|
||||
if chunk.usage:
|
||||
total_input_tokens = chunk.usage.prompt_tokens
|
||||
total_output_tokens = chunk.usage.completion_tokens
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,257 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using GoogleAIStudio.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
#
|
||||
# Using this SDK:
|
||||
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
|
||||
#
|
||||
# Seems to have simpler dependencies on the 'VertexAI' service, which
|
||||
# TrustGraph implements in the trustgraph-vertexai package.
|
||||
#
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.types import HarmCategory, HarmBlockThreshold
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
default_ident = "text-completion"
|
||||
|
||||
default_model = 'gemini-2.0-flash-001'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 8192
|
||||
default_api_key = os.getenv("GOOGLE_AI_STUDIO_KEY")
|
||||
|
||||
class Processor(LlmService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key", default_api_key)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
if api_key is None:
|
||||
raise RuntimeError("Google AI Studio API key not specified")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
# Cache for generation configs per model
|
||||
self.generation_configs = {}
|
||||
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
|
||||
self.safety_settings = [
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold = block_level,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold = block_level,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
# There is a documentation conflict on whether or not
|
||||
# CIVIC_INTEGRITY is a valid category
|
||||
# HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: block_level,
|
||||
]
|
||||
|
||||
logger.info("GoogleAIStudio LLM service initialized")
|
||||
|
||||
def _get_or_create_config(self, model_name, temperature=None):
|
||||
"""Get or create generation config with dynamic temperature"""
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
# Create cache key that includes temperature to avoid conflicts
|
||||
cache_key = f"{model_name}:{effective_temperature}"
|
||||
|
||||
if cache_key not in self.generation_configs:
|
||||
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
||||
self.generation_configs[cache_key] = types.GenerateContentConfig(
|
||||
temperature = effective_temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
max_output_tokens = self.max_output,
|
||||
response_mime_type = "text/plain",
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
|
||||
return self.generation_configs[cache_key]
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
generation_config = self._get_or_create_config(model_name, effective_temperature)
|
||||
# Set system instruction per request (can't be cached)
|
||||
generation_config.system_instruction = system
|
||||
|
||||
try:
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=model_name,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
|
||||
resp = response.text
|
||||
inputtokens = int(response.usage_metadata.prompt_token_count)
|
||||
outputtokens = int(response.usage_metadata.candidates_token_count)
|
||||
logger.debug(f"LLM response: {resp}")
|
||||
logger.info(f"Input Tokens: {inputtokens}")
|
||||
logger.info(f"Output Tokens: {outputtokens}")
|
||||
|
||||
resp = LlmResult(
|
||||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
except ResourceExhausted as e:
|
||||
|
||||
logger.warning("Rate limit exceeded")
|
||||
|
||||
# Leave rate limit retries to the default handler
|
||||
raise TooManyRequests()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||
|
||||
logger.error(f"GoogleAIStudio LLM exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
def supports_streaming(self):
|
||||
"""Google AI Studio supports streaming"""
|
||||
return True
|
||||
|
||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||
"""Stream content generation from Google AI Studio"""
|
||||
model_name = model or self.default_model
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model (streaming): {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
generation_config = self._get_or_create_config(model_name, effective_temperature)
|
||||
generation_config.system_instruction = system
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content_stream(
|
||||
model=model_name,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for chunk in response:
|
||||
if hasattr(chunk, 'text') and chunk.text:
|
||||
yield LlmChunk(
|
||||
text=chunk.text,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Accumulate token counts if available
|
||||
if hasattr(chunk, 'usage_metadata'):
|
||||
if hasattr(chunk.usage_metadata, 'prompt_token_count'):
|
||||
total_input_tokens = int(chunk.usage_metadata.prompt_token_count)
|
||||
if hasattr(chunk.usage_metadata, 'candidates_token_count'):
|
||||
total_output_tokens = int(chunk.usage_metadata.candidates_token_count)
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
logger.debug("Streaming complete")
|
||||
|
||||
except ResourceExhausted:
|
||||
logger.warning("Rate limit exceeded during streaming")
|
||||
raise TooManyRequests()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GoogleAIStudio streaming exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
LlmService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'LLM model (default: {default_model})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
default=default_api_key,
|
||||
help=f'GoogleAIStudio API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -126,9 +126,13 @@ class Processor(LlmService):
|
|||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={"type": "text"},
|
||||
stream=True
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield LlmChunk(
|
||||
|
|
@ -139,10 +143,15 @@ class Processor(LlmService):
|
|||
is_final=False
|
||||
)
|
||||
|
||||
# Capture usage from final chunk
|
||||
if chunk.usage:
|
||||
total_input_tokens = chunk.usage.prompt_tokens
|
||||
total_output_tokens = chunk.usage.completion_tokens
|
||||
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class Processor(LlmService):
|
|||
)
|
||||
|
||||
self.default_model = model
|
||||
self.url = url + "v1/"
|
||||
self.url = url.rstrip('/') + "/v1/"
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.openai = OpenAI(
|
||||
|
|
@ -130,9 +130,13 @@ class Processor(LlmService):
|
|||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={"type": "text"},
|
||||
stream=True
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield LlmChunk(
|
||||
|
|
@ -143,10 +147,15 @@ class Processor(LlmService):
|
|||
is_final=False
|
||||
)
|
||||
|
||||
# Capture usage from final chunk
|
||||
if chunk.usage:
|
||||
total_input_tokens = chunk.usage.prompt_tokens
|
||||
total_output_tokens = chunk.usage.completion_tokens
|
||||
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -156,6 +156,9 @@ class Processor(LlmService):
|
|||
response_format={"type": "text"}
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.data.choices and chunk.data.choices[0].delta.content:
|
||||
yield LlmChunk(
|
||||
|
|
@ -166,11 +169,16 @@ class Processor(LlmService):
|
|||
is_final=False
|
||||
)
|
||||
|
||||
# Send final chunk
|
||||
# Capture usage data when available (typically in final chunk)
|
||||
if chunk.data.usage:
|
||||
total_input_tokens = chunk.data.usage.prompt_tokens
|
||||
total_output_tokens = chunk.data.usage.completion_tokens
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -153,9 +153,13 @@ class Processor(LlmService):
|
|||
],
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
stream=True # Enable streaming
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Stream chunks
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
|
|
@ -167,12 +171,16 @@ class Processor(LlmService):
|
|||
is_final=False
|
||||
)
|
||||
|
||||
# Note: OpenAI doesn't provide token counts in streaming mode
|
||||
# Send final chunk without token counts
|
||||
# Capture usage from final chunk
|
||||
if chunk.usage:
|
||||
total_input_tokens = chunk.usage.prompt_tokens
|
||||
total_output_tokens = chunk.usage.completion_tokens
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class Processor(LlmService):
|
|||
|
||||
try:
|
||||
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
|
|
@ -152,10 +152,14 @@ class Processor(LlmService):
|
|||
"max_tokens": self.max_output,
|
||||
"temperature": effective_temperature,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
|
|
@ -196,15 +200,21 @@ class Processor(LlmService):
|
|||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Capture usage from final chunk
|
||||
if 'usage' in chunk_data and chunk_data['usage']:
|
||||
total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0)
|
||||
total_output_tokens = chunk_data['usage'].get('completion_tokens', 0)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chunk: {data}")
|
||||
continue
|
||||
|
||||
# Send final chunk
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class Processor(LlmService):
|
|||
|
||||
try:
|
||||
|
||||
url = f"{self.base_url}/completions"
|
||||
url = f"{self.base_url.rstrip('/')}/completions"
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
|
|
@ -135,10 +135,14 @@ class Processor(LlmService):
|
|||
"max_tokens": self.max_output,
|
||||
"temperature": effective_temperature,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}/completions"
|
||||
url = f"{self.base_url.rstrip('/')}/completions"
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
|
|
@ -177,15 +181,21 @@ class Processor(LlmService):
|
|||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Capture usage from final chunk
|
||||
if 'usage' in chunk_data and chunk_data['usage']:
|
||||
total_input_tokens = chunk_data['usage'].get('prompt_tokens', 0)
|
||||
total_output_tokens = chunk_data['usage'].get('completion_tokens', 0)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chunk: {data}")
|
||||
continue
|
||||
|
||||
# Send final chunk
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ import logging
|
|||
|
||||
from .... direct.milvus_doc_embeddings import DocVectors
|
||||
from .... schema import DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import Error
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-query"
|
||||
default_ident = "doc-embeddings-query"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(DocumentEmbeddingsQueryService):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... base import DocumentEmbeddingsQueryService
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-query"
|
||||
default_ident = "doc-embeddings-query"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
|
||||
class Processor(DocumentEmbeddingsQueryService):
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from qdrant_client.models import PointStruct
|
|||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from .... schema import DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import Error
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-query"
|
||||
default_ident = "doc-embeddings-query"
|
||||
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ import logging
|
|||
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-query"
|
||||
default_ident = "graph-embeddings-query"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(GraphEmbeddingsQueryService):
|
||||
|
|
@ -33,9 +33,9 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
async def query_graph_embeddings(self, msg):
|
||||
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from pinecone import Pinecone, ServerlessSpec
|
|||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-query"
|
||||
default_ident = "graph-embeddings-query"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
|
||||
class Processor(GraphEmbeddingsQueryService):
|
||||
|
|
@ -51,9 +51,9 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
async def query_graph_embeddings(self, msg):
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from qdrant_client.models import PointStruct
|
|||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-query"
|
||||
default_ident = "graph-embeddings-query"
|
||||
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
|
@ -67,9 +67,9 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
async def query_graph_embeddings(self, msg):
|
||||
|
||||
|
|
|
|||
22
trustgraph-flow/trustgraph/query/graphql/__init__.py
Normal file
22
trustgraph-flow/trustgraph/query/graphql/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""
|
||||
Shared GraphQL utilities for row query services.
|
||||
|
||||
This module provides reusable GraphQL components including:
|
||||
- Filter types (IntFilter, StringFilter, FloatFilter)
|
||||
- Dynamic schema generation from RowSchema definitions
|
||||
- Filter parsing utilities
|
||||
"""
|
||||
|
||||
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
|
||||
from .schema import GraphQLSchemaBuilder
|
||||
from .filters import parse_filter_key, parse_where_clause
|
||||
|
||||
__all__ = [
|
||||
"IntFilter",
|
||||
"StringFilter",
|
||||
"FloatFilter",
|
||||
"SortDirection",
|
||||
"GraphQLSchemaBuilder",
|
||||
"parse_filter_key",
|
||||
"parse_where_clause",
|
||||
]
|
||||
104
trustgraph-flow/trustgraph/query/graphql/filters.py
Normal file
104
trustgraph-flow/trustgraph/query/graphql/filters.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
Filter parsing utilities for GraphQL row queries.
|
||||
|
||||
Provides functions to parse GraphQL filter objects into a normalized
|
||||
format that can be used by different query backends.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_filter_key(filter_key: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse GraphQL filter key into field name and operator.
|
||||
|
||||
Supports common GraphQL filter patterns:
|
||||
- field_name -> (field_name, "eq")
|
||||
- field_name_gt -> (field_name, "gt")
|
||||
- field_name_gte -> (field_name, "gte")
|
||||
- field_name_lt -> (field_name, "lt")
|
||||
- field_name_lte -> (field_name, "lte")
|
||||
- field_name_in -> (field_name, "in")
|
||||
|
||||
Args:
|
||||
filter_key: The filter key string from GraphQL
|
||||
|
||||
Returns:
|
||||
Tuple of (field_name, operator)
|
||||
"""
|
||||
if not filter_key:
|
||||
return ("", "eq")
|
||||
|
||||
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
||||
|
||||
for op_suffix in operators:
|
||||
if filter_key.endswith(op_suffix):
|
||||
field_name = filter_key[:-len(op_suffix)]
|
||||
operator = op_suffix[1:] # Remove the leading underscore
|
||||
return (field_name, operator)
|
||||
|
||||
# Default to equality if no operator suffix found
|
||||
return (filter_key, "eq")
|
||||
|
||||
|
||||
def parse_where_clause(where_obj) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse the idiomatic nested GraphQL filter structure into a flat dict.
|
||||
|
||||
Converts Strawberry filter objects (StringFilter, IntFilter, etc.)
|
||||
into a dictionary mapping field names with operators to values.
|
||||
|
||||
Example:
|
||||
Input: where_obj with email.eq = "foo@bar.com"
|
||||
Output: {"email": "foo@bar.com"}
|
||||
|
||||
Input: where_obj with age.gt = 21
|
||||
Output: {"age_gt": 21}
|
||||
|
||||
Args:
|
||||
where_obj: The GraphQL where clause object
|
||||
|
||||
Returns:
|
||||
Dictionary mapping field_operator keys to values
|
||||
"""
|
||||
if not where_obj:
|
||||
return {}
|
||||
|
||||
conditions = {}
|
||||
|
||||
logger.debug(f"Parsing where clause: {where_obj}")
|
||||
|
||||
for field_name, filter_obj in where_obj.__dict__.items():
|
||||
if filter_obj is None:
|
||||
continue
|
||||
|
||||
logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
||||
|
||||
if hasattr(filter_obj, '__dict__'):
|
||||
# This is a filter object (StringFilter, IntFilter, etc.)
|
||||
for operator, value in filter_obj.__dict__.items():
|
||||
if value is not None:
|
||||
logger.debug(f"Found operator {operator} with value {value}")
|
||||
# Map GraphQL operators to our internal format
|
||||
if operator == "eq":
|
||||
conditions[field_name] = value
|
||||
elif operator in ["gt", "gte", "lt", "lte"]:
|
||||
conditions[f"{field_name}_{operator}"] = value
|
||||
elif operator == "in_":
|
||||
conditions[f"{field_name}_in"] = value
|
||||
elif operator == "contains":
|
||||
conditions[f"{field_name}_contains"] = value
|
||||
elif operator == "startsWith":
|
||||
conditions[f"{field_name}_startsWith"] = value
|
||||
elif operator == "endsWith":
|
||||
conditions[f"{field_name}_endsWith"] = value
|
||||
elif operator == "not_":
|
||||
conditions[f"{field_name}_not"] = value
|
||||
elif operator == "not_in":
|
||||
conditions[f"{field_name}_not_in"] = value
|
||||
|
||||
logger.debug(f"Final parsed conditions: {conditions}")
|
||||
return conditions
|
||||
251
trustgraph-flow/trustgraph/query/graphql/schema.py
Normal file
251
trustgraph-flow/trustgraph/query/graphql/schema.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""
|
||||
Dynamic GraphQL schema generation from RowSchema definitions.
|
||||
|
||||
Provides a builder class that creates Strawberry GraphQL schemas
|
||||
from TrustGraph RowSchema definitions, with pluggable query backends.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable, Awaitable
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
from strawberry.types import Info
|
||||
|
||||
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for query callback function
|
||||
QueryCallback = Callable[
|
||||
[str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]],
|
||||
Awaitable[List[Dict[str, Any]]]
|
||||
]
|
||||
|
||||
|
||||
class GraphQLSchemaBuilder:
|
||||
"""
|
||||
Builds GraphQL schemas from RowSchema definitions.
|
||||
|
||||
This class extracts the GraphQL schema generation logic so it can be
|
||||
reused across different query backends (Cassandra, etc.).
|
||||
|
||||
Usage:
|
||||
builder = GraphQLSchemaBuilder()
|
||||
|
||||
# Add schemas
|
||||
for name, row_schema in schemas.items():
|
||||
builder.add_schema(name, row_schema)
|
||||
|
||||
# Build with a query callback
|
||||
schema = builder.build(query_callback)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.schemas: Dict[str, Any] = {} # name -> RowSchema
|
||||
self.graphql_types: Dict[str, type] = {}
|
||||
self.filter_types: Dict[str, type] = {}
|
||||
|
||||
def add_schema(self, name: str, row_schema) -> None:
|
||||
"""
|
||||
Add a RowSchema to the builder.
|
||||
|
||||
Args:
|
||||
name: The schema name (used as the GraphQL query field name)
|
||||
row_schema: The RowSchema object defining fields
|
||||
"""
|
||||
self.schemas[name] = row_schema
|
||||
self.graphql_types[name] = self._create_graphql_type(name, row_schema)
|
||||
self.filter_types[name] = self._create_filter_type(name, row_schema)
|
||||
logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all schemas from the builder."""
|
||||
self.schemas = {}
|
||||
self.graphql_types = {}
|
||||
self.filter_types = {}
|
||||
|
||||
def build(self, query_callback: QueryCallback) -> Optional[Schema]:
|
||||
"""
|
||||
Build the GraphQL schema with the provided query callback.
|
||||
|
||||
The query callback will be invoked when resolving queries, with:
|
||||
- user: str
|
||||
- collection: str
|
||||
- schema_name: str
|
||||
- row_schema: RowSchema
|
||||
- filters: Dict[str, Any]
|
||||
- limit: int
|
||||
- order_by: Optional[str]
|
||||
- direction: Optional[SortDirection]
|
||||
|
||||
It should return a list of row dictionaries.
|
||||
|
||||
Args:
|
||||
query_callback: Async function to execute queries
|
||||
|
||||
Returns:
|
||||
Strawberry Schema, or None if no schemas are loaded
|
||||
"""
|
||||
if not self.schemas:
|
||||
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
||||
return None
|
||||
|
||||
# Create the Query class with resolvers
|
||||
query_dict = {'__annotations__': {}}
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.graphql_types[schema_name]
|
||||
filter_type = self.filter_types[schema_name]
|
||||
|
||||
# Create resolver function for this schema
|
||||
resolver_func = self._make_resolver(
|
||||
schema_name, row_schema, graphql_type, filter_type, query_callback
|
||||
)
|
||||
|
||||
# Add field to query dictionary
|
||||
query_dict[schema_name] = strawberry.field(resolver=resolver_func)
|
||||
query_dict['__annotations__'][schema_name] = List[graphql_type]
|
||||
|
||||
# Create the Query class
|
||||
Query = type('Query', (), query_dict)
|
||||
Query = strawberry.type(Query)
|
||||
|
||||
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
||||
schema = strawberry.Schema(
|
||||
query=Query,
|
||||
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
||||
)
|
||||
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
||||
return schema
|
||||
|
||||
def _get_python_type(self, field_type: str):
|
||||
"""Convert schema field type to Python type for GraphQL."""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"timestamp": str, # Use string for timestamps in GraphQL
|
||||
"date": str,
|
||||
"time": str,
|
||||
"uuid": str
|
||||
}
|
||||
return type_mapping.get(field_type, str)
|
||||
|
||||
def _create_graphql_type(self, schema_name: str, row_schema) -> type:
|
||||
"""Create a GraphQL output type from a RowSchema."""
|
||||
# Create annotations for the GraphQL type
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
for field in row_schema.fields:
|
||||
python_type = self._get_python_type(field.type)
|
||||
|
||||
# Make field optional if not required
|
||||
if not field.required and not field.primary:
|
||||
annotations[field.name] = Optional[python_type]
|
||||
defaults[field.name] = None
|
||||
else:
|
||||
annotations[field.name] = python_type
|
||||
|
||||
# Create the class dynamically
|
||||
type_name = f"{schema_name.capitalize()}Type"
|
||||
graphql_class = type(
|
||||
type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry decorator
|
||||
return strawberry.type(graphql_class)
|
||||
|
||||
def _create_filter_type(self, schema_name: str, row_schema) -> type:
|
||||
"""Create a dynamic filter input type for a schema."""
|
||||
filter_type_name = f"{schema_name.capitalize()}Filter"
|
||||
|
||||
# Add __annotations__ and defaults for the fields
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
||||
|
||||
for field in row_schema.fields:
|
||||
logger.debug(
|
||||
f"Field {field.name}: type={field.type}, "
|
||||
f"indexed={field.indexed}, primary={field.primary}"
|
||||
)
|
||||
|
||||
# Allow filtering on any field
|
||||
if field.type == "integer":
|
||||
annotations[field.name] = Optional[IntFilter]
|
||||
defaults[field.name] = None
|
||||
elif field.type == "float":
|
||||
annotations[field.name] = Optional[FloatFilter]
|
||||
defaults[field.name] = None
|
||||
elif field.type == "string":
|
||||
annotations[field.name] = Optional[StringFilter]
|
||||
defaults[field.name] = None
|
||||
|
||||
logger.debug(
|
||||
f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}"
|
||||
)
|
||||
|
||||
# Create the class dynamically
|
||||
FilterType = type(
|
||||
filter_type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry input decorator
|
||||
FilterType = strawberry.input(FilterType)
|
||||
|
||||
return FilterType
|
||||
|
||||
def _make_resolver(
|
||||
self,
|
||||
schema_name: str,
|
||||
row_schema,
|
||||
graphql_type: type,
|
||||
filter_type: type,
|
||||
query_callback: QueryCallback
|
||||
):
|
||||
"""Create a resolver function for a schema."""
|
||||
from .filters import parse_where_clause
|
||||
|
||||
async def resolver(
|
||||
info: Info,
|
||||
where: Optional[filter_type] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[SortDirection] = None,
|
||||
limit: Optional[int] = 100
|
||||
) -> List[graphql_type]:
|
||||
# Get context values
|
||||
user = info.context["user"]
|
||||
collection = info.context["collection"]
|
||||
|
||||
# Parse the where clause
|
||||
filters = parse_where_clause(where)
|
||||
|
||||
# Call the query backend
|
||||
results = await query_callback(
|
||||
user, collection, schema_name, row_schema,
|
||||
filters, limit, order_by, direction
|
||||
)
|
||||
|
||||
# Convert to GraphQL types
|
||||
graphql_results = []
|
||||
for row in results:
|
||||
graphql_obj = graphql_type(**row)
|
||||
graphql_results.append(graphql_obj)
|
||||
|
||||
return graphql_results
|
||||
|
||||
return resolver
|
||||
56
trustgraph-flow/trustgraph/query/graphql/types.py
Normal file
56
trustgraph-flow/trustgraph/query/graphql/types.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
GraphQL filter and sort types for row queries.
|
||||
|
||||
These types are used to build dynamic GraphQL schemas for querying
|
||||
structured row data.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
|
||||
import strawberry
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class IntFilter:
|
||||
"""Filter type for integer fields."""
|
||||
eq: Optional[int] = None
|
||||
gt: Optional[int] = None
|
||||
gte: Optional[int] = None
|
||||
lt: Optional[int] = None
|
||||
lte: Optional[int] = None
|
||||
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[int] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[int]] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class StringFilter:
|
||||
"""Filter type for string fields."""
|
||||
eq: Optional[str] = None
|
||||
contains: Optional[str] = None
|
||||
startsWith: Optional[str] = None
|
||||
endsWith: Optional[str] = None
|
||||
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[str] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[str]] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class FloatFilter:
|
||||
"""Filter type for float fields."""
|
||||
eq: Optional[float] = None
|
||||
gt: Optional[float] = None
|
||||
gte: Optional[float] = None
|
||||
lt: Optional[float] = None
|
||||
lte: Optional[float] = None
|
||||
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[float] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[float]] = None
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class SortDirection(Enum):
|
||||
"""Sort direction for query results."""
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
|
@ -1,738 +0,0 @@
|
|||
"""
|
||||
Objects query service using GraphQL. Input is a GraphQL query with variables.
|
||||
Output is GraphQL response data with any errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
from strawberry.types import Info
|
||||
from strawberry.scalars import JSON
|
||||
from strawberry.tools import create_type
|
||||
|
||||
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from .... schema import Error, RowSchema, Field as SchemaField
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "objects-query"
|
||||
|
||||
# GraphQL filter input types
|
||||
@strawberry.input
|
||||
class IntFilter:
|
||||
eq: Optional[int] = None
|
||||
gt: Optional[int] = None
|
||||
gte: Optional[int] = None
|
||||
lt: Optional[int] = None
|
||||
lte: Optional[int] = None
|
||||
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[int] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[int]] = None
|
||||
|
||||
@strawberry.input
|
||||
class StringFilter:
|
||||
eq: Optional[str] = None
|
||||
contains: Optional[str] = None
|
||||
startsWith: Optional[str] = None
|
||||
endsWith: Optional[str] = None
|
||||
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[str] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[str]] = None
|
||||
|
||||
@strawberry.input
|
||||
class FloatFilter:
|
||||
eq: Optional[float] = None
|
||||
gt: Optional[float] = None
|
||||
gte: Optional[float] = None
|
||||
lt: Optional[float] = None
|
||||
lte: Optional[float] = None
|
||||
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
||||
not_: Optional[float] = strawberry.field(name="not", default=None)
|
||||
not_in: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = ObjectsQueryRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = ObjectsQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# GraphQL schema
|
||||
self.graphql_schema: Optional[Schema] = None
|
||||
|
||||
# GraphQL types cache
|
||||
self.graphql_types: Dict[str, type] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
# Known keyspaces and tables
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.known_tables: Dict[str, Set[str]] = {}
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def sanitize_table(self, name: str) -> str:
|
||||
"""Sanitize table names for Cassandra compatibility"""
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
|
||||
"""Parse GraphQL filter key into field name and operator"""
|
||||
if not filter_key:
|
||||
return ("", "eq")
|
||||
|
||||
# Support common GraphQL filter patterns:
|
||||
# field_name -> (field_name, "eq")
|
||||
# field_name_gt -> (field_name, "gt")
|
||||
# field_name_gte -> (field_name, "gte")
|
||||
# field_name_lt -> (field_name, "lt")
|
||||
# field_name_lte -> (field_name, "lte")
|
||||
# field_name_in -> (field_name, "in")
|
||||
|
||||
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
||||
|
||||
for op_suffix in operators:
|
||||
if filter_key.endswith(op_suffix):
|
||||
field_name = filter_key[:-len(op_suffix)]
|
||||
operator = op_suffix[1:] # Remove the leading underscore
|
||||
return (field_name, operator)
|
||||
|
||||
# Default to equality if no operator suffix found
|
||||
return (filter_key, "eq")
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
self.graphql_types = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Regenerate GraphQL schema
|
||||
self.generate_graphql_schema()
|
||||
|
||||
def get_python_type(self, field_type: str):
|
||||
"""Convert schema field type to Python type for GraphQL"""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"timestamp": str, # Use string for timestamps in GraphQL
|
||||
"date": str,
|
||||
"time": str,
|
||||
"uuid": str
|
||||
}
|
||||
return type_mapping.get(field_type, str)
|
||||
|
||||
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
|
||||
"""Create a GraphQL type from a RowSchema"""
|
||||
|
||||
# Create annotations for the GraphQL type
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
for field in row_schema.fields:
|
||||
python_type = self.get_python_type(field.type)
|
||||
|
||||
# Make field optional if not required
|
||||
if not field.required and not field.primary:
|
||||
annotations[field.name] = Optional[python_type]
|
||||
defaults[field.name] = None
|
||||
else:
|
||||
annotations[field.name] = python_type
|
||||
|
||||
# Create the class dynamically
|
||||
type_name = f"{schema_name.capitalize()}Type"
|
||||
graphql_class = type(
|
||||
type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry decorator
|
||||
return strawberry.type(graphql_class)
|
||||
|
||||
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
|
||||
"""Create a dynamic filter input type for a schema"""
|
||||
# Create the filter type dynamically
|
||||
filter_type_name = f"{schema_name.capitalize()}Filter"
|
||||
|
||||
# Add __annotations__ and defaults for the fields
|
||||
annotations = {}
|
||||
defaults = {}
|
||||
|
||||
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
||||
|
||||
for field in row_schema.fields:
|
||||
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
|
||||
|
||||
# Allow filtering on any field for now, not just indexed/primary
|
||||
# if field.indexed or field.primary:
|
||||
if field.type == "integer":
|
||||
annotations[field.name] = Optional[IntFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added IntFilter for {field.name}")
|
||||
elif field.type == "float":
|
||||
annotations[field.name] = Optional[FloatFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added FloatFilter for {field.name}")
|
||||
elif field.type == "string":
|
||||
annotations[field.name] = Optional[StringFilter]
|
||||
defaults[field.name] = None
|
||||
logger.info(f"Added StringFilter for {field.name}")
|
||||
|
||||
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
|
||||
|
||||
# Create the class dynamically
|
||||
FilterType = type(
|
||||
filter_type_name,
|
||||
(),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**defaults
|
||||
}
|
||||
)
|
||||
|
||||
# Apply strawberry input decorator
|
||||
FilterType = strawberry.input(FilterType)
|
||||
|
||||
return FilterType
|
||||
|
||||
def create_sort_direction_enum(self):
|
||||
"""Create sort direction enum"""
|
||||
@strawberry.enum
|
||||
class SortDirection(Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
return SortDirection
|
||||
|
||||
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
|
||||
"""Parse the idiomatic nested filter structure"""
|
||||
if not where_obj:
|
||||
return {}
|
||||
|
||||
conditions = {}
|
||||
|
||||
logger.info(f"Parsing where clause: {where_obj}")
|
||||
|
||||
for field_name, filter_obj in where_obj.__dict__.items():
|
||||
if filter_obj is None:
|
||||
continue
|
||||
|
||||
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
||||
|
||||
if hasattr(filter_obj, '__dict__'):
|
||||
# This is a filter object (StringFilter, IntFilter, etc.)
|
||||
for operator, value in filter_obj.__dict__.items():
|
||||
if value is not None:
|
||||
logger.info(f"Found operator {operator} with value {value}")
|
||||
# Map GraphQL operators to our internal format
|
||||
if operator == "eq":
|
||||
conditions[field_name] = value
|
||||
elif operator in ["gt", "gte", "lt", "lte"]:
|
||||
conditions[f"{field_name}_{operator}"] = value
|
||||
elif operator == "in_":
|
||||
conditions[f"{field_name}_in"] = value
|
||||
elif operator == "contains":
|
||||
conditions[f"{field_name}_contains"] = value
|
||||
|
||||
logger.info(f"Final parsed conditions: {conditions}")
|
||||
return conditions
|
||||
|
||||
def generate_graphql_schema(self):
|
||||
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
|
||||
if not self.schemas:
|
||||
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
||||
self.graphql_schema = None
|
||||
return
|
||||
|
||||
# Create GraphQL types and filter types for each schema
|
||||
filter_types = {}
|
||||
sort_direction_enum = self.create_sort_direction_enum()
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.create_graphql_type(schema_name, row_schema)
|
||||
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
|
||||
|
||||
self.graphql_types[schema_name] = graphql_type
|
||||
filter_types[schema_name] = filter_type
|
||||
|
||||
# Create the Query class with resolvers
|
||||
query_dict = {'__annotations__': {}}
|
||||
|
||||
for schema_name, row_schema in self.schemas.items():
|
||||
graphql_type = self.graphql_types[schema_name]
|
||||
filter_type = filter_types[schema_name]
|
||||
|
||||
# Create resolver function for this schema
|
||||
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
|
||||
async def resolver(
|
||||
info: Info,
|
||||
where: Optional[f_type] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[sort_enum] = None,
|
||||
limit: Optional[int] = 100
|
||||
) -> List[g_type]:
|
||||
# Get the processor instance from context
|
||||
processor = info.context["processor"]
|
||||
user = info.context["user"]
|
||||
collection = info.context["collection"]
|
||||
|
||||
# Parse the idiomatic where clause
|
||||
filters = processor.parse_idiomatic_where_clause(where)
|
||||
|
||||
# Query Cassandra
|
||||
results = await processor.query_cassandra(
|
||||
user, collection, s_name, r_schema,
|
||||
filters, limit, order_by, direction
|
||||
)
|
||||
|
||||
# Convert to GraphQL types
|
||||
graphql_results = []
|
||||
for row in results:
|
||||
graphql_obj = g_type(**row)
|
||||
graphql_results.append(graphql_obj)
|
||||
|
||||
return graphql_results
|
||||
|
||||
return resolver
|
||||
|
||||
# Add resolver to query
|
||||
resolver_name = schema_name
|
||||
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
|
||||
|
||||
# Add field to query dictionary
|
||||
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
|
||||
query_dict['__annotations__'][resolver_name] = List[graphql_type]
|
||||
|
||||
# Create the Query class
|
||||
Query = type('Query', (), query_dict)
|
||||
Query = strawberry.type(Query)
|
||||
|
||||
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
||||
self.graphql_schema = strawberry.Schema(
|
||||
query=Query,
|
||||
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
||||
)
|
||||
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
||||
|
||||
async def query_cassandra(
|
||||
self,
|
||||
user: str,
|
||||
collection: str,
|
||||
schema_name: str,
|
||||
row_schema: RowSchema,
|
||||
filters: Dict[str, Any],
|
||||
limit: int,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[Any] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query against Cassandra"""
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Build the query
|
||||
keyspace = self.sanitize_name(user)
|
||||
table = self.sanitize_table(schema_name)
|
||||
|
||||
# Start with basic SELECT
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
|
||||
# Add WHERE clauses
|
||||
where_clauses = [f"collection = %s"]
|
||||
params = [collection]
|
||||
|
||||
# Add filters for indexed or primary key fields
|
||||
for filter_key, value in filters.items():
|
||||
if value is not None:
|
||||
# Parse field name and operator from filter key
|
||||
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
|
||||
result = self.parse_filter_key(filter_key)
|
||||
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
|
||||
|
||||
if not result or len(result) != 2:
|
||||
logger.error(f"parse_filter_key returned invalid result: {result}")
|
||||
continue # Skip this filter
|
||||
|
||||
field_name, operator = result
|
||||
|
||||
# Find the field in schema
|
||||
schema_field = None
|
||||
for f in row_schema.fields:
|
||||
if f.name == field_name:
|
||||
schema_field = f
|
||||
break
|
||||
|
||||
if schema_field:
|
||||
safe_field = self.sanitize_name(field_name)
|
||||
|
||||
# Build WHERE clause based on operator
|
||||
if operator == "eq":
|
||||
where_clauses.append(f"{safe_field} = %s")
|
||||
params.append(value)
|
||||
elif operator == "gt":
|
||||
where_clauses.append(f"{safe_field} > %s")
|
||||
params.append(value)
|
||||
elif operator == "gte":
|
||||
where_clauses.append(f"{safe_field} >= %s")
|
||||
params.append(value)
|
||||
elif operator == "lt":
|
||||
where_clauses.append(f"{safe_field} < %s")
|
||||
params.append(value)
|
||||
elif operator == "lte":
|
||||
where_clauses.append(f"{safe_field} <= %s")
|
||||
params.append(value)
|
||||
elif operator == "in":
|
||||
if isinstance(value, list):
|
||||
placeholders = ",".join(["%s"] * len(value))
|
||||
where_clauses.append(f"{safe_field} IN ({placeholders})")
|
||||
params.extend(value)
|
||||
else:
|
||||
# Default to equality for unknown operators
|
||||
where_clauses.append(f"{safe_field} = %s")
|
||||
params.append(value)
|
||||
|
||||
if where_clauses:
|
||||
query += " WHERE " + " AND ".join(where_clauses)
|
||||
|
||||
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
|
||||
cassandra_order_by_added = False
|
||||
if order_by and direction:
|
||||
# Validate that order_by field exists in schema
|
||||
order_field_exists = any(f.name == order_by for f in row_schema.fields)
|
||||
if order_field_exists:
|
||||
safe_order_field = self.sanitize_name(order_by)
|
||||
direction_str = "ASC" if direction.value == "asc" else "DESC"
|
||||
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
|
||||
query += f" ORDER BY {safe_order_field} {direction_str}"
|
||||
|
||||
# Add limit first (must come before ALLOW FILTERING)
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
|
||||
query += " ALLOW FILTERING"
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
result = self.session.execute(query, params)
|
||||
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
|
||||
except Exception as e:
|
||||
# If ORDER BY fails, try without it
|
||||
if order_by and direction and "ORDER BY" in query:
|
||||
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
|
||||
# Remove ORDER BY clause and retry
|
||||
query_parts = query.split(" ORDER BY ")
|
||||
if len(query_parts) == 2:
|
||||
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
|
||||
result = self.session.execute(query_without_order, params)
|
||||
cassandra_order_by_added = False
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
|
||||
# Convert rows to dicts
|
||||
results = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
for field in row_schema.fields:
|
||||
safe_field = self.sanitize_name(field.name)
|
||||
if hasattr(row, safe_field):
|
||||
value = getattr(row, safe_field)
|
||||
# Use original field name in result
|
||||
row_dict[field.name] = value
|
||||
results.append(row_dict)
|
||||
|
||||
# Post-query sorting if Cassandra didn't handle ORDER BY
|
||||
if order_by and direction and not cassandra_order_by_added:
|
||||
reverse_order = (direction.value == "desc")
|
||||
try:
|
||||
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
async def execute_graphql_query(
|
||||
self,
|
||||
query: str,
|
||||
variables: Dict[str, Any],
|
||||
operation_name: Optional[str],
|
||||
user: str,
|
||||
collection: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a GraphQL query"""
|
||||
|
||||
if not self.graphql_schema:
|
||||
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
||||
|
||||
# Create context for the query
|
||||
context = {
|
||||
"processor": self,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
# Execute the query
|
||||
result = await self.graphql_schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
operation_name=operation_name,
|
||||
context_value=context
|
||||
)
|
||||
|
||||
# Build response
|
||||
response = {}
|
||||
|
||||
if result.data:
|
||||
response["data"] = result.data
|
||||
else:
|
||||
response["data"] = None
|
||||
|
||||
if result.errors:
|
||||
response["errors"] = [
|
||||
{
|
||||
"message": str(error),
|
||||
"path": getattr(error, "path", []),
|
||||
"extensions": getattr(error, "extensions", {})
|
||||
}
|
||||
for error in result.errors
|
||||
]
|
||||
else:
|
||||
response["errors"] = []
|
||||
|
||||
# Add extensions if any
|
||||
if hasattr(result, "extensions") and result.extensions:
|
||||
response["extensions"] = result.extensions
|
||||
|
||||
return response
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling objects query request {id}...")
|
||||
|
||||
# Execute GraphQL query
|
||||
result = await self.execute_graphql_query(
|
||||
query=request.query,
|
||||
variables=dict(request.variables) if request.variables else {},
|
||||
operation_name=request.operation_name,
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Create response
|
||||
graphql_errors = []
|
||||
if "errors" in result and result["errors"]:
|
||||
for err in result["errors"]:
|
||||
graphql_error = GraphQLError(
|
||||
message=err.get("message", ""),
|
||||
path=err.get("path", []),
|
||||
extensions=err.get("extensions", {})
|
||||
)
|
||||
graphql_errors.append(graphql_error)
|
||||
|
||||
response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
||||
errors=graphql_errors,
|
||||
extensions=result.get("extensions", {})
|
||||
)
|
||||
|
||||
logger.debug("Sending objects query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.debug("Objects query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in objects query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = ObjectsQueryResponse(
|
||||
error = Error(
|
||||
type = "objects-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
data = None,
|
||||
errors = [],
|
||||
extensions = {}
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for objects-query-graphql-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Row embeddings query modules.
|
||||
"""
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Qdrant row embeddings query service.
|
||||
"""
|
||||
|
||||
from .service import Processor, run, default_ident
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
from .service import run
|
||||
|
||||
run()
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
"""
|
||||
Row embeddings query service for Qdrant.
|
||||
|
||||
Input is query vectors plus user/collection/schema context.
|
||||
Output is matching row index information (index_name, index_value) for
|
||||
use in subsequent Cassandra lookups.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
|
||||
from .... schema import (
|
||||
RowEmbeddingsRequest, RowEmbeddingsResponse,
|
||||
RowIndexMatch, Error
|
||||
)
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "row-embeddings-query"
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowEmbeddingsRequest,
|
||||
handler=self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="response",
|
||||
schema=RowEmbeddingsResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Qdrant collection naming"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
|
||||
"""Find the Qdrant collection for a given user/collection/schema"""
|
||||
prefix = (
|
||||
f"rows_{self.sanitize_name(user)}_"
|
||||
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||
)
|
||||
|
||||
try:
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
matching = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if matching:
|
||||
# Return first match (there should typically be only one per dimension)
|
||||
return matching[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
||||
"""Execute row embeddings query"""
|
||||
|
||||
matches = []
|
||||
|
||||
# Find the collection for this user/collection/schema
|
||||
qdrant_collection = self.find_collection(
|
||||
request.user, request.collection, request.schema_name
|
||||
)
|
||||
|
||||
if not qdrant_collection:
|
||||
logger.info(
|
||||
f"No Qdrant collection found for "
|
||||
f"{request.user}/{request.collection}/{request.schema_name}"
|
||||
)
|
||||
return matches
|
||||
|
||||
for vec in request.vectors:
|
||||
try:
|
||||
# Build optional filter for index_name
|
||||
query_filter = None
|
||||
if request.index_name:
|
||||
query_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="index_name",
|
||||
match=MatchValue(value=request.index_name)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Query Qdrant
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=qdrant_collection,
|
||||
query=vec,
|
||||
limit=request.limit,
|
||||
with_payload=True,
|
||||
query_filter=query_filter,
|
||||
).points
|
||||
|
||||
# Convert to RowIndexMatch objects
|
||||
for point in search_result:
|
||||
payload = point.payload or {}
|
||||
match = RowIndexMatch(
|
||||
index_name=payload.get("index_name", ""),
|
||||
index_value=payload.get("index_value", []),
|
||||
text=payload.get("text", ""),
|
||||
score=point.score if hasattr(point, 'score') else 0.0
|
||||
)
|
||||
matches.append(match)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
return matches
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(
|
||||
f"Handling row embeddings query for "
|
||||
f"{request.user}/{request.collection}/{request.schema_name}..."
|
||||
)
|
||||
|
||||
# Execute query
|
||||
matches = await self.query_row_embeddings(request)
|
||||
|
||||
response = RowEmbeddingsResponse(
|
||||
error=None,
|
||||
matches=matches
|
||||
)
|
||||
|
||||
logger.debug(f"Returning {len(matches)} matches")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in row embeddings query: {e}", exc_info=True)
|
||||
|
||||
response = RowEmbeddingsResponse(
|
||||
error=Error(
|
||||
type="row-embeddings-query-error",
|
||||
message=str(e)
|
||||
),
|
||||
matches=[]
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
default=None,
|
||||
help='API key for Qdrant (default: None)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for row-embeddings-query-qdrant command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
523
trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Normal file
523
trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Normal file
|
|
@ -0,0 +1,523 @@
|
|||
"""
|
||||
Row query service using GraphQL. Input is a GraphQL query with variables.
|
||||
Output is GraphQL response data with any errors.
|
||||
|
||||
Queries against the unified 'rows' table with schema:
|
||||
- collection: text
|
||||
- schema_name: text
|
||||
- index_name: text
|
||||
- index_value: frozen<list<text>>
|
||||
- data: map<text, text>
|
||||
- source: text
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from .... schema import Error, RowSchema, Field as SchemaField
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "rows-query"
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowsQueryRequest,
|
||||
handler=self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="response",
|
||||
schema=RowsQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# GraphQL schema builder and generated schema
|
||||
self.schema_builder = GraphQLSchemaBuilder()
|
||||
self.graphql_schema = None
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
# Known keyspaces
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
self.schema_builder.clear()
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
self.schema_builder.add_schema(schema_name, row_schema)
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Regenerate GraphQL schema
|
||||
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
"""Get all index names for a schema."""
|
||||
index_names = []
|
||||
for field in schema.fields:
|
||||
if field.primary or field.indexed:
|
||||
index_names.append(field.name)
|
||||
return index_names
|
||||
|
||||
def find_matching_index(
|
||||
self,
|
||||
schema: RowSchema,
|
||||
filters: Dict[str, Any]
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
Find an index that can satisfy the query filters.
|
||||
Returns (index_name, index_value) if found, None otherwise.
|
||||
|
||||
For exact match queries, we need a filter on an indexed field.
|
||||
"""
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
# Look for an exact match filter on an indexed field
|
||||
for index_name in index_names:
|
||||
if index_name in filters:
|
||||
value = filters[index_name]
|
||||
# Single field index -> single element list
|
||||
index_value = [str(value)]
|
||||
return (index_name, index_value)
|
||||
|
||||
return None
|
||||
|
||||
async def query_cassandra(
|
||||
self,
|
||||
user: str,
|
||||
collection: str,
|
||||
schema_name: str,
|
||||
row_schema: RowSchema,
|
||||
filters: Dict[str, Any],
|
||||
limit: int,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[SortDirection] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a query against the unified Cassandra rows table.
|
||||
|
||||
For exact match queries on indexed fields, we can query directly.
|
||||
For other queries, we need to scan and post-filter.
|
||||
"""
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Try to find an index that matches the filters
|
||||
index_match = self.find_matching_index(row_schema, filters)
|
||||
|
||||
results = []
|
||||
|
||||
if index_match:
|
||||
# Direct query using index
|
||||
index_name, index_value = index_match
|
||||
|
||||
query = f"""
|
||||
SELECT data, source FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s
|
||||
AND schema_name = %s
|
||||
AND index_name = %s
|
||||
AND index_value = %s
|
||||
"""
|
||||
params = [collection, schema_name, index_name, index_value]
|
||||
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
try:
|
||||
rows = self.session.execute(query, params)
|
||||
for row in rows:
|
||||
# Convert data map to dict with proper field names
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
results.append(row_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
else:
|
||||
# No direct index match - scan all rows for this schema
|
||||
# This is less efficient but necessary for non-indexed queries
|
||||
logger.warning(
|
||||
f"No index match for filters {filters} - scanning all indexes"
|
||||
)
|
||||
|
||||
# Get all index names for this schema
|
||||
index_names = self.get_index_names(row_schema)
|
||||
|
||||
if not index_names:
|
||||
logger.warning(f"Schema {schema_name} has no indexes")
|
||||
return []
|
||||
|
||||
# Query using the first index (arbitrary choice for scan)
|
||||
primary_index = index_names[0]
|
||||
|
||||
# We need to scan all values for this index
|
||||
# This requires ALLOW FILTERING or a different approach
|
||||
query = f"""
|
||||
SELECT data, source FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s
|
||||
AND schema_name = %s
|
||||
AND index_name = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
params = [collection, schema_name, primary_index]
|
||||
|
||||
try:
|
||||
rows = self.session.execute(query, params)
|
||||
|
||||
for row in rows:
|
||||
row_dict = dict(row.data) if row.data else {}
|
||||
|
||||
# Apply post-filters
|
||||
if self._matches_filters(row_dict, filters, row_schema):
|
||||
results.append(row_dict)
|
||||
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Post-query sorting if requested
|
||||
if order_by and results:
|
||||
reverse_order = direction and direction.value == "desc"
|
||||
try:
|
||||
results.sort(
|
||||
key=lambda x: x.get(order_by, ""),
|
||||
reverse=reverse_order
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def _matches_filters(
|
||||
self,
|
||||
row_dict: Dict[str, Any],
|
||||
filters: Dict[str, Any],
|
||||
row_schema: RowSchema
|
||||
) -> bool:
|
||||
"""Check if a row matches the given filters."""
|
||||
for filter_key, filter_value in filters.items():
|
||||
if filter_value is None:
|
||||
continue
|
||||
|
||||
# Parse filter key for operator
|
||||
if '_' in filter_key:
|
||||
parts = filter_key.rsplit('_', 1)
|
||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
||||
field_name = parts[0]
|
||||
operator = parts[1]
|
||||
else:
|
||||
field_name = filter_key
|
||||
operator = 'eq'
|
||||
else:
|
||||
field_name = filter_key
|
||||
operator = 'eq'
|
||||
|
||||
row_value = row_dict.get(field_name)
|
||||
if row_value is None:
|
||||
return False
|
||||
|
||||
# Convert types for comparison
|
||||
try:
|
||||
if operator == 'eq':
|
||||
if str(row_value) != str(filter_value):
|
||||
return False
|
||||
elif operator == 'gt':
|
||||
if float(row_value) <= float(filter_value):
|
||||
return False
|
||||
elif operator == 'gte':
|
||||
if float(row_value) < float(filter_value):
|
||||
return False
|
||||
elif operator == 'lt':
|
||||
if float(row_value) >= float(filter_value):
|
||||
return False
|
||||
elif operator == 'lte':
|
||||
if float(row_value) > float(filter_value):
|
||||
return False
|
||||
elif operator == 'contains':
|
||||
if str(filter_value) not in str(row_value):
|
||||
return False
|
||||
elif operator == 'in':
|
||||
if str(row_value) not in [str(v) for v in filter_value]:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def execute_graphql_query(
|
||||
self,
|
||||
query: str,
|
||||
variables: Dict[str, Any],
|
||||
operation_name: Optional[str],
|
||||
user: str,
|
||||
collection: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a GraphQL query"""
|
||||
|
||||
if not self.graphql_schema:
|
||||
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
||||
|
||||
# Create context for the query
|
||||
context = {
|
||||
"processor": self,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
# Execute the query
|
||||
result = await self.graphql_schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
operation_name=operation_name,
|
||||
context_value=context
|
||||
)
|
||||
|
||||
# Build response
|
||||
response = {}
|
||||
|
||||
if result.data:
|
||||
response["data"] = result.data
|
||||
else:
|
||||
response["data"] = None
|
||||
|
||||
if result.errors:
|
||||
response["errors"] = [
|
||||
{
|
||||
"message": str(error),
|
||||
"path": getattr(error, "path", []),
|
||||
"extensions": getattr(error, "extensions", {})
|
||||
}
|
||||
for error in result.errors
|
||||
]
|
||||
else:
|
||||
response["errors"] = []
|
||||
|
||||
# Add extensions if any
|
||||
if hasattr(result, "extensions") and result.extensions:
|
||||
response["extensions"] = result.extensions
|
||||
|
||||
return response
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling objects query request {id}...")
|
||||
|
||||
# Execute GraphQL query
|
||||
result = await self.execute_graphql_query(
|
||||
query=request.query,
|
||||
variables=dict(request.variables) if request.variables else {},
|
||||
operation_name=request.operation_name,
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Create response
|
||||
graphql_errors = []
|
||||
if "errors" in result and result["errors"]:
|
||||
for err in result["errors"]:
|
||||
graphql_error = GraphQLError(
|
||||
message=err.get("message", ""),
|
||||
path=err.get("path", []),
|
||||
extensions=err.get("extensions", {})
|
||||
)
|
||||
graphql_errors.append(graphql_error)
|
||||
|
||||
response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
||||
errors=graphql_errors,
|
||||
extensions=result.get("extensions", {})
|
||||
)
|
||||
|
||||
logger.debug("Sending objects query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.debug("Objects query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in rows query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = RowsQueryResponse(
|
||||
error=Error(
|
||||
type="rows-query-error",
|
||||
message=str(e),
|
||||
),
|
||||
data=None,
|
||||
errors=[],
|
||||
extensions={}
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for rows-query-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
|
||||
"""
|
||||
Triples query service. Input is a (s, p, o) triple, some values may be
|
||||
null. Output is a list of triples.
|
||||
Triples query service. Input is a (s, p, o, g) quad pattern, some values may be
|
||||
null. Output is a list of quads.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .... direct.cassandra_kg import KnowledgeGraph
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
||||
)
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import Term, Triple, IRI, LITERAL
|
||||
from .... base import TriplesQueryService
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
|
|
@ -18,6 +20,56 @@ logger = logging.getLogger(__name__)
|
|||
default_ident = "triples-query"
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
def create_term(value, otype=None, dtype=None, lang=None):
|
||||
"""
|
||||
Create a Term from a string value, optionally using type metadata.
|
||||
|
||||
Args:
|
||||
value: The string value
|
||||
otype: Object type - 'u' (URI), 'l' (literal), 't' (triple)
|
||||
dtype: XSD datatype (for literals)
|
||||
lang: Language tag (for literals)
|
||||
|
||||
If otype is provided, uses it to determine Term type.
|
||||
Otherwise falls back to URL detection heuristic.
|
||||
"""
|
||||
if otype is not None:
|
||||
if otype == 'u':
|
||||
return Term(type=IRI, iri=value)
|
||||
elif otype == 'l':
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value=value,
|
||||
datatype=dtype or "",
|
||||
language=lang or ""
|
||||
)
|
||||
elif otype == 't':
|
||||
# Triple/reification - treat as IRI for now
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
# Unknown otype, fall back to heuristic
|
||||
pass
|
||||
|
||||
# Heuristic fallback for backwards compatibility
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
return Term(type=LITERAL, value=value)
|
||||
|
||||
|
||||
class Processor(TriplesQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -46,12 +98,6 @@ class Processor(TriplesQueryService):
|
|||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
|
@ -59,90 +105,137 @@ class Processor(TriplesQueryService):
|
|||
user = query.user
|
||||
|
||||
if user != self.table:
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
)
|
||||
self.table = user
|
||||
|
||||
triples = []
|
||||
# Extract values from query
|
||||
s_val = get_term_value(query.s)
|
||||
p_val = get_term_value(query.p)
|
||||
o_val = get_term_value(query.o)
|
||||
g_val = query.g # Already a string or None
|
||||
|
||||
if query.s is not None:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
# Helper to extract object metadata from result row
|
||||
def get_o_metadata(t):
|
||||
"""Extract otype/dtype/lang from result row if available"""
|
||||
otype = getattr(t, 'otype', None)
|
||||
dtype = getattr(t, 'dtype', None)
|
||||
lang = getattr(t, 'lang', None)
|
||||
return otype, dtype, lang
|
||||
|
||||
quads = []
|
||||
|
||||
# Route to appropriate query method based on which fields are specified
|
||||
if s_val is not None:
|
||||
if p_val is not None:
|
||||
if o_val is not None:
|
||||
# SPO specified - find matching graphs
|
||||
resp = self.tg.get_spo(
|
||||
query.collection, query.s.value, query.p.value, query.o.value,
|
||||
query.collection, s_val, p_val, o_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, p_val, o_val, g, otype, dtype, lang))
|
||||
else:
|
||||
# SP specified
|
||||
resp = self.tg.get_sp(
|
||||
query.collection, query.s.value, query.p.value,
|
||||
query.collection, s_val, p_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, query.p.value, t.o))
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, p_val, t.o, g, otype, dtype, lang))
|
||||
else:
|
||||
if query.o is not None:
|
||||
if o_val is not None:
|
||||
# SO specified
|
||||
resp = self.tg.get_os(
|
||||
query.collection, query.o.value, query.s.value,
|
||||
query.collection, o_val, s_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, t.p, query.o.value))
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, t.p, o_val, g, otype, dtype, lang))
|
||||
else:
|
||||
# S only
|
||||
resp = self.tg.get_s(
|
||||
query.collection, query.s.value,
|
||||
query.collection, s_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, t.p, t.o))
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, t.p, t.o, g, otype, dtype, lang))
|
||||
else:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
if p_val is not None:
|
||||
if o_val is not None:
|
||||
# PO specified
|
||||
resp = self.tg.get_po(
|
||||
query.collection, query.p.value, query.o.value,
|
||||
query.collection, p_val, o_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, query.p.value, query.o.value))
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, p_val, o_val, g, otype, dtype, lang))
|
||||
else:
|
||||
# P only
|
||||
resp = self.tg.get_p(
|
||||
query.collection, query.p.value,
|
||||
query.collection, p_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, query.p.value, t.o))
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, p_val, t.o, g, otype, dtype, lang))
|
||||
else:
|
||||
if query.o is not None:
|
||||
if o_val is not None:
|
||||
# O only
|
||||
resp = self.tg.get_o(
|
||||
query.collection, query.o.value,
|
||||
query.collection, o_val, g=g_val,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, query.o.value))
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, t.p, o_val, g, otype, dtype, lang))
|
||||
else:
|
||||
# Nothing specified - get all
|
||||
resp = self.tg.get_all(
|
||||
query.collection,
|
||||
limit=query.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, t.o))
|
||||
# Note: quads_by_collection uses 'd' for graph field
|
||||
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, t.p, t.o, g, otype, dtype, lang))
|
||||
|
||||
# Convert to Triple objects (with g field)
|
||||
# Use otype/dtype/lang for proper Term reconstruction if available
|
||||
triples = [
|
||||
Triple(
|
||||
s=self.create_value(t[0]),
|
||||
p=self.create_value(t[1]),
|
||||
o=self.create_value(t[2])
|
||||
s=create_term(q[0]),
|
||||
p=create_term(q[1]),
|
||||
o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]),
|
||||
g=q[3] if q[3] != DEFAULT_GRAPH else None
|
||||
)
|
||||
for t in triples
|
||||
for q in quads
|
||||
]
|
||||
|
||||
return triples
|
||||
|
|
@ -162,4 +255,3 @@ class Processor(TriplesQueryService):
|
|||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,24 @@ import logging
|
|||
from falkordb import FalkorDB
|
||||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import Term, Triple, IRI, LITERAL
|
||||
from .... base import TriplesQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_url = 'falkor://falkordb:6379'
|
||||
|
|
@ -42,9 +54,9 @@ class Processor(TriplesQueryService):
|
|||
def create_value(self, ent):
|
||||
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
|
|
@ -63,28 +75,28 @@ class Processor(TriplesQueryService):
|
|||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
"value": query.o.value,
|
||||
"src": get_term_value(query.s),
|
||||
"rel": get_term_value(query.p),
|
||||
"value": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
"uri": query.o.value,
|
||||
"src": get_term_value(query.s),
|
||||
"rel": get_term_value(query.p),
|
||||
"uri": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -95,26 +107,26 @@ class Processor(TriplesQueryService):
|
|||
"RETURN dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
"src": get_term_value(query.s),
|
||||
"rel": get_term_value(query.p),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, rec[0]))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), rec[0]))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
"src": get_term_value(query.s),
|
||||
"rel": get_term_value(query.p),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, rec[0]))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), rec[0]))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -127,26 +139,26 @@ class Processor(TriplesQueryService):
|
|||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"value": query.o.value,
|
||||
"src": get_term_value(query.s),
|
||||
"value": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, rec[0], query.o.value))
|
||||
triples.append((get_term_value(query.s), rec[0], get_term_value(query.o)))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"uri": query.o.value,
|
||||
"src": get_term_value(query.s),
|
||||
"uri": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, rec[0], query.o.value))
|
||||
triples.append((get_term_value(query.s), rec[0], get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -157,24 +169,24 @@ class Processor(TriplesQueryService):
|
|||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"src": get_term_value(query.s),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, rec[0], rec[1]))
|
||||
triples.append((get_term_value(query.s), rec[0], rec[1]))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": query.s.value,
|
||||
"src": get_term_value(query.s),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, rec[0], rec[1]))
|
||||
triples.append((get_term_value(query.s), rec[0], rec[1]))
|
||||
|
||||
|
||||
else:
|
||||
|
|
@ -190,26 +202,26 @@ class Processor(TriplesQueryService):
|
|||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": query.p.value,
|
||||
"value": query.o.value,
|
||||
"uri": get_term_value(query.p),
|
||||
"value": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], query.p.value, query.o.value))
|
||||
triples.append((rec[0], get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": query.p.value,
|
||||
"dest": query.o.value,
|
||||
"uri": get_term_value(query.p),
|
||||
"dest": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], query.p.value, query.o.value))
|
||||
triples.append((rec[0], get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -220,24 +232,24 @@ class Processor(TriplesQueryService):
|
|||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": query.p.value,
|
||||
"uri": get_term_value(query.p),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], query.p.value, rec[1]))
|
||||
triples.append((rec[0], get_term_value(query.p), rec[1]))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"RETURN src.uri as src, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": query.p.value,
|
||||
"uri": get_term_value(query.p),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], query.p.value, rec[1]))
|
||||
triples.append((rec[0], get_term_value(query.p), rec[1]))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -250,24 +262,24 @@ class Processor(TriplesQueryService):
|
|||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"value": query.o.value,
|
||||
"value": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], rec[1], query.o.value))
|
||||
triples.append((rec[0], rec[1], get_term_value(query.o)))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": query.o.value,
|
||||
"uri": get_term_value(query.o),
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], rec[1], query.o.value))
|
||||
triples.append((rec[0], rec[1], get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,24 @@ import logging
|
|||
from neo4j import GraphDatabase
|
||||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import Term, Triple, IRI, LITERAL
|
||||
from .... base import TriplesQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_host = 'bolt://memgraph:7687'
|
||||
|
|
@ -47,9 +59,9 @@ class Processor(TriplesQueryService):
|
|||
def create_value(self, ent):
|
||||
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
|
|
@ -73,13 +85,13 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value, value=query.o.value,
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
|
|
@ -87,13 +99,13 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value, uri=query.o.value,
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -105,14 +117,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value,
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
|
|
@ -120,14 +132,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value,
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -141,14 +153,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, value=query.o.value,
|
||||
src=get_term_value(query.s), value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
|
|
@ -156,14 +168,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, uri=query.o.value,
|
||||
src=get_term_value(query.s), uri=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -175,14 +187,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value,
|
||||
src=get_term_value(query.s),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
|
|
@ -190,14 +202,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value,
|
||||
src=get_term_value(query.s),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
|
||||
|
||||
|
||||
else:
|
||||
|
|
@ -214,14 +226,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value, value=query.o.value,
|
||||
uri=get_term_value(query.p), value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
|
|
@ -229,14 +241,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value, dest=query.o.value,
|
||||
uri=get_term_value(query.p), dest=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -248,14 +260,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value,
|
||||
uri=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
triples.append((data["src"], get_term_value(query.p), data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
|
|
@ -263,14 +275,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value,
|
||||
uri=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
triples.append((data["src"], get_term_value(query.p), data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -284,14 +296,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
value=query.o.value,
|
||||
value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
triples.append((data["src"], data["rel"], get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
|
|
@ -299,14 +311,14 @@ class Processor(TriplesQueryService):
|
|||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.o.value,
|
||||
uri=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
triples.append((data["src"], data["rel"], get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,24 @@ import logging
|
|||
from neo4j import GraphDatabase
|
||||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import Term, Triple, IRI, LITERAL
|
||||
from .... base import TriplesQueryService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_host = 'bolt://neo4j:7687'
|
||||
|
|
@ -47,9 +59,9 @@ class Processor(TriplesQueryService):
|
|||
def create_value(self, ent):
|
||||
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
|
|
@ -71,27 +83,29 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src",
|
||||
src=query.s.value, rel=query.p.value, value=query.o.value,
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src",
|
||||
src=query.s.value, rel=query.p.value, uri=query.o.value,
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -101,29 +115,31 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN dest.value as dest",
|
||||
src=query.s.value, rel=query.p.value,
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest",
|
||||
src=query.s.value, rel=query.p.value,
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s), rel=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -135,29 +151,31 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=query.s.value, value=query.o.value,
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s), value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=query.s.value, uri=query.o.value,
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s), uri=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -167,29 +185,31 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest",
|
||||
src=query.s.value,
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest",
|
||||
src=query.s.value,
|
||||
"RETURN rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
src=get_term_value(query.s),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
|
||||
|
||||
|
||||
else:
|
||||
|
|
@ -204,29 +224,31 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=query.p.value, value=query.o.value,
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=get_term_value(query.p), value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=query.p.value, dest=query.o.value,
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=get_term_value(query.p), dest=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -236,29 +258,31 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest",
|
||||
uri=query.p.value,
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
triples.append((data["src"], get_term_value(query.p), data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.uri as dest",
|
||||
uri=query.p.value,
|
||||
"RETURN src.uri as src, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=get_term_value(query.p),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
triples.append((data["src"], get_term_value(query.p), data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -270,29 +294,31 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
value=query.o.value,
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
value=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
triples.append((data["src"], data["rel"], get_term_value(query.o)))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
uri=query.o.value,
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=get_term_value(query.o),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
triples.append((data["src"], data["rel"], get_term_value(query.o)))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -302,7 +328,8 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
|
@ -315,7 +342,8 @@ class Processor(TriplesQueryService):
|
|||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
user=user, collection=collection,
|
||||
database_=self.db,
|
||||
)
|
||||
|
|
@ -327,10 +355,10 @@ class Processor(TriplesQueryService):
|
|||
triples = [
|
||||
Triple(
|
||||
s=self.create_value(t[0]),
|
||||
p=self.create_value(t[1]),
|
||||
p=self.create_value(t[1]),
|
||||
o=self.create_value(t[2])
|
||||
)
|
||||
for t in triples
|
||||
for t in triples[:query.limit]
|
||||
]
|
||||
|
||||
return triples
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Structured Query Service - orchestrates natural language question processing.
|
||||
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
|
||||
Takes a question, converts it to GraphQL via nlp-query, executes via rows-query,
|
||||
and returns the results.
|
||||
"""
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ from typing import Dict, Any, Optional
|
|||
|
||||
from ...schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
|
||||
from ...schema import RowsQueryRequest, RowsQueryResponse
|
||||
from ...schema import Error
|
||||
|
||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
|
||||
|
|
@ -57,13 +57,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Client spec for calling objects query service
|
||||
# Client spec for calling rows query service
|
||||
self.register_specification(
|
||||
RequestResponseSpec(
|
||||
request_name = "objects-query-request",
|
||||
response_name = "objects-query-response",
|
||||
request_schema = ObjectsQueryRequest,
|
||||
response_schema = ObjectsQueryResponse
|
||||
request_name = "rows-query-request",
|
||||
response_name = "rows-query-response",
|
||||
request_schema = RowsQueryRequest,
|
||||
response_schema = RowsQueryResponse
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ class Processor(FlowProcessor):
|
|||
variables_as_strings[key] = str(value)
|
||||
|
||||
# Use user/collection values from request
|
||||
objects_request = ObjectsQueryRequest(
|
||||
objects_request = RowsQueryRequest(
|
||||
user=request.user,
|
||||
collection=request.collection,
|
||||
query=nlp_response.graphql_query,
|
||||
|
|
@ -120,12 +120,12 @@ class Processor(FlowProcessor):
|
|||
operation_name=None
|
||||
)
|
||||
|
||||
objects_response = await flow("objects-query-request").request(objects_request)
|
||||
|
||||
objects_response = await flow("rows-query-request").request(objects_request)
|
||||
|
||||
if objects_response.error is not None:
|
||||
raise Exception(f"Objects query service error: {objects_response.error.message}")
|
||||
|
||||
# Handle GraphQL errors from the objects query service
|
||||
raise Exception(f"Rows query service error: {objects_response.error.message}")
|
||||
|
||||
# Handle GraphQL errors from the rows query service
|
||||
graphql_errors = []
|
||||
if objects_response.errors:
|
||||
for gql_error in objects_response.errors:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_ident = "doc-embeddings-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_ident = "doc-embeddings-write"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
default_cloud = "aws"
|
||||
default_region = "us-east-1"
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "de-write"
|
||||
default_ident = "doc-embeddings-write"
|
||||
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
|
|
|||
|
|
@ -9,11 +9,25 @@ from .... direct.milvus_graph_embeddings import EntityVectors
|
|||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-write"
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "graph-embeddings-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||
|
|
@ -36,11 +50,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
async def store_graph_embeddings(self, message):
|
||||
|
||||
for entity in message.entities:
|
||||
entity_value = get_term_value(entity.entity)
|
||||
|
||||
if entity.entity.value != "" and entity.entity.value is not None:
|
||||
if entity_value != "" and entity_value is not None:
|
||||
for vec in entity.vectors:
|
||||
self.vecstore.insert(
|
||||
vec, entity.entity.value,
|
||||
vec, entity_value,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,25 @@ import logging
|
|||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-write"
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
default_ident = "graph-embeddings-write"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
default_cloud = "aws"
|
||||
default_region = "us-east-1"
|
||||
|
|
@ -100,8 +114,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
return
|
||||
|
||||
for entity in message.entities:
|
||||
entity_value = get_term_value(entity.entity)
|
||||
|
||||
if entity.entity.value == "" or entity.entity.value is None:
|
||||
if entity_value == "" or entity_value is None:
|
||||
continue
|
||||
|
||||
for vec in entity.vectors:
|
||||
|
|
@ -126,7 +141,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "entity": entity.entity.value },
|
||||
"metadata": { "entity": entity_value },
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,26 @@ import logging
|
|||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-write"
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
default_ident = "graph-embeddings-write"
|
||||
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
|
@ -51,8 +66,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
return
|
||||
|
||||
for entity in message.entities:
|
||||
entity_value = get_term_value(entity.entity)
|
||||
|
||||
if entity.entity.value == "" or entity.entity.value is None: return
|
||||
if entity_value == "" or entity_value is None:
|
||||
continue
|
||||
|
||||
for vec in entity.vectors:
|
||||
|
||||
|
|
@ -80,7 +97,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"entity": entity.entity.value,
|
||||
"entity": entity_value,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -64,12 +64,14 @@ class Processor(FlowProcessor):
|
|||
async def on_triples(self, msg, consumer, flow):
|
||||
|
||||
v = msg.value()
|
||||
await self.table_store.add_triples(v)
|
||||
if v.triples:
|
||||
await self.table_store.add_triples(v)
|
||||
|
||||
async def on_graph_embeddings(self, msg, consumer, flow):
|
||||
|
||||
v = msg.value()
|
||||
await self.table_store.add_graph_embeddings(v)
|
||||
if v.entities:
|
||||
await self.table_store.add_graph_embeddings(v)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# Objects storage module
|
||||
|
|
@ -1 +0,0 @@
|
|||
from . write import *
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from . write import run
|
||||
|
||||
run()
|
||||
|
|
@ -1,538 +0,0 @@
|
|||
"""
|
||||
Object writer for Cassandra. Input is ExtractedObject.
|
||||
Writes structured objects to Cassandra tables based on schema definitions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.cqlengine import connection
|
||||
from cassandra import ConsistencyLevel
|
||||
|
||||
from .... schema import ExtractedObject
|
||||
from .... schema import RowSchema, Field
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import CollectionConfigHandler
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "objects-write"
|
||||
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
username=cassandra_username,
|
||||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = ExtractedObject,
|
||||
handler = self.on_object
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handlers
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
# Cache of known keyspaces/tables
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = Field(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
def ensure_keyspace(self, keyspace: str):
|
||||
"""Ensure keyspace exists in Cassandra"""
|
||||
if keyspace in self.known_keyspaces:
|
||||
return
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize keyspace name
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Create keyspace if not exists
|
||||
create_keyspace_cql = f"""
|
||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||
WITH REPLICATION = {{
|
||||
'class': 'SimpleStrategy',
|
||||
'replication_factor': 1
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_keyspace_cql)
|
||||
self.known_keyspaces.add(keyspace)
|
||||
self.known_tables[keyspace] = set()
|
||||
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
|
||||
"""Convert schema field type to Cassandra type"""
|
||||
# Handle None size
|
||||
if size is None:
|
||||
size = 0
|
||||
|
||||
type_mapping = {
|
||||
"string": "text",
|
||||
"integer": "bigint" if size > 4 else "int",
|
||||
"float": "double" if size > 4 else "float",
|
||||
"boolean": "boolean",
|
||||
"timestamp": "timestamp",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"uuid": "uuid"
|
||||
}
|
||||
|
||||
return type_mapping.get(field_type, "text")
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
# Replace non-alphanumeric characters with underscore
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def sanitize_table(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
# Replace non-alphanumeric characters with underscore
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema):
|
||||
"""Ensure table exists with proper structure"""
|
||||
table_key = f"{keyspace}.{table_name}"
|
||||
if table_key in self.known_tables.get(keyspace, set()):
|
||||
return
|
||||
|
||||
# Ensure keyspace exists first
|
||||
self.ensure_keyspace(keyspace)
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Build column definitions
|
||||
columns = ["collection text"] # Collection is always part of table
|
||||
primary_key_fields = []
|
||||
clustering_fields = []
|
||||
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
cassandra_type = self.get_cassandra_type(field.type, field.size)
|
||||
columns.append(f"{safe_field_name} {cassandra_type}")
|
||||
|
||||
if field.primary:
|
||||
primary_key_fields.append(safe_field_name)
|
||||
|
||||
# Build primary key - collection is always first in partition key
|
||||
if primary_key_fields:
|
||||
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
|
||||
else:
|
||||
# If no primary key defined, use collection and a synthetic id
|
||||
columns.append("synthetic_id uuid")
|
||||
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
|
||||
|
||||
# Create table
|
||||
create_table_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
|
||||
{', '.join(columns)},
|
||||
{primary_key}
|
||||
)
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_table_cql)
|
||||
if keyspace not in self.known_tables:
|
||||
self.known_tables[keyspace] = set()
|
||||
self.known_tables[keyspace].add(table_key)
|
||||
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
|
||||
|
||||
# Create secondary indexes for indexed fields
|
||||
for field in schema.fields:
|
||||
if field.indexed and not field.primary:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
index_name = f"{safe_table}_{safe_field_name}_idx"
|
||||
create_index_cql = f"""
|
||||
CREATE INDEX IF NOT EXISTS {index_name}
|
||||
ON {safe_keyspace}.{safe_table} ({safe_field_name})
|
||||
"""
|
||||
try:
|
||||
self.session.execute(create_index_cql)
|
||||
logger.info(f"Created index: {index_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create index {index_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def convert_value(self, value: Any, field_type: str) -> Any:
|
||||
"""Convert value to appropriate type for Cassandra"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if field_type == "integer":
|
||||
return int(value)
|
||||
elif field_type == "float":
|
||||
return float(value)
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes')
|
||||
return bool(value)
|
||||
elif field_type == "timestamp":
|
||||
# Handle timestamp conversion if needed
|
||||
return value
|
||||
else:
|
||||
return str(value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
|
||||
return str(value)
|
||||
async def on_object(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {obj.metadata.collection} does not exist. "
|
||||
f"Create it first via collection management API."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
# Ensure table exists
|
||||
keyspace = obj.metadata.user
|
||||
table_name = obj.schema_name
|
||||
self.ensure_table(keyspace, table_name, schema)
|
||||
|
||||
# Prepare data for insertion
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Process each object in the batch
|
||||
for obj_index, value_map in enumerate(obj.values):
|
||||
# Build column names and values for this object
|
||||
columns = ["collection"]
|
||||
values = [obj.metadata.collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
values.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Process fields for this object
|
||||
skip_object = False
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = value_map.get(field.name)
|
||||
|
||||
# Handle required fields
|
||||
if field.required and raw_value is None:
|
||||
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
|
||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
||||
|
||||
# Check if primary key field is NULL
|
||||
if field.primary and raw_value is None:
|
||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
|
||||
skip_object = True
|
||||
break
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
|
||||
columns.append(safe_field_name)
|
||||
values.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Skip this object if primary key validation failed
|
||||
if skip_object:
|
||||
continue
|
||||
|
||||
# Build and execute insert query for this object
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
# Debug: Show data being inserted
|
||||
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
|
||||
|
||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
||||
|
||||
try:
|
||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
||||
self.session.execute(insert_cql, tuple(values))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Create/verify collection exists in Cassandra object store"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Ensure keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
self.ensure_keyspace(safe_keyspace)
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# For Cassandra objects, collection is just a property in rows
|
||||
# No need to create separate tables per collection
|
||||
# Just mark that we've seen this collection
|
||||
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection using schema information"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Check if keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
# Query to verify keyspace exists
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
if not result.one():
|
||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||
return
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# Iterate over schemas we manage to delete from relevant tables
|
||||
tables_deleted = 0
|
||||
|
||||
for schema_name, schema in self.schemas.items():
|
||||
safe_table = self.sanitize_table(schema_name)
|
||||
|
||||
# Check if table exists
|
||||
table_key = f"{user}.{schema_name}"
|
||||
if table_key not in self.known_tables.get(user, set()):
|
||||
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get primary key fields from schema
|
||||
primary_key_fields = [field for field in schema.fields if field.primary]
|
||||
|
||||
if primary_key_fields:
|
||||
# Schema has primary keys: need to query for partition keys first
|
||||
# Build SELECT query for primary key fields
|
||||
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
|
||||
select_cql = f"""
|
||||
SELECT {', '.join(pk_field_names)}
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using full partition key
|
||||
for row in rows:
|
||||
where_clauses = ["collection = %s"]
|
||||
values = [collection]
|
||||
|
||||
for field_name in pk_field_names:
|
||||
where_clauses.append(f"{field_name} = %s")
|
||||
values.append(getattr(row, field_name))
|
||||
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE {' AND '.join(where_clauses)}
|
||||
"""
|
||||
|
||||
self.session.execute(delete_cql, tuple(values))
|
||||
else:
|
||||
# No primary keys, uses synthetic_id
|
||||
# Need to query for synthetic_ids first
|
||||
select_cql = f"""
|
||||
SELECT synthetic_id
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using collection and synthetic_id
|
||||
for row in rows:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s AND synthetic_id = %s
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection, row.synthetic_id))
|
||||
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for objects-write-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Row embeddings storage modules.
|
||||
"""
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Qdrant storage for row embeddings.
|
||||
"""
|
||||
|
||||
from .write import Processor, run, default_ident
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
from .write import run
|
||||
|
||||
run()
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
Row embeddings writer for Qdrant (Stage 2).
|
||||
|
||||
Consumes RowEmbeddings messages (which already contain computed vectors)
|
||||
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
|
||||
|
||||
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
|
||||
Stage 1 (row-embeddings): Compute embeddings
|
||||
Stage 2 (this processor): Store embeddings
|
||||
|
||||
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
|
||||
|
||||
Payload structure:
|
||||
- index_name: The indexed field(s) this embedding represents
|
||||
- index_value: The original list of values (for Cassandra lookup)
|
||||
- text: The text that was embedded (for debugging/display)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Set, Tuple
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||
|
||||
from .... schema import RowEmbeddings
|
||||
from .... base import FlowProcessor, ConsumerSpec
|
||||
from .... base import CollectionConfigHandler
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "row-embeddings-write"
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=RowEmbeddings,
|
||||
handler=self.on_embeddings
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for collection management
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
# Cache of created Qdrant collections
|
||||
self.created_collections: Set[str] = set()
|
||||
|
||||
# Qdrant client
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Qdrant collection naming"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def get_collection_name(
|
||||
self, user: str, collection: str, schema_name: str, dimension: int
|
||||
) -> str:
|
||||
"""Generate Qdrant collection name"""
|
||||
safe_user = self.sanitize_name(user)
|
||||
safe_collection = self.sanitize_name(collection)
|
||||
safe_schema = self.sanitize_name(schema_name)
|
||||
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
|
||||
|
||||
def ensure_collection(self, collection_name: str, dimension: int):
|
||||
"""Create Qdrant collection if it doesn't exist"""
|
||||
if collection_name in self.created_collections:
|
||||
return
|
||||
|
||||
if not self.qdrant.collection_exists(collection_name):
|
||||
logger.info(
|
||||
f"Creating Qdrant collection {collection_name} "
|
||||
f"with dimension {dimension}"
|
||||
)
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=dimension,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
self.created_collections.add(collection_name)
|
||||
|
||||
async def on_embeddings(self, msg, consumer, flow):
|
||||
"""Process incoming RowEmbeddings and write to Qdrant"""
|
||||
|
||||
embeddings = msg.value()
|
||||
logger.info(
|
||||
f"Writing {len(embeddings.embeddings)} embeddings for schema "
|
||||
f"{embeddings.schema_name} from {embeddings.metadata.id}"
|
||||
)
|
||||
|
||||
# Validate collection exists in config before processing
|
||||
if not self.collection_exists(
|
||||
embeddings.metadata.user, embeddings.metadata.collection
|
||||
):
|
||||
logger.warning(
|
||||
f"Collection {embeddings.metadata.collection} for user "
|
||||
f"{embeddings.metadata.user} does not exist in config. "
|
||||
f"Dropping message."
|
||||
)
|
||||
return
|
||||
|
||||
user = embeddings.metadata.user
|
||||
collection = embeddings.metadata.collection
|
||||
schema_name = embeddings.schema_name
|
||||
|
||||
embeddings_written = 0
|
||||
qdrant_collection = None
|
||||
|
||||
for row_emb in embeddings.embeddings:
|
||||
if not row_emb.vectors:
|
||||
logger.warning(
|
||||
f"No vectors for index {row_emb.index_name} - skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Use first vector (there may be multiple from different models)
|
||||
for vector in row_emb.vectors:
|
||||
dimension = len(vector)
|
||||
|
||||
# Create/get collection name (lazily on first vector)
|
||||
if qdrant_collection is None:
|
||||
qdrant_collection = self.get_collection_name(
|
||||
user, collection, schema_name, dimension
|
||||
)
|
||||
self.ensure_collection(qdrant_collection, dimension)
|
||||
|
||||
# Write to Qdrant
|
||||
self.qdrant.upsert(
|
||||
collection_name=qdrant_collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vector,
|
||||
payload={
|
||||
"index_name": row_emb.index_name,
|
||||
"index_value": row_emb.index_value,
|
||||
"text": row_emb.text
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
embeddings_written += 1
|
||||
|
||||
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Collection creation via config push - collections created lazily on first write"""
|
||||
logger.info(
|
||||
f"Row embeddings collection create request for {user}/{collection} - "
|
||||
f"will be created lazily on first write"
|
||||
)
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all Qdrant collections for a given user/collection"""
|
||||
try:
|
||||
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if not matching_collections:
|
||||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
self.created_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
logger.info(
|
||||
f"Deleted {len(matching_collections)} collection(s) "
|
||||
f"for {user}/{collection}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete collection {user}/{collection}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def delete_collection_schema(
|
||||
self, user: str, collection: str, schema_name: str
|
||||
):
|
||||
"""Delete Qdrant collection for a specific user/collection/schema"""
|
||||
try:
|
||||
prefix = (
|
||||
f"rows_{self.sanitize_name(user)}_"
|
||||
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||
)
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if not matching_collections:
|
||||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
self.created_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Qdrant URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
default=None,
|
||||
help='Qdrant API key (default: None)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for row-embeddings-write-qdrant command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
@ -1,46 +1,49 @@
|
|||
|
||||
"""
|
||||
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
||||
Row writer for Cassandra. Input is ExtractedObject.
|
||||
Writes structured rows to a unified Cassandra table with multi-index support.
|
||||
|
||||
Uses a single 'rows' table with the schema:
|
||||
- collection: text
|
||||
- schema_name: text
|
||||
- index_name: text
|
||||
- index_value: frozen<list<text>>
|
||||
- data: map<text, text>
|
||||
- source: text
|
||||
|
||||
Each row is written multiple times - once per indexed field defined in the schema.
|
||||
"""
|
||||
|
||||
raise RuntimeError("This code is no longer in use")
|
||||
|
||||
import pulsar
|
||||
import base64
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Set, Optional, Any, List, Tuple
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
|
||||
from .... schema import Rows
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... schema import ExtractedObject
|
||||
from .... schema import RowSchema, Field
|
||||
from .... base import FlowProcessor, ConsumerSpec
|
||||
from .... base import CollectionConfigHandler
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
module = "rows-write"
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
default_ident = "rows-write"
|
||||
|
||||
default_input_queue = "rows-store" # Default queue name
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
||||
|
||||
# Resolve configuration with environment variable fallback
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=cassandra_host,
|
||||
|
|
@ -48,99 +51,549 @@ class Processor(Consumer):
|
|||
password=cassandra_password
|
||||
)
|
||||
|
||||
# Store resolved configuration with proper names
|
||||
self.cassandra_host = hosts # Store as list
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Rows,
|
||||
"cassandra_host": ','.join(hosts),
|
||||
"cassandra_username": username,
|
||||
"cassandra_password": password,
|
||||
"id": id,
|
||||
"config_type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
if username and password:
|
||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||
else:
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.tables = set()
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=ExtractedObject,
|
||||
handler=self.on_object
|
||||
)
|
||||
)
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
""");
|
||||
# Register config handlers
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
|
||||
self.session.execute("use trustgraph");
|
||||
# Cache of known keyspaces and whether tables exist
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
|
||||
|
||||
async def handle(self, msg):
|
||||
# Cache of registered (collection, schema_name) pairs
|
||||
self.registered_partitions: Set[Tuple[str, str]] = set()
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
name = v.row_schema.name
|
||||
|
||||
if name not in self.tables:
|
||||
|
||||
# FIXME: SQL injection?
|
||||
|
||||
pkey = []
|
||||
|
||||
stmt = "create table if not exists " + name + " ( "
|
||||
|
||||
for field in v.row_schema.fields:
|
||||
|
||||
stmt += field.name + " text, "
|
||||
|
||||
if field.primary:
|
||||
pkey.append(field.name)
|
||||
|
||||
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
|
||||
|
||||
self.session.execute(stmt)
|
||||
|
||||
self.tables.add(name);
|
||||
|
||||
for row in v.rows:
|
||||
|
||||
field_names = []
|
||||
values = []
|
||||
|
||||
for field in v.row_schema.fields:
|
||||
field_names.append(field.name)
|
||||
values.append(row[field.name])
|
||||
|
||||
# FIXME: SQL injection?
|
||||
stmt = (
|
||||
"insert into " + name + " (" + ", ".join(field_names) +
|
||||
") values (" + ",".join(["%s"] * len(values)) + ")"
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session.execute(stmt, values)
|
||||
self.session = self.cluster.connect()
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# If there's an error make sure to do table creation etc.
|
||||
self.tables.remove(name)
|
||||
# Track which schemas changed so we can clear partition cache
|
||||
old_schema_names = set(self.schemas.keys())
|
||||
|
||||
raise e
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = Field(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
# Clear partition cache for schemas that changed
|
||||
# This ensures next write will re-register partitions
|
||||
new_schema_names = set(self.schemas.keys())
|
||||
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
|
||||
if changed_schemas:
|
||||
self.registered_partitions = {
|
||||
(col, sch) for col, sch in self.registered_partitions
|
||||
if sch not in changed_schemas
|
||||
}
|
||||
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility"""
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def ensure_keyspace(self, keyspace: str):
|
||||
"""Ensure keyspace exists in Cassandra"""
|
||||
if keyspace in self.known_keyspaces:
|
||||
return
|
||||
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize keyspace name
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Create keyspace if not exists
|
||||
create_keyspace_cql = f"""
|
||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||
WITH REPLICATION = {{
|
||||
'class': 'SimpleStrategy',
|
||||
'replication_factor': 1
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_keyspace_cql)
|
||||
self.known_keyspaces.add(keyspace)
|
||||
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def ensure_tables(self, keyspace: str):
|
||||
"""Ensure unified rows and row_partitions tables exist"""
|
||||
if keyspace in self.tables_initialized:
|
||||
return
|
||||
|
||||
# Ensure keyspace exists first
|
||||
self.ensure_keyspace(keyspace)
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Create unified rows table
|
||||
create_rows_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
|
||||
collection text,
|
||||
schema_name text,
|
||||
index_name text,
|
||||
index_value frozen<list<text>>,
|
||||
data map<text, text>,
|
||||
source text,
|
||||
PRIMARY KEY ((collection, schema_name, index_name), index_value)
|
||||
)
|
||||
"""
|
||||
|
||||
# Create row_partitions tracking table
|
||||
create_partitions_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
|
||||
collection text,
|
||||
schema_name text,
|
||||
index_name text,
|
||||
PRIMARY KEY ((collection), schema_name, index_name)
|
||||
)
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(create_rows_cql)
|
||||
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
|
||||
|
||||
self.session.execute(create_partitions_cql)
|
||||
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
|
||||
|
||||
self.tables_initialized.add(keyspace)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
"""
|
||||
Get all index names for a schema.
|
||||
Returns list of index_name strings (single field names or comma-joined composites).
|
||||
"""
|
||||
index_names = []
|
||||
|
||||
for field in schema.fields:
|
||||
# Primary key fields are treated as indexes
|
||||
if field.primary:
|
||||
index_names.append(field.name)
|
||||
# Indexed fields
|
||||
elif field.indexed:
|
||||
index_names.append(field.name)
|
||||
|
||||
# TODO: Support composite indexes in the future
|
||||
# For now, each indexed field is a single-field index
|
||||
|
||||
return index_names
|
||||
|
||||
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
|
||||
"""
|
||||
Register partition entries for a (collection, schema_name) pair.
|
||||
Called once on first row for each pair.
|
||||
"""
|
||||
cache_key = (collection, schema_name)
|
||||
if cache_key in self.registered_partitions:
|
||||
return
|
||||
|
||||
schema = self.schemas.get(schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
|
||||
return
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
# Insert partition entries for each index
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
self.session.execute(insert_cql, (collection, schema_name, index_name))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
|
||||
|
||||
self.registered_partitions.add(cache_key)
|
||||
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
|
||||
|
||||
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
||||
"""
|
||||
Build the index_value list for a given index.
|
||||
For single-field indexes, returns a single-element list.
|
||||
For composite indexes (comma-separated), returns multiple elements.
|
||||
"""
|
||||
field_names = [f.strip() for f in index_name.split(',')]
|
||||
values = []
|
||||
|
||||
for field_name in field_names:
|
||||
value = value_map.get(field_name)
|
||||
# Convert to string for storage
|
||||
values.append(str(value) if value is not None else "")
|
||||
|
||||
return values
|
||||
|
||||
async def on_object(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(
|
||||
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
|
||||
f"from {obj.metadata.id}"
|
||||
)
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {obj.metadata.collection} does not exist. "
|
||||
f"Create it first via collection management API."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
keyspace = obj.metadata.user
|
||||
collection = obj.metadata.collection
|
||||
schema_name = obj.schema_name
|
||||
source = getattr(obj.metadata, 'source', '') or ''
|
||||
|
||||
# Ensure tables exist
|
||||
self.ensure_tables(keyspace)
|
||||
|
||||
# Register partitions if first time seeing this (collection, schema_name)
|
||||
self.register_partitions(keyspace, collection, schema_name)
|
||||
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
|
||||
# Get all index names for this schema
|
||||
index_names = self.get_index_names(schema)
|
||||
|
||||
if not index_names:
|
||||
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
|
||||
return
|
||||
|
||||
# Prepare insert statement
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.rows
|
||||
(collection, schema_name, index_name, index_value, data, source)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
# Process each row in the batch
|
||||
rows_written = 0
|
||||
for row_index, value_map in enumerate(obj.values):
|
||||
# Convert all values to strings for the data map
|
||||
data_map = {}
|
||||
for field in schema.fields:
|
||||
raw_value = value_map.get(field.name)
|
||||
if raw_value is not None:
|
||||
data_map[field.name] = str(raw_value)
|
||||
|
||||
# Write one copy per index
|
||||
for index_name in index_names:
|
||||
index_value = self.build_index_value(value_map, index_name)
|
||||
|
||||
# Skip if index value is empty/null
|
||||
if not index_value or all(v == "" for v in index_value):
|
||||
logger.debug(
|
||||
f"Skipping index {index_name} for row {row_index} - "
|
||||
f"empty index value"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
self.session.execute(
|
||||
insert_cql,
|
||||
(collection, schema_name, index_name, index_value, data_map, source)
|
||||
)
|
||||
rows_written += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to insert row {row_index} for index {index_name}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
|
||||
f"({len(index_names)} indexes per row)"
|
||||
)
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""Create/verify collection exists in Cassandra row store"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Ensure tables exist
|
||||
self.ensure_tables(user)
|
||||
|
||||
logger.info(f"Collection {collection} ready for user {user}")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection using partition tracking"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Check if keyspace exists
|
||||
if user not in self.known_keyspaces:
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
if not result.one():
|
||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||
return
|
||||
self.known_keyspaces.add(user)
|
||||
|
||||
# Discover all partitions for this collection
|
||||
select_partitions_cql = f"""
|
||||
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
partitions = self.session.execute(select_partitions_cql, (collection,))
|
||||
partition_list = list(partitions)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query partitions for collection {collection}: {e}")
|
||||
raise
|
||||
|
||||
# Delete each partition from rows table
|
||||
delete_rows_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
||||
"""
|
||||
|
||||
partitions_deleted = 0
|
||||
for partition in partition_list:
|
||||
try:
|
||||
self.session.execute(
|
||||
delete_rows_cql,
|
||||
(collection, partition.schema_name, partition.index_name)
|
||||
)
|
||||
partitions_deleted += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete partition {collection}/{partition.schema_name}/"
|
||||
f"{partition.index_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Clean up row_partitions entries
|
||||
delete_partitions_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(delete_partitions_cql, (collection,))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
||||
raise
|
||||
|
||||
# Clear from local cache
|
||||
self.registered_partitions = {
|
||||
(col, sch) for col, sch in self.registered_partitions
|
||||
if col != collection
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleted collection {collection}: {partitions_deleted} partitions "
|
||||
f"from keyspace {safe_keyspace}"
|
||||
)
|
||||
|
||||
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
|
||||
"""Delete all data for a specific collection + schema combination"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Discover partitions for this collection + schema
|
||||
select_partitions_cql = f"""
|
||||
SELECT index_name FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s AND schema_name = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
|
||||
partition_list = list(partitions)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to query partitions for {collection}/{schema_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Delete each partition from rows table
|
||||
delete_rows_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.rows
|
||||
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
||||
"""
|
||||
|
||||
partitions_deleted = 0
|
||||
for partition in partition_list:
|
||||
try:
|
||||
self.session.execute(
|
||||
delete_rows_cql,
|
||||
(collection, schema_name, partition.index_name)
|
||||
)
|
||||
partitions_deleted += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete partition {collection}/{schema_name}/"
|
||||
f"{partition.index_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Clean up row_partitions entries for this schema
|
||||
delete_partitions_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.row_partitions
|
||||
WHERE collection = %s AND schema_name = %s
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(delete_partitions_cql, (collection, schema_name))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Clear from local cache
|
||||
self.registered_partitions.discard((collection, schema_name))
|
||||
|
||||
logger.info(
|
||||
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
|
||||
f"from keyspace {safe_keyspace}"
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
logger.info("Closed Cassandra connection")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
FlowProcessor.add_args(parser)
|
||||
add_cassandra_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
"""Entry point for rows-write-cassandra command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,14 @@ import argparse
|
|||
import time
|
||||
import logging
|
||||
|
||||
from .... direct.cassandra_kg import KnowledgeGraph
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
|
||||
)
|
||||
from .... base import TriplesStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
from .... schema import IRI, LITERAL, BLANK, TRIPLE
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -22,6 +25,59 @@ logger = logging.getLogger(__name__)
|
|||
default_ident = "triples-write"
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
def get_term_otype(term):
|
||||
"""
|
||||
Get object type code from a Term for entity-centric storage.
|
||||
|
||||
Maps Term.type to otype:
|
||||
- IRI ("i") → "u" (URI)
|
||||
- BLANK ("b") → "u" (treated as URI)
|
||||
- LITERAL ("l") → "l" (Literal)
|
||||
- TRIPLE ("t") → "t" (Triple/reification)
|
||||
"""
|
||||
if term is None:
|
||||
return "u"
|
||||
if term.type == IRI or term.type == BLANK:
|
||||
return "u"
|
||||
elif term.type == LITERAL:
|
||||
return "l"
|
||||
elif term.type == TRIPLE:
|
||||
return "t"
|
||||
else:
|
||||
return "u"
|
||||
|
||||
|
||||
def get_term_dtype(term):
|
||||
"""Extract datatype from a Term (for literals)"""
|
||||
if term is None:
|
||||
return ""
|
||||
if term.type == LITERAL:
|
||||
return term.datatype or ""
|
||||
return ""
|
||||
|
||||
|
||||
def get_term_lang(term):
|
||||
"""Extract language tag from a Term (for literals)"""
|
||||
if term is None:
|
||||
return ""
|
||||
if term.type == LITERAL:
|
||||
return term.language or ""
|
||||
return ""
|
||||
|
||||
|
||||
class Processor(CollectionConfigHandler, TriplesStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -64,15 +120,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
|
||||
self.tg = None
|
||||
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.metadata.user,
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.metadata.user,
|
||||
)
|
||||
|
|
@ -84,11 +143,27 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
self.table = user
|
||||
|
||||
for t in message.triples:
|
||||
# Extract values from Term objects
|
||||
s_val = get_term_value(t.s)
|
||||
p_val = get_term_value(t.p)
|
||||
o_val = get_term_value(t.o)
|
||||
# t.g is None for default graph, or a graph IRI
|
||||
g_val = t.g if t.g is not None else DEFAULT_GRAPH
|
||||
|
||||
# Extract object type metadata for entity-centric storage
|
||||
otype = get_term_otype(t.o)
|
||||
dtype = get_term_dtype(t.o)
|
||||
lang = get_term_lang(t.o)
|
||||
|
||||
self.tg.insert(
|
||||
message.metadata.collection,
|
||||
t.s.value,
|
||||
t.p.value,
|
||||
t.o.value
|
||||
s_val,
|
||||
p_val,
|
||||
o_val,
|
||||
g=g_val,
|
||||
otype=otype,
|
||||
dtype=dtype,
|
||||
lang=lang
|
||||
)
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
|
|
@ -98,16 +173,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
if self.table is None or self.table != user:
|
||||
self.tg = None
|
||||
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
)
|
||||
|
|
@ -137,16 +215,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
if self.table is None or self.table != user:
|
||||
self.tg = None
|
||||
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,12 +15,27 @@ from falkordb import FalkorDB
|
|||
from .... base import TriplesStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-write"
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
default_graph_url = 'falkor://falkordb:6379'
|
||||
default_database = 'falkordb'
|
||||
|
||||
|
|
@ -164,14 +179,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
s_val = get_term_value(t.s)
|
||||
p_val = get_term_value(t.p)
|
||||
o_val = get_term_value(t.o)
|
||||
|
||||
if t.o.is_uri:
|
||||
self.create_node(t.o.value, user, collection)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
self.create_node(s_val, user, collection)
|
||||
|
||||
if t.o.type == IRI:
|
||||
self.create_node(o_val, user, collection)
|
||||
self.relate_node(s_val, p_val, o_val, user, collection)
|
||||
else:
|
||||
self.create_literal(t.o.value, user, collection)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
self.create_literal(o_val, user, collection)
|
||||
self.relate_literal(s_val, p_val, o_val, user, collection)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -15,12 +15,27 @@ from neo4j import GraphDatabase
|
|||
from .... base import TriplesStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-write"
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
default_graph_host = 'bolt://memgraph:7687'
|
||||
default_username = 'memgraph'
|
||||
default_password = 'password'
|
||||
|
|
@ -204,40 +219,44 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
|
||||
def create_triple(self, tx, t, user, collection):
|
||||
|
||||
s_val = get_term_value(t.s)
|
||||
p_val = get_term_value(t.p)
|
||||
o_val = get_term_value(t.o)
|
||||
|
||||
# Create new s node with given uri, if not exists
|
||||
result = tx.run(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
uri=t.s.value, user=user, collection=collection
|
||||
uri=s_val, user=user, collection=collection
|
||||
)
|
||||
|
||||
if t.o.is_uri:
|
||||
if t.o.type == IRI:
|
||||
|
||||
# Create new o node with given uri, if not exists
|
||||
result = tx.run(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
uri=t.o.value, user=user, collection=collection
|
||||
uri=o_val, user=user, collection=collection
|
||||
)
|
||||
|
||||
result = tx.run(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
|
||||
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
|
||||
# Create new o literal with given uri, if not exists
|
||||
result = tx.run(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value=t.o.value, user=user, collection=collection
|
||||
value=o_val, user=user, collection=collection
|
||||
)
|
||||
|
||||
result = tx.run(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
|
||||
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
|
||||
)
|
||||
|
||||
async def store_triples(self, message):
|
||||
|
|
@ -257,14 +276,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
s_val = get_term_value(t.s)
|
||||
p_val = get_term_value(t.p)
|
||||
o_val = get_term_value(t.o)
|
||||
|
||||
if t.o.is_uri:
|
||||
self.create_node(t.o.value, user, collection)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
self.create_node(s_val, user, collection)
|
||||
|
||||
if t.o.type == IRI:
|
||||
self.create_node(o_val, user, collection)
|
||||
self.relate_node(s_val, p_val, o_val, user, collection)
|
||||
else:
|
||||
self.create_literal(t.o.value, user, collection)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
self.create_literal(o_val, user, collection)
|
||||
self.relate_literal(s_val, p_val, o_val, user, collection)
|
||||
|
||||
# Alternative implementation using transactions
|
||||
# with self.io.session(database=self.db) as session:
|
||||
|
|
|
|||
|
|
@ -14,12 +14,27 @@ from neo4j import GraphDatabase
|
|||
from .... base import TriplesStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
from .... base import ConsumerMetrics, ProducerMetrics
|
||||
from .... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-write"
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
default_graph_host = 'bolt://neo4j:7687'
|
||||
default_username = 'neo4j'
|
||||
default_password = 'password'
|
||||
|
|
@ -212,14 +227,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
s_val = get_term_value(t.s)
|
||||
p_val = get_term_value(t.p)
|
||||
o_val = get_term_value(t.o)
|
||||
|
||||
if t.o.is_uri:
|
||||
self.create_node(t.o.value, user, collection)
|
||||
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
self.create_node(s_val, user, collection)
|
||||
|
||||
if t.o.type == IRI:
|
||||
self.create_node(o_val, user, collection)
|
||||
self.relate_node(s_val, p_val, o_val, user, collection)
|
||||
else:
|
||||
self.create_literal(t.o.value, user, collection)
|
||||
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
|
||||
self.create_literal(o_val, user, collection)
|
||||
self.relate_literal(s_val, p_val, o_val, user, collection)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
|
||||
from .. schema import Metadata, Value, GraphEmbeddings
|
||||
from .. schema import Metadata, GraphEmbeddings
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
|
|
|||
|
|
@ -1,8 +1,24 @@
|
|||
|
||||
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
|
||||
from .. schema import Metadata, Value, GraphEmbeddings
|
||||
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
|
||||
def term_to_tuple(term):
|
||||
"""Convert Term to (value, is_uri) tuple for database storage."""
|
||||
if term.type == IRI:
|
||||
return (term.iri, True)
|
||||
else: # LITERAL
|
||||
return (term.value, False)
|
||||
|
||||
|
||||
def tuple_to_term(value, is_uri):
|
||||
"""Convert (value, is_uri) tuple from database to Term."""
|
||||
if is_uri:
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
return Term(type=LITERAL, value=value)
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
|
||||
|
|
@ -205,8 +221,7 @@ class KnowledgeTableStore:
|
|||
if m.metadata.metadata:
|
||||
metadata = [
|
||||
(
|
||||
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
|
||||
v.o.value, v.o.is_uri
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.metadata.metadata
|
||||
]
|
||||
|
|
@ -215,8 +230,7 @@ class KnowledgeTableStore:
|
|||
|
||||
triples = [
|
||||
(
|
||||
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
|
||||
v.o.value, v.o.is_uri
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.triples
|
||||
]
|
||||
|
|
@ -248,8 +262,7 @@ class KnowledgeTableStore:
|
|||
if m.metadata.metadata:
|
||||
metadata = [
|
||||
(
|
||||
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
|
||||
v.o.value, v.o.is_uri
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.metadata.metadata
|
||||
]
|
||||
|
|
@ -258,7 +271,7 @@ class KnowledgeTableStore:
|
|||
|
||||
entities = [
|
||||
(
|
||||
(v.entity.value, v.entity.is_uri),
|
||||
term_to_tuple(v.entity),
|
||||
v.vectors
|
||||
)
|
||||
for v in m.entities
|
||||
|
|
@ -291,8 +304,7 @@ class KnowledgeTableStore:
|
|||
if m.metadata.metadata:
|
||||
metadata = [
|
||||
(
|
||||
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
|
||||
v.o.value, v.o.is_uri
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.metadata.metadata
|
||||
]
|
||||
|
|
@ -414,23 +426,26 @@ class KnowledgeTableStore:
|
|||
if row[2]:
|
||||
metadata = [
|
||||
Triple(
|
||||
s = Value(value = elt[0], is_uri = elt[1]),
|
||||
p = Value(value = elt[2], is_uri = elt[3]),
|
||||
o = Value(value = elt[4], is_uri = elt[5]),
|
||||
s = tuple_to_term(elt[0], elt[1]),
|
||||
p = tuple_to_term(elt[2], elt[3]),
|
||||
o = tuple_to_term(elt[4], elt[5]),
|
||||
)
|
||||
for elt in row[2]
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
triples = [
|
||||
Triple(
|
||||
s = Value(value = elt[0], is_uri = elt[1]),
|
||||
p = Value(value = elt[2], is_uri = elt[3]),
|
||||
o = Value(value = elt[4], is_uri = elt[5]),
|
||||
)
|
||||
for elt in row[3]
|
||||
]
|
||||
if row[3]:
|
||||
triples = [
|
||||
Triple(
|
||||
s = tuple_to_term(elt[0], elt[1]),
|
||||
p = tuple_to_term(elt[2], elt[3]),
|
||||
o = tuple_to_term(elt[4], elt[5]),
|
||||
)
|
||||
for elt in row[3]
|
||||
]
|
||||
else:
|
||||
triples = []
|
||||
|
||||
await receiver(
|
||||
Triples(
|
||||
|
|
@ -470,22 +485,25 @@ class KnowledgeTableStore:
|
|||
if row[2]:
|
||||
metadata = [
|
||||
Triple(
|
||||
s = Value(value = elt[0], is_uri = elt[1]),
|
||||
p = Value(value = elt[2], is_uri = elt[3]),
|
||||
o = Value(value = elt[4], is_uri = elt[5]),
|
||||
s = tuple_to_term(elt[0], elt[1]),
|
||||
p = tuple_to_term(elt[2], elt[3]),
|
||||
o = tuple_to_term(elt[4], elt[5]),
|
||||
)
|
||||
for elt in row[2]
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity = Value(value = ent[0][0], is_uri = ent[0][1]),
|
||||
vectors = ent[1]
|
||||
)
|
||||
for ent in row[3]
|
||||
]
|
||||
if row[3]:
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity = tuple_to_term(ent[0][0], ent[0][1]),
|
||||
vectors = ent[1]
|
||||
)
|
||||
for ent in row[3]
|
||||
]
|
||||
else:
|
||||
entities = []
|
||||
|
||||
await receiver(
|
||||
GraphEmbeddings(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,24 @@
|
|||
|
||||
from .. schema import LibrarianRequest, LibrarianResponse
|
||||
from .. schema import DocumentMetadata, ProcessingMetadata
|
||||
from .. schema import Error, Triple, Value
|
||||
from .. schema import Error, Triple, Term, IRI, LITERAL
|
||||
from .. knowledge import hash
|
||||
|
||||
|
||||
def term_to_tuple(term):
|
||||
"""Convert Term to (value, is_uri) tuple for database storage."""
|
||||
if term.type == IRI:
|
||||
return (term.iri, True)
|
||||
else: # LITERAL
|
||||
return (term.value, False)
|
||||
|
||||
|
||||
def tuple_to_term(value, is_uri):
|
||||
"""Convert (value, is_uri) tuple from database to Term."""
|
||||
if is_uri:
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
return Term(type=LITERAL, value=value)
|
||||
from .. exceptions import RequestError
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
|
|
@ -215,8 +231,7 @@ class LibraryTableStore:
|
|||
|
||||
metadata = [
|
||||
(
|
||||
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
|
||||
v.o.value, v.o.is_uri
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in document.metadata
|
||||
]
|
||||
|
|
@ -249,8 +264,7 @@ class LibraryTableStore:
|
|||
|
||||
metadata = [
|
||||
(
|
||||
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
|
||||
v.o.value, v.o.is_uri
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in document.metadata
|
||||
]
|
||||
|
|
@ -331,9 +345,9 @@ class LibraryTableStore:
|
|||
comments = row[4],
|
||||
metadata = [
|
||||
Triple(
|
||||
s=Value(value=m[0], is_uri=m[1]),
|
||||
p=Value(value=m[2], is_uri=m[3]),
|
||||
o=Value(value=m[4], is_uri=m[5])
|
||||
s=tuple_to_term(m[0], m[1]),
|
||||
p=tuple_to_term(m[2], m[3]),
|
||||
o=tuple_to_term(m[4], m[5])
|
||||
)
|
||||
for m in row[5]
|
||||
],
|
||||
|
|
@ -376,9 +390,9 @@ class LibraryTableStore:
|
|||
comments = row[3],
|
||||
metadata = [
|
||||
Triple(
|
||||
s=Value(value=m[0], is_uri=m[1]),
|
||||
p=Value(value=m[2], is_uri=m[3]),
|
||||
o=Value(value=m[4], is_uri=m[5])
|
||||
s=tuple_to_term(m[0], m[1]),
|
||||
p=tuple_to_term(m[2], m[3]),
|
||||
o=tuple_to_term(m[4], m[5])
|
||||
)
|
||||
for m in row[4]
|
||||
],
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class PromptManager:
|
|||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
|
|
@ -92,6 +92,43 @@ class PromptManager:
|
|||
|
||||
return json.loads(json_str)
|
||||
|
||||
def parse_jsonl(self, text):
|
||||
"""
|
||||
Parse JSONL response, returning list of valid objects.
|
||||
|
||||
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
|
||||
This provides truncation resilience - partial output yields partial results.
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Strip markdown code fences if present
|
||||
text = text.strip()
|
||||
if text.startswith('```'):
|
||||
# Remove opening fence (possibly with language hint)
|
||||
text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text)
|
||||
if text.endswith('```'):
|
||||
text = text[:-3]
|
||||
|
||||
for line_num, line in enumerate(text.strip().split('\n'), 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip any remaining fence markers
|
||||
if line.startswith('```'):
|
||||
continue
|
||||
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
results.append(obj)
|
||||
except json.JSONDecodeError as e:
|
||||
# Log warning but continue - this provides truncation resilience
|
||||
logger.warning(f"JSONL parse error on line {line_num}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def render(self, id, input):
|
||||
|
||||
if id not in self.prompts:
|
||||
|
|
@ -121,21 +158,41 @@ class PromptManager:
|
|||
if resp_type == "text":
|
||||
return resp
|
||||
|
||||
if resp_type != "json":
|
||||
raise RuntimeError(f"Response type {resp_type} not known")
|
||||
|
||||
try:
|
||||
obj = self.parse_json(resp)
|
||||
except:
|
||||
logger.error(f"JSON parse failed: {resp}")
|
||||
raise RuntimeError("JSON parse fail")
|
||||
|
||||
if self.prompts[id].schema:
|
||||
if resp_type == "json":
|
||||
try:
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
logger.debug("Schema validation successful")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Schema validation fail: {e}")
|
||||
obj = self.parse_json(resp)
|
||||
except:
|
||||
logger.error(f"JSON parse failed: {resp}")
|
||||
raise RuntimeError("JSON parse fail")
|
||||
|
||||
return obj
|
||||
if self.prompts[id].schema:
|
||||
try:
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
logger.debug("Schema validation successful")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Schema validation fail: {e}")
|
||||
|
||||
return obj
|
||||
|
||||
if resp_type == "jsonl":
|
||||
objects = self.parse_jsonl(resp)
|
||||
|
||||
if not objects:
|
||||
logger.warning("JSONL parse returned no valid objects")
|
||||
return []
|
||||
|
||||
# Validate each object against schema if provided
|
||||
if self.prompts[id].schema:
|
||||
validated = []
|
||||
for i, obj in enumerate(objects):
|
||||
try:
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
validated.append(obj)
|
||||
except Exception as e:
|
||||
logger.warning(f"Object {i} failed schema validation: {e}")
|
||||
return validated
|
||||
|
||||
return objects
|
||||
|
||||
raise RuntimeError(f"Response type {resp_type} not known")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue