feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)

Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.

Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
  proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
  captures the workspace/collection/flow hierarchy.

Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
  DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
  Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
  service layer.
- Translators updated to not serialise/deserialise user.

API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.

Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
  scoped by workspace. Config client API takes workspace as first
  positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
  no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.

CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
  library) drop user kwargs from every method signature.

MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
  keyed per user.

Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
  whose blueprint template was parameterised AND no remaining
  live flow (across all workspaces) still resolves to that topic.
  Three scopes fall out naturally from template analysis:
    * {id} -> per-flow, deleted on stop
    * {blueprint} -> per-blueprint, kept while any flow of the
      same blueprint exists
    * {workspace} -> per-workspace, kept while any flow in the
      workspace exists
    * literal -> global, never deleted (e.g. tg.request.librarian)
  Fixes a bug where stopping a flow silently destroyed the global
  librarian exchange, wedging all library operations until manual
  restart.

RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
  dead connections (broker restart, orphaned channels, network
  partitions) within ~2 heartbeat windows, so the consumer
  reconnects and re-binds its queue rather than sitting forever
  on a zombie connection.

Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
  ~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
cybermaggedon 2026-04-21 23:23:01 +01:00 committed by GitHub
parent 9332089b3d
commit d35473f7f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
377 changed files with 6868 additions and 5785 deletions

View file

@ -26,42 +26,50 @@ class Service(ToolService):
self.register_config_handler(self.on_mcp_config, types=["mcp"])
# Per-workspace MCP service registries
self.mcp_services = {}
async def on_mcp_config(self, config, version):
async def on_mcp_config(self, workspace, config, version):
logger.info(f"Got config version {version}")
logger.info(
f"Got config version {version} for workspace {workspace}"
)
if "mcp" not in config:
self.mcp_services = {}
self.mcp_services[workspace] = {}
return
self.mcp_services = {
self.mcp_services[workspace] = {
k: json.loads(v)
for k, v in config["mcp"].items()
}
async def invoke_tool(self, name, parameters):
async def invoke_tool(self, workspace, name, parameters):
try:
if name not in self.mcp_services:
raise RuntimeError(f"MCP service {name} not known")
ws_services = self.mcp_services.get(workspace, {})
if "url" not in self.mcp_services[name]:
if name not in ws_services:
raise RuntimeError(
f"MCP service {name} not known in workspace "
f"{workspace}"
)
if "url" not in ws_services[name]:
raise RuntimeError(f"MCP service {name} URL not defined")
url = self.mcp_services[name]["url"]
url = ws_services[name]["url"]
if "remote-name" in self.mcp_services[name]:
remote_name = self.mcp_services[name]["remote-name"]
if "remote-name" in ws_services[name]:
remote_name = ws_services[name]["remote-name"]
else:
remote_name = name
# Build headers with optional bearer token
headers = {}
if "auth-token" in self.mcp_services[name]:
token = self.mcp_services[name]["auth-token"]
if "auth-token" in ws_services[name]:
token = ws_services[name]["auth-token"]
headers["Authorization"] = f"Bearer {token}"
logger.info(f"Invoking {remote_name} at {url}")

View file

@ -108,7 +108,7 @@ class Aggregator:
)
def build_synthesis_request(self, correlation_id, original_question,
user, collection):
collection):
"""
Build the AgentRequest that triggers the synthesis phase.
"""
@ -139,7 +139,6 @@ class Aggregator:
state="",
group=template.group if template else [],
history=history,
user=user,
collection=collection,
streaming=template.streaming if template else False,
session_id=parent_session_id,

View file

@ -46,25 +46,20 @@ from ..tool_filter import filter_tools_by_group_and_state, get_next_state
logger = logging.getLogger(__name__)
class UserAwareContext:
"""Wraps flow interface to inject user context for tools that need it."""
class FlowContext:
"""Wraps flow interface with orchestrator-only scratch state
(explain URIs, response handle, streaming flag). Workspace isolation
is enforced by the flow layer (flow.workspace), not by this class."""
def __init__(self, flow, user, respond=None, streaming=False):
def __init__(self, flow, respond=None, streaming=False):
self._flow = flow
self._user = user
self.respond = respond
self.streaming = streaming
self.current_explain_uri = None
self.last_sub_explain_uri = None
def __call__(self, service_name):
client = self._flow(service_name)
if service_name in (
"structured-query-request",
"row-embeddings-query-request",
):
client._current_user = self._user
return client
return self._flow(service_name)
class UsageTracker:
@ -131,7 +126,6 @@ class PatternBase:
state="",
group=getattr(request, 'group', []),
history=[completion_step],
user=request.user,
collection=getattr(request, 'collection', 'default'),
streaming=False,
session_id=getattr(request, 'session_id', ''),
@ -158,9 +152,9 @@ class PatternBase:
current_state=getattr(request, 'state', None),
)
def make_context(self, flow, user, respond=None, streaming=False):
"""Create a user-aware context wrapper."""
return UserAwareContext(flow, user, respond=respond, streaming=streaming)
def make_context(self, flow, respond=None, streaming=False):
"""Create a flow context wrapper."""
return FlowContext(flow, respond=respond, streaming=streaming)
def build_history(self, request):
"""Convert AgentStep history into Action objects."""
@ -249,7 +243,7 @@ class PatternBase:
# ---- Provenance emission ------------------------------------------------
async def emit_session_triples(self, flow, session_uri, question, user,
async def emit_session_triples(self, flow, session_uri, question,
collection, respond, streaming,
parent_uri=None):
"""Emit provenance triples for a new session."""
@ -264,7 +258,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=session_uri,
user=user,
collection=collection,
),
triples=triples,
@ -281,7 +274,7 @@ class PatternBase:
async def emit_pattern_decision_triples(
self, flow, session_id, session_uri, pattern, task_type,
user, collection, respond,
collection, respond,
):
"""Emit provenance triples for a meta-router pattern decision."""
uri = agent_pattern_decision_uri(session_id)
@ -292,7 +285,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -329,7 +322,7 @@ class PatternBase:
try:
await self.processor.save_answer_content(
doc_id=thought_doc_id,
user=request.user,
workspace=flow.workspace,
content=act.thought,
title=f"Agent Thought: {act.name}",
)
@ -360,7 +353,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=iteration_uri,
user=request.user,
collection=getattr(request, 'collection', 'default'),
),
triples=iter_triples,
@ -399,7 +391,7 @@ class PatternBase:
try:
await self.processor.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
workspace=flow.workspace,
content=observation_text,
title=f"Agent Observation",
)
@ -420,7 +412,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=observation_entity_uri,
user=request.user,
collection=getattr(request, 'collection', 'default'),
),
triples=obs_triples,
@ -456,7 +447,7 @@ class PatternBase:
try:
await self.processor.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
workspace=flow.workspace,
content=answer_text,
title=f"Agent Answer: {request.question[:50]}...",
)
@ -478,7 +469,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=final_uri,
user=request.user,
collection=getattr(request, 'collection', 'default'),
),
triples=final_triples,
@ -496,7 +486,7 @@ class PatternBase:
# ---- Orchestrator provenance helpers ------------------------------------
async def emit_decomposition_triples(
self, flow, session_id, session_uri, goals, user, collection,
self, flow, session_id, session_uri, goals, collection,
respond, streaming,
):
"""Emit provenance for a supervisor decomposition step."""
@ -506,7 +496,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -516,7 +506,7 @@ class PatternBase:
))
async def emit_finding_triples(
self, flow, session_id, index, goal, answer_text, user, collection,
self, flow, session_id, index, goal, answer_text, collection,
respond, streaming, subagent_session_id="",
):
"""Emit provenance for a subagent finding."""
@ -532,7 +522,7 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc"
try:
await self.processor.save_answer_content(
doc_id=doc_id, user=user,
doc_id=doc_id, workspace=flow.workspace,
content=answer_text,
title=f"Finding: {goal[:60]}",
)
@ -545,7 +535,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -555,7 +545,7 @@ class PatternBase:
))
async def emit_plan_triples(
self, flow, session_id, session_uri, steps, user, collection,
self, flow, session_id, session_uri, steps, collection,
respond, streaming,
):
"""Emit provenance for a plan creation."""
@ -565,7 +555,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -575,7 +565,7 @@ class PatternBase:
))
async def emit_step_result_triples(
self, flow, session_id, index, goal, answer_text, user, collection,
self, flow, session_id, index, goal, answer_text, collection,
respond, streaming,
):
"""Emit provenance for a plan step result."""
@ -585,7 +575,7 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc"
try:
await self.processor.save_answer_content(
doc_id=doc_id, user=user,
doc_id=doc_id, workspace=flow.workspace,
content=answer_text,
title=f"Step result: {goal[:60]}",
)
@ -598,7 +588,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -608,7 +598,7 @@ class PatternBase:
))
async def emit_synthesis_triples(
self, flow, session_id, previous_uris, answer_text, user, collection,
self, flow, session_id, previous_uris, answer_text, collection,
respond, streaming, termination_reason=None,
):
"""Emit provenance for a synthesis answer."""
@ -617,7 +607,7 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc"
try:
await self.processor.save_answer_content(
doc_id=doc_id, user=user,
doc_id=doc_id, workspace=flow.workspace,
content=answer_text,
title="Synthesis",
)
@ -633,7 +623,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -751,7 +741,6 @@ class PatternBase:
)
for h in history
],
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id,

View file

@ -53,7 +53,7 @@ class PlanThenExecutePattern(PatternBase):
if iteration_num == 1:
await self.emit_session_triples(
flow, session_uri, request.question,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
logger.info(
@ -109,11 +109,17 @@ class PlanThenExecutePattern(PatternBase):
think = self.make_think_callback(respond, streaming)
tools = self.filter_tools(self.processor.agent.tools, request)
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
tools = self.filter_tools(agent.tools, request)
framing = getattr(request, 'framing', '')
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -147,7 +153,7 @@ class PlanThenExecutePattern(PatternBase):
step_goals = [ps.get("goal", "") for ps in plan_steps]
await self.emit_plan_triples(
flow, session_id, session_uri, step_goals,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
# Build PlanStep objects
@ -179,7 +185,6 @@ class PlanThenExecutePattern(PatternBase):
state=request.state,
group=getattr(request, 'group', []),
history=new_history,
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -237,9 +242,15 @@ class PlanThenExecutePattern(PatternBase):
"result": dep_result,
})
tools = self.filter_tools(self.processor.agent.tools, request)
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
tools = self.filter_tools(agent.tools, request)
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
@ -307,7 +318,7 @@ class PlanThenExecutePattern(PatternBase):
# Emit step result provenance
await self.emit_step_result_triples(
flow, session_id, pending_idx, goal, step_result,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
# Build execution step for history
@ -327,7 +338,6 @@ class PlanThenExecutePattern(PatternBase):
state=request.state,
group=getattr(request, 'group', []),
history=new_history,
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -352,7 +362,7 @@ class PlanThenExecutePattern(PatternBase):
framing = getattr(request, 'framing', '')
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -387,7 +397,7 @@ class PlanThenExecutePattern(PatternBase):
last_step_uri = make_step_result_uri(session_id, len(plan) - 1)
await self.emit_synthesis_triples(
flow, session_id, last_step_uri,
response_text, request.user, collection, respond, streaming,
response_text, collection, respond, streaming,
termination_reason="plan-complete",
)

View file

@ -61,7 +61,7 @@ class ReactPattern(PatternBase):
)
await self.emit_session_triples(
flow, session_uri, request.question,
request.user, collection, respond, streaming,
collection, respond, streaming,
parent_uri=parent_uri,
)
@ -80,13 +80,20 @@ class ReactPattern(PatternBase):
observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id)
answer_cb = self.make_answer_callback(respond, streaming, message_id=answer_msg_id)
# Look up the per-workspace agent
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
# Filter tools
filtered_tools = self.filter_tools(
self.processor.agent.tools, request,
agent.tools, request,
)
# Create temporary agent with filtered tools and optional framing
additional_context = self.processor.agent.additional_context
additional_context = agent.additional_context
framing = getattr(request, 'framing', '')
if framing:
if additional_context:
@ -100,7 +107,7 @@ class ReactPattern(PatternBase):
)
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)

View file

@ -42,7 +42,7 @@ from ..tool_filter import validate_tool_config
from ..react.types import Final, Action, Tool, Argument
from . meta_router import MetaRouter
from . pattern_base import PatternBase, UserAwareContext
from . pattern_base import PatternBase, FlowContext
from . react_pattern import ReactPattern
from . plan_pattern import PlanThenExecutePattern
from . supervisor_pattern import SupervisorPattern
@ -76,10 +76,9 @@ class Processor(AgentService):
}
)
self.agent = AgentManager(
tools={},
additional_context="",
)
# Per-workspace agent managers and meta-routers
self.agents = {}
self.meta_routers = {}
self.tool_service_clients = {}
@ -91,9 +90,6 @@ class Processor(AgentService):
# Aggregator for supervisor fan-in
self.aggregator = Aggregator()
# Meta-router (initialised on first config load)
self.meta_router = None
self.register_config_handler(
self.on_tools_config, types=["tool", "tool-service"]
)
@ -204,13 +200,13 @@ class Processor(AgentService):
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None,
async def save_answer_content(self, doc_id, workspace, content, title=None,
timeout=120):
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
@ -221,7 +217,7 @@ class Processor(AgentService):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
future = asyncio.get_event_loop().create_future()
@ -247,9 +243,12 @@ class Processor(AgentService):
def provenance_session_uri(self, session_id):
return agent_session_uri(session_id)
async def on_tools_config(self, config, version):
async def on_tools_config(self, workspace, config, version):
logger.info(f"Loading configuration version {version}")
logger.info(
f"Loading configuration version {version} "
f"for workspace {workspace}"
)
try:
tools = {}
@ -316,7 +315,6 @@ class Processor(AgentService):
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection"),
user=None,
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
@ -324,7 +322,6 @@ class Processor(AgentService):
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None,
index_name=data.get("index-name"),
limit=int(data.get("limit", 10)),
)
@ -408,15 +405,17 @@ class Processor(AgentService):
agent_config = config[self.config_key]
additional = agent_config.get("additional-context", None)
self.agent = AgentManager(
self.agents[workspace] = AgentManager(
tools=tools,
additional_context=additional,
)
# Re-initialise meta-router with config
self.meta_router = MetaRouter(config=config)
# Re-initialise meta-router with config for this workspace
self.meta_routers[workspace] = MetaRouter(config=config)
logger.info(f"Loaded {len(tools)} tools")
logger.info(
f"Loaded {len(tools)} tools for workspace {workspace}"
)
except Exception as e:
logger.error(
@ -466,7 +465,7 @@ class Processor(AgentService):
await self.supervisor_pattern.emit_finding_triples(
flow, parent_session_id, finding_index,
subagent_goal, answer_text,
template.user, collection,
collection,
respond, template.streaming,
subagent_session_id=subagent_session_id,
)
@ -486,7 +485,6 @@ class Processor(AgentService):
synthesis_request = self.aggregator.build_synthesis_request(
correlation_id,
original_question=template.question,
user=template.user,
collection=getattr(template, 'collection', 'default'),
)
@ -515,10 +513,11 @@ class Processor(AgentService):
# If no pattern set and this is the first iteration, route
if not pattern and not request.history:
context = UserAwareContext(flow, request.user)
context = FlowContext(flow)
if self.meta_router:
pattern, task_type, framing = await self.meta_router.route(
meta_router = self.meta_routers.get(flow.workspace)
if meta_router:
pattern, task_type, framing = await meta_router.route(
request.question, context, usage=usage,
)
else:
@ -553,7 +552,6 @@ class Processor(AgentService):
await selected.emit_pattern_decision_triples(
flow, session_id, session_uri,
pattern, getattr(request, 'task_type', ''),
request.user,
getattr(request, 'collection', 'default'),
respond,
)

View file

@ -54,7 +54,7 @@ class SupervisorPattern(PatternBase):
if iteration_num == 1:
await self.emit_session_triples(
flow, session_uri, request.question,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
logger.info(
@ -99,10 +99,16 @@ class SupervisorPattern(PatternBase):
)
framing = getattr(request, 'framing', '')
tools = self.filter_tools(self.processor.agent.tools, request)
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
tools = self.filter_tools(agent.tools, request)
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -144,7 +150,7 @@ class SupervisorPattern(PatternBase):
# Emit decomposition provenance
await self.emit_decomposition_triples(
flow, session_id, session_uri, goals,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
# Fan out: emit a subagent request for each goal
@ -155,7 +161,6 @@ class SupervisorPattern(PatternBase):
state="",
group=getattr(request, 'group', []),
history=[],
user=request.user,
collection=collection,
streaming=False, # Subagents don't stream
session_id=subagent_session,
@ -207,7 +212,7 @@ class SupervisorPattern(PatternBase):
subagent_results = {"(no results)": "No subagent results available"}
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -237,7 +242,7 @@ class SupervisorPattern(PatternBase):
]
await self.emit_synthesis_triples(
flow, session_id, finding_uris,
response_text, request.user, collection, respond, streaming,
response_text, collection, respond, streaming,
termination_reason="subagents-complete",
)

View file

@ -10,6 +10,7 @@ import sys
import functools
import logging
import uuid
from typing import Dict
from datetime import datetime, timezone
# Module logger
@ -73,10 +74,8 @@ class Processor(AgentService):
}
)
self.agent = AgentManager(
tools={},
additional_context="",
)
# Per-workspace agent managers
self.agents: Dict[str, AgentManager] = {}
# Track active tool service clients for cleanup
self.tool_service_clients = {}
@ -193,13 +192,13 @@ class Processor(AgentService):
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
@ -211,7 +210,7 @@ class Processor(AgentService):
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
@ -222,7 +221,7 @@ class Processor(AgentService):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
# Create future for response
@ -249,9 +248,12 @@ class Processor(AgentService):
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, config, version):
async def on_tools_config(self, workspace, config, version):
logger.info(f"Loading configuration version {version}")
logger.info(
f"Loading configuration version {version} "
f"for workspace {workspace}"
)
try:
@ -321,7 +323,6 @@ class Processor(AgentService):
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection"),
user=None # User will be provided dynamically via context
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
@ -329,7 +330,6 @@ class Processor(AgentService):
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None, # User will be provided dynamically via context
index_name=data.get("index-name"), # Optional filter
limit=int(data.get("limit", 10)) # Max results
)
@ -409,13 +409,17 @@ class Processor(AgentService):
agent_config = config[self.config_key]
additional = agent_config.get("additional-context", None)
self.agent = AgentManager(
self.agents[workspace] = AgentManager(
tools=tools,
additional_context=additional
)
logger.info(f"Loaded {len(tools)} tools")
logger.info("Tool configuration reloaded.")
logger.info(
f"Loaded {len(tools)} tools for workspace {workspace}"
)
logger.info(
f"Tool configuration reloaded for workspace {workspace}."
)
except Exception as e:
@ -460,7 +464,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=session_uri,
user=request.user,
collection=collection,
),
triples=triples,
@ -557,35 +560,41 @@ class Processor(AgentService):
await respond(r)
# Look up the agent for this workspace
workspace = flow.workspace
agent = self.agents.get(workspace)
if agent is None:
logger.error(
f"No agent configuration loaded for workspace "
f"{workspace}"
)
raise RuntimeError(
f"No agent configuration for workspace {workspace}"
)
# Apply tool filtering based on request groups and state
filtered_tools = filter_tools_by_group_and_state(
tools=self.agent.tools,
tools=agent.tools,
requested_groups=getattr(request, 'group', None),
current_state=getattr(request, 'state', None)
)
# Create temporary agent with filtered tools
temp_agent = AgentManager(
tools=filtered_tools,
additional_context=self.agent.additional_context
additional_context=agent.additional_context
)
logger.debug("Call React")
# Create user-aware context wrapper that preserves the flow interface
# but adds user information for tools that need it
class UserAwareContext:
def __init__(self, flow, user):
# Thin wrapper around flow — carries only explain URI state.
class _Context:
def __init__(self, flow):
self._flow = flow
self._user = user
self.last_sub_explain_uri = None
def __call__(self, service_name):
client = self._flow(service_name)
# For query clients that need user context, store it
if service_name in ("structured-query-request", "row-embeddings-query-request"):
client._current_user = self._user
return client
return self._flow(service_name)
# Callback: emit Analysis+ToolUse triples before tool executes
async def on_action(act_decision):
@ -604,7 +613,7 @@ class Processor(AgentService):
try:
await self.save_answer_content(
doc_id=t_doc_id,
user=request.user,
workspace=flow.workspace,
content=act_decision.thought,
title=f"Agent Thought: {act_decision.name}",
)
@ -629,7 +638,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=iter_uri,
user=request.user,
collection=collection,
),
triples=iter_triples,
@ -644,7 +652,7 @@ class Processor(AgentService):
explain_triples=iter_triples,
))
user_context = UserAwareContext(flow, request.user)
user_context = _Context(flow)
act = await temp_agent.react(
question = request.question,
@ -685,7 +693,7 @@ class Processor(AgentService):
try:
await self.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
workspace=flow.workspace,
content=f,
title=f"Agent Answer: {request.question[:50]}...",
)
@ -706,7 +714,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=final_uri,
user=request.user,
collection=collection,
),
triples=final_triples,
@ -763,7 +770,7 @@ class Processor(AgentService):
try:
await self.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
workspace=flow.workspace,
content=act.observation,
title=f"Agent Observation",
)
@ -783,7 +790,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=observation_entity_uri,
user=request.user,
collection=collection,
),
triples=obs_triples,
@ -820,7 +826,6 @@ class Processor(AgentService):
)
for h in history
],
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id, # Pass session_id for provenance continuity

View file

@ -116,31 +116,26 @@ class McpToolImpl:
# This tool implementation knows how to query structured data using natural language
class StructuredQueryImpl:
def __init__(self, context, collection=None, user=None):
def __init__(self, context, collection=None):
self.context = context
self.collection = collection # For multi-tenant scenarios
self.user = user # User context for multi-tenancy
self.collection = collection
@staticmethod
def get_arguments():
return [
Argument(
name="question",
type="string",
type="string",
description="Natural language question about structured data (tables, databases, etc.)"
)
]
async def invoke(self, **arguments):
client = self.context("structured-query-request")
logger.debug("Structured query question...")
# Get user from client context if available, otherwise use instance user or default
user = getattr(client, '_current_user', self.user or "trustgraph")
result = await client.structured_query(
question=arguments.get("question"),
user=user,
collection=self.collection or "default"
)
@ -159,11 +154,10 @@ class StructuredQueryImpl:
# This tool implementation knows how to query row embeddings for semantic search
class RowEmbeddingsQueryImpl:
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
def __init__(self, context, schema_name, collection=None, index_name=None, limit=10):
self.context = context
self.schema_name = schema_name
self.collection = collection
self.user = user
self.index_name = index_name # Optional: filter to specific index
self.limit = limit # Max results to return
@ -190,13 +184,9 @@ class RowEmbeddingsQueryImpl:
client = self.context("row-embeddings-query-request")
logger.debug("Row embeddings query...")
# Get user from client context if available
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vector=vector,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",
index_name=self.index_name,
limit=self.limit
@ -250,7 +240,7 @@ class ToolServiceImpl:
Initialize a tool service implementation.
Args:
context: The context function (provides user info)
context: Flow context (callable resolving service names to clients)
request_queue: Full Pulsar topic for requests
response_queue: Full Pulsar topic for responses
config_values: Dict of config values (e.g., {"collection": "customers"})
@ -325,17 +315,10 @@ class ToolServiceImpl:
logger.debug(f"Config: {self.config_values}")
logger.debug(f"Arguments: {arguments}")
# Get user from context if available
user = "trustgraph"
if hasattr(self.context, '_user'):
user = self.context._user
# Get or create the client
client = await self._get_or_create_client()
# Call the tool service
response = await client.call(
user=user,
config=self.config_values,
arguments=arguments,
)

View file

@ -95,7 +95,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
text = await self.get_document_text(v, flow.workspace)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
@ -144,7 +144,7 @@ class Processor(ChunkingService):
await self.librarian.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
@ -168,7 +168,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -179,7 +178,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,

View file

@ -92,7 +92,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
text = await self.get_document_text(v, flow.workspace)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
@ -140,7 +140,7 @@ class Processor(ChunkingService):
await self.librarian.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
@ -164,7 +164,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -175,7 +174,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,

View file

@ -9,42 +9,8 @@ from ... tables.config import ConfigTableStore
# Module logger
logger = logging.getLogger(__name__)
class ConfigurationClass:
async def keys(self):
return await self.table_store.get_keys(self.type)
async def values(self):
vals = await self.table_store.get_values(self.type)
return {
v[0]: v[1]
for v in vals
}
async def get(self, key):
return await self.table_store.get_value(self.type, key)
async def put(self, key, value):
return await self.table_store.put_config(self.type, key, value)
async def delete(self, key):
return await self.table_store.delete_key(self.type, key)
async def has(self, key):
val = await self.table_store.get_value(self.type, key)
return val is not None
class Configuration:
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
# FIXME: This has state now, but does it address all of the above?
# REVIEW: Above
# FIXME: Some version vs config race conditions
def __init__(self, push, host, username, password, keyspace):
# External function to respond to update
@ -60,34 +26,17 @@ class Configuration:
async def get_version(self):
return await self.table_store.get_version()
def get(self, type):
c = ConfigurationClass()
c.table_store = self.table_store
c.type = type
return c
async def handle_get(self, v):
# for k in v.keys:
# if k.type not in self or k.key not in self[k.type]:
# return ConfigResponse(
# version = None,
# values = None,
# directory = None,
# config = None,
# error = Error(
# type = "key-error",
# message = f"Key error"
# )
# )
workspace = v.workspace
values = [
ConfigValue(
type = k.type,
key = k.key,
value = await self.table_store.get_value(k.type, k.key)
value = await self.table_store.get_value(
workspace, k.type, k.key
)
)
for k in v.keys
]
@ -96,43 +45,19 @@ class Configuration:
version = await self.get_version(),
values = values,
)
async def handle_list(self, v):
# if v.type not in self:
# return ConfigResponse(
# version = None,
# values = None,
# directory = None,
# config = None,
# error = Error(
# type = "key-error",
# message = "No such type",
# ),
# )
return ConfigResponse(
version = await self.get_version(),
directory = await self.table_store.get_keys(v.type),
directory = await self.table_store.get_keys(
v.workspace, v.type
),
)
async def handle_getvalues(self, v):
# if v.type not in self:
# return ConfigResponse(
# version = None,
# values = None,
# directory = None,
# config = None,
# error = Error(
# type = "key-error",
# message = f"Key error"
# )
# )
vals = await self.table_store.get_values(v.type)
vals = await self.table_store.get_values(v.workspace, v.type)
values = map(
lambda x: ConfigValue(
@ -146,39 +71,63 @@ class Configuration:
values = list(values),
)
async def handle_getvalues_all_ws(self, v):
"""Fetch all values of a given type across all workspaces.
Used by shared processors to load type-scoped config at
startup without enumerating workspaces separately."""
vals = await self.table_store.get_values_all_ws(v.type)
values = [
ConfigValue(
workspace = row[0],
type = v.type,
key = row[1],
value = row[2],
)
for row in vals
]
return ConfigResponse(
version = await self.get_version(),
values = values,
)
async def handle_delete(self, v):
workspace = v.workspace
types = list(set(k.type for k in v.keys))
for k in v.keys:
await self.table_store.delete_key(k.type, k.key)
await self.table_store.delete_key(workspace, k.type, k.key)
await self.inc_version()
await self.push(types=types)
await self.push(changes={t: [workspace] for t in types})
return ConfigResponse(
)
async def handle_put(self, v):
workspace = v.workspace
types = list(set(k.type for k in v.values))
for k in v.values:
await self.table_store.put_config(k.type, k.key, k.value)
await self.table_store.put_config(
workspace, k.type, k.key, k.value
)
await self.inc_version()
await self.push(types=types)
await self.push(changes={t: [workspace] for t in types})
return ConfigResponse(
)
async def get_config(self):
async def get_config(self, workspace):
table = await self.table_store.get_all()
table = await self.table_store.get_all_for_workspace(workspace)
config = {}
@ -191,7 +140,7 @@ class Configuration:
async def handle_config(self, v):
config = await self.get_config()
config = await self.get_config(v.workspace)
return ConfigResponse(
version = await self.get_version(),
@ -200,7 +149,20 @@ class Configuration:
async def handle(self, msg):
logger.debug(f"Handling config message: {msg.operation}")
logger.debug(
f"Handling config message: {msg.operation} "
f"workspace={msg.workspace}"
)
# getvalues-all-ws spans all workspaces, so no workspace
# required; everything else is workspace-scoped.
if msg.operation != "getvalues-all-ws" and not msg.workspace:
return ConfigResponse(
error=Error(
type = "bad-request",
message = "Workspace is required"
)
)
if msg.operation == "get":
@ -214,6 +176,10 @@ class Configuration:
resp = await self.handle_getvalues(msg)
elif msg.operation == "getvalues-all-ws":
resp = await self.handle_getvalues_all_ws(msg)
elif msg.operation == "delete":
resp = await self.handle_delete(msg)

View file

@ -128,18 +128,21 @@ class Processor(AsyncProcessor):
await self.push() # Startup poke: empty types = everything
await self.config_request_consumer.start()
async def push(self, types=None):
async def push(self, changes=None):
version = await self.config.get_version()
resp = ConfigPush(
version = version,
types = types or [],
changes = changes or {},
)
await self.config_push_producer.send(resp)
logger.info(f"Pushed config poke version {version}, types={resp.types}")
logger.info(
f"Pushed config poke version {version}, "
f"changes={resp.changes}"
)
async def on_config_request(self, msg, consumer, flow):

View file

@ -33,7 +33,7 @@ class KnowledgeManager:
logger.info("Deleting knowledge core...")
await self.table_store.delete_kg_core(
request.user, request.id
request.workspace, request.id
)
await respond(
@ -63,7 +63,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_triples(
request.user,
request.workspace,
request.id,
publish_triples,
)
@ -81,7 +81,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_graph_embeddings(
request.user,
request.workspace,
request.id,
publish_ge,
)
@ -100,7 +100,7 @@ class KnowledgeManager:
async def list_kg_cores(self, request, respond):
ids = await self.table_store.list_kg_cores(request.user)
ids = await self.table_store.list_kg_cores(request.workspace)
await respond(
KnowledgeResponse(
@ -114,12 +114,14 @@ class KnowledgeManager:
async def put_kg_core(self, request, respond):
workspace = request.workspace
if request.triples:
await self.table_store.add_triples(request.triples)
await self.table_store.add_triples(workspace, request.triples)
if request.graph_embeddings:
await self.table_store.add_graph_embeddings(
request.graph_embeddings
workspace, request.graph_embeddings
)
await respond(
@ -178,10 +180,15 @@ class KnowledgeManager:
if request.flow is None:
raise RuntimeError("Flow ID must be specified")
if request.flow not in self.flow_config.flows:
raise RuntimeError("Invalid flow")
workspace = request.workspace
ws_flows = self.flow_config.flows.get(workspace, {})
if request.flow not in ws_flows:
raise RuntimeError(
f"Invalid flow {request.flow} for workspace "
f"{workspace}"
)
flow = self.flow_config.flows[request.flow]
flow = ws_flows[request.flow]
if "interfaces" not in flow:
raise RuntimeError("No defined interfaces")
@ -257,7 +264,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_triples(
request.user,
request.workspace,
request.id,
publish_triples,
)
@ -272,7 +279,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_graph_embeddings(
request.user,
request.workspace,
request.id,
publish_ge,
)

View file

@ -124,19 +124,21 @@ class Processor(AsyncProcessor):
await self.knowledge_request_consumer.start()
await self.knowledge_response_producer.start()
async def on_knowledge_config(self, config, version):
async def on_knowledge_config(self, workspace, config, version):
logger.info(f"Configuration version: {version}")
logger.info(
f"Configuration version: {version} workspace: {workspace}"
)
if "flow" in config:
self.flows = {
self.flows[workspace] = {
k: json.loads(v)
for k, v in config["flow"].items()
}
else:
self.flows = {}
self.flows[workspace] = {}
logger.debug(f"Flows: {self.flows}")
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
async def process_request(self, v, id):

View file

@ -200,7 +200,7 @@ class Processor(FlowProcessor):
if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
@ -215,7 +215,7 @@ class Processor(FlowProcessor):
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.librarian.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
if isinstance(content, str):
content = content.encode('utf-8')
@ -243,7 +243,7 @@ class Processor(FlowProcessor):
await self.librarian.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=page_content,
document_type="page",
title=f"Page {page_num}",
@ -265,7 +265,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -277,7 +276,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,

View file

@ -93,7 +93,7 @@ class Processor(FlowProcessor):
if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
@ -114,7 +114,7 @@ class Processor(FlowProcessor):
content = await self.librarian.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
# Content is base64 encoded
@ -157,7 +157,7 @@ class Processor(FlowProcessor):
await self.librarian.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=page_content,
document_type="page",
title=f"Page {page_num}",
@ -179,7 +179,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -191,7 +190,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
def make_safe_collection_name(workspace, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default'
return safe
safe_user = sanitize(user)
safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
return f"{prefix}_{safe_workspace}_{safe_collection}"
class DocVectors:
@ -49,26 +49,26 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""
Check if any collection exists for this user/collection combination.
Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
def create_collection(self, workspace, collection, dimension=384):
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
@ -116,15 +116,15 @@ class DocVectors:
index_params=index_params
)
self.collections[(dimension, user, collection)] = collection_name
self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, chunk_id, user, collection):
def insert(self, embeds, chunk_id, workspace, collection):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, workspace, collection)
data = [
{
@ -134,25 +134,25 @@ class DocVectors:
]
self.client.insert(
collection_name=self.collections[(dim, user, collection)],
collection_name=self.collections[(dim, workspace, collection)],
data=data
)
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10):
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix)
if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...")
self.client.load_collection(
@ -181,12 +181,12 @@ class DocVectors:
return res
def delete_collection(self, user, collection):
def delete_collection(self, workspace, collection):
"""
Delete all dimension variants of the collection for the given user/collection.
Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
# Get all collections and filter for matches
@ -199,10 +199,10 @@ class DocVectors:
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
def make_safe_collection_name(workspace, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default'
return safe
safe_user = sanitize(user)
safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
return f"{prefix}_{safe_workspace}_{safe_collection}"
class EntityVectors:
@ -49,26 +49,26 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""
Check if any collection exists for this user/collection combination.
Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
def create_collection(self, workspace, collection, dimension=384):
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
@ -122,15 +122,15 @@ class EntityVectors:
index_params=index_params
)
self.collections[(dimension, user, collection)] = collection_name
self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, entity, user, collection, chunk_id=""):
def insert(self, embeds, entity, workspace, collection, chunk_id=""):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, workspace, collection)
data = [
{
@ -141,25 +141,25 @@ class EntityVectors:
]
self.client.insert(
collection_name=self.collections[(dim, user, collection)],
collection_name=self.collections[(dim, workspace, collection)],
data=data
)
def search(self, embeds, user, collection, fields=["entity"], limit=10):
def search(self, embeds, workspace, collection, fields=["entity"], limit=10):
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix)
if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...")
self.client.load_collection(
@ -188,12 +188,12 @@ class EntityVectors:
return res
def delete_collection(self, user, collection):
def delete_collection(self, workspace, collection):
"""
Delete all dimension variants of the collection for the given user/collection.
Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
# Get all collections and filter for matches
@ -206,10 +206,10 @@ class EntityVectors:
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -69,19 +69,26 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.register_config_handler(self.on_schema_config, types=["schema"])
self.register_config_handler(self.on_collection_config, types=["collection"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
@ -115,13 +122,19 @@ class Processor(CollectionConfigHandler, FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
@ -149,23 +162,29 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""Process incoming ExtractedObject and compute embeddings"""
obj = msg.value()
workspace = flow.workspace
logger.info(
f"Computing embeddings for {len(obj.values)} rows, "
f"schema {obj.schema_name}, doc {obj.metadata.id}"
f"schema {obj.schema_name}, doc {obj.metadata.id}, "
f"workspace {workspace}"
)
# Validate collection exists before processing
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
if not self.collection_exists(workspace, obj.metadata.collection):
logger.warning(
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
f"Collection {obj.metadata.collection} for workspace {workspace} "
f"does not exist in config. Dropping message."
)
return
# Get schema definition
schema = self.schemas.get(obj.schema_name)
# Get schema definition for this workspace
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
logger.warning(
f"No schema found for {obj.schema_name} in "
f"workspace {workspace} - skipping"
)
return
# Get all index names for this schema
@ -239,13 +258,13 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error("Exception during embedding computation", exc_info=True)
raise e
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
logger.debug(f"Row embeddings collection notification for {workspace}/{collection}")
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Collection deletion notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
logger.debug(f"Row embeddings collection delete notification for {workspace}/{collection}")
@staticmethod
def add_args(parser):

View file

@ -75,24 +75,36 @@ class Processor(FlowProcessor):
)
)
# Null configuration, should reload quickly
self.manager = PromptManager()
# Per-workspace prompt managers
self.managers = {}
async def on_prompt_config(self, config, version):
async def on_prompt_config(self, workspace, config, version):
logger.info(f"Loading configuration version {version}")
logger.info(
f"Loading configuration version {version} "
f"for workspace {workspace}"
)
if self.config_key not in config:
logger.warning(f"No key {self.config_key} in config")
logger.warning(
f"No key {self.config_key} in config for {workspace}"
)
return
config = config[self.config_key]
prompt_config = config[self.config_key]
try:
self.manager.load_config(config)
manager = self.managers.get(workspace)
if manager is None:
manager = PromptManager()
self.managers[workspace] = manager
logger.info("Prompt configuration reloaded")
manager.load_config(prompt_config)
logger.info(
f"Prompt configuration reloaded for {workspace}"
)
except Exception as e:
@ -107,7 +119,6 @@ class Processor(FlowProcessor):
metadata = Metadata(
id = metadata.id,
root = metadata.root,
user = metadata.user,
collection = metadata.collection,
),
triples = triples,
@ -120,7 +131,6 @@ class Processor(FlowProcessor):
metadata = Metadata(
id = metadata.id,
root = metadata.root,
user = metadata.user,
collection = metadata.collection,
),
entities = entity_contexts,
@ -170,13 +180,24 @@ class Processor(FlowProcessor):
try:
v = msg.value()
workspace = flow.workspace
# Extract chunk text
chunk_text = v.chunk.decode('utf-8')
logger.debug("Processing chunk for agent extraction")
logger.debug(
f"Processing chunk for agent extraction, "
f"workspace {workspace}"
)
prompt = self.manager.render(
manager = self.managers.get(workspace)
if manager is None:
logger.error(
f"No prompt configuration for workspace {workspace}"
)
return
prompt = manager.render(
self.template_id,
{
"text": chunk_text

View file

@ -213,7 +213,6 @@ class Processor(FlowProcessor):
Metadata(
id=v.metadata.id,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
batch
@ -227,7 +226,6 @@ class Processor(FlowProcessor):
Metadata(
id=v.metadata.id,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
batch

View file

@ -109,20 +109,22 @@ class Processor(FlowProcessor):
# Register config handler for ontology updates
self.register_config_handler(self.on_ontology_config, types=["ontology"])
# Shared components (not flow-specific)
self.ontology_loader = OntologyLoader()
# Per-workspace ontology loaders
self.ontology_loaders = {} # workspace -> OntologyLoader
self.text_processor = TextProcessor()
# Per-flow components (each flow gets its own embedder/vector store/selector)
self.flow_components = {} # flow_id -> {embedder, vector_store, selector}
# Per-flow components (each flow gets its own embedder/vector
# store/selector). Keyed by id(flow) — Flow objects are unique
# per (workspace, flow), so this is implicitly workspace-scoped.
self.flow_components = {}
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3)
# Track loaded ontology version
self.current_ontology_version = None
self.loaded_ontology_ids = set()
# Per-workspace ontology version tracking
self.current_ontology_versions = {} # workspace -> version
self.loaded_ontology_ids = {} # workspace -> set of ids
async def initialize_flow_components(self, flow):
"""Initialize per-flow OntoRAG components.
@ -167,17 +169,23 @@ class Processor(FlowProcessor):
vector_store=vector_store
)
# Embed all loaded ontologies for this flow
if self.ontology_loader.get_all_ontologies():
logger.info(f"Embedding ontologies for flow {flow_id}")
for ont_id, ontology in self.ontology_loader.get_all_ontologies().items():
workspace = flow.workspace
# Embed all loaded ontologies for this workspace
loader = self.ontology_loaders.get(workspace)
if loader is not None and loader.get_all_ontologies():
logger.info(
f"Embedding ontologies for flow {flow_id} "
f"(workspace {workspace})"
)
for ont_id, ontology in loader.get_all_ontologies().items():
await ontology_embedder.embed_ontology(ontology)
logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}")
# Initialize ontology selector
ontology_selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=self.ontology_loader,
ontology_loader=loader,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold
)
@ -187,7 +195,8 @@ class Processor(FlowProcessor):
'embedder': ontology_embedder,
'vector_store': vector_store,
'selector': ontology_selector,
'dimension': dimension
'dimension': dimension,
'workspace': workspace,
}
logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})")
@ -197,31 +206,27 @@ class Processor(FlowProcessor):
logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True)
raise
async def on_ontology_config(self, config, version):
"""
Handle ontology configuration updates from ConfigPush queue.
Parses and stores ontologies. Embedding happens per-flow on first message.
Called automatically when:
- Processor starts (gets full config history via start_of_messages=True)
- Config service pushes updates (immediate event-driven notification)
Args:
config: Full configuration map - config[type][key] = value
version: Config version number (monotonically increasing)
"""
async def on_ontology_config(self, workspace, config, version):
"""Handle ontology configuration updates for a workspace."""
try:
logger.info(f"Received ontology config update, version={version}")
logger.info(
f"Received ontology config update, "
f"version={version} workspace={workspace}"
)
# Skip if we've already processed this version
if version == self.current_ontology_version:
logger.debug(f"Already at version {version}, skipping")
# Skip if we've already processed this version for this workspace
if version == self.current_ontology_versions.get(workspace):
logger.debug(
f"Already at version {version} for {workspace}, "
f"skipping"
)
return
# Extract ontology configurations
if "ontology" not in config:
logger.warning("No 'ontology' section in config")
logger.warning(
f"No 'ontology' section in config for {workspace}"
)
return
ontology_configs = config["ontology"]
@ -235,38 +240,65 @@ class Processor(FlowProcessor):
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
continue
logger.info(f"Loaded {len(ontologies)} ontology definitions")
logger.info(
f"Loaded {len(ontologies)} ontology definitions "
f"for {workspace}"
)
# Determine what changed (for incremental updates)
# Determine what changed for this workspace
ws_loaded_ids = self.loaded_ontology_ids.get(workspace, set())
new_ids = set(ontologies.keys())
added_ids = new_ids - self.loaded_ontology_ids
removed_ids = self.loaded_ontology_ids - new_ids
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
added_ids = new_ids - ws_loaded_ids
removed_ids = ws_loaded_ids - new_ids
updated_ids = new_ids & ws_loaded_ids # May have changed content
if added_ids:
logger.info(f"New ontologies: {added_ids}")
logger.info(f"New ontologies in {workspace}: {added_ids}")
if removed_ids:
logger.info(f"Removed ontologies: {removed_ids}")
logger.info(f"Removed ontologies in {workspace}: {removed_ids}")
if updated_ids:
logger.info(f"Updated ontologies: {updated_ids}")
logger.info(f"Updated ontologies in {workspace}: {updated_ids}")
# Update ontology loader's internal state
self.ontology_loader.update_ontologies(ontologies)
# Get or create per-workspace loader
loader = self.ontology_loaders.get(workspace)
if loader is None:
loader = OntologyLoader()
self.ontology_loaders[workspace] = loader
loader.update_ontologies(ontologies)
# Clear all flow components to force re-embedding with new ontologies
# Clear flow components for this workspace to force
# re-embedding with new ontologies.
if added_ids or removed_ids or updated_ids:
logger.info("Clearing flow components to trigger re-embedding")
self.flow_components.clear()
self._clear_workspace_flow_components(workspace)
# Update tracking
self.current_ontology_version = version
self.loaded_ontology_ids = new_ids
self.current_ontology_versions[workspace] = version
self.loaded_ontology_ids[workspace] = new_ids
logger.info(f"Ontology config update complete, version={version}")
logger.info(
f"Ontology config update complete for {workspace}, "
f"version={version}"
)
except Exception as e:
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
def _clear_workspace_flow_components(self, workspace):
"""Drop cached flow components belonging to the given workspace
so they're re-initialised on next message with fresh ontology
embeddings."""
to_remove = [
fid for fid, comp in self.flow_components.items()
if comp.get("workspace") == workspace
]
if to_remove:
logger.info(
f"Clearing {len(to_remove)} flow components for "
f"workspace {workspace}"
)
for fid in to_remove:
del self.flow_components[fid]
async def on_message(self, msg, consumer, flow):
"""Process incoming chunk message."""
v = msg.value()
@ -624,7 +656,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=metadata.id,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
triples=triples,
@ -637,7 +668,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=metadata.id,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
entities=entities,

View file

@ -207,7 +207,6 @@ class Processor(FlowProcessor):
Metadata(
id=v.metadata.id,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
batch

View file

@ -84,32 +84,39 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
@ -124,21 +131,27 @@ class Processor(FlowProcessor):
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
async def extract_objects_for_schema(self, text: str, schema_name: str, schema: RowSchema, flow) -> List[Dict[str, Any]]:
"""Extract objects from text for a specific schema"""
@ -234,18 +247,26 @@ class Processor(FlowProcessor):
"""Process incoming chunk and extract objects"""
v = msg.value()
logger.info(f"Extracting objects from chunk {v.metadata.id}...")
workspace = flow.workspace
logger.info(
f"Extracting objects from chunk {v.metadata.id} "
f"(workspace {workspace})..."
)
chunk_text = v.chunk.decode("utf-8")
# If no schemas configured, log warning and return
if not self.schemas:
logger.warning("No schemas configured - skipping extraction")
# If no schemas configured for this workspace, log and return
ws_schemas = self.schemas.get(workspace, {})
if not ws_schemas:
logger.warning(
f"No schemas configured for workspace {workspace} "
f"- skipping extraction"
)
return
try:
# Extract objects for each configured schema
for schema_name, schema in self.schemas.items():
for schema_name, schema in ws_schemas.items():
logger.debug(f"Extracting {schema_name} objects from chunk")
@ -274,7 +295,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=f"{v.metadata.id}:{schema_name}",
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
schema_name=schema_name,

View file

@ -17,14 +17,18 @@ class FlowConfig:
self.config = config
self.pubsub = pubsub
# Cache for parameter type definitions to avoid repeated lookups
# Per-workspace cache for parameter type definitions
# Keyed by (workspace, type-name)
self.param_type_cache = {}
async def resolve_parameters(self, flow_blueprint, user_params):
async def resolve_parameters(
self, workspace, flow_blueprint, user_params
):
"""
Resolve parameters by merging user-provided values with defaults.
Args:
workspace: Workspace containing the parameter-type definitions
flow_blueprint: The flow blueprint definition dict
user_params: User-provided parameters dict (may be None or empty)
@ -55,24 +59,25 @@ class FlowConfig:
# Look up the parameter type definition
param_type = param_meta.get("type")
if param_type:
cache_key = (workspace, param_type)
# Check cache first
if param_type not in self.param_type_cache:
if cache_key not in self.param_type_cache:
try:
# Fetch parameter type definition from config store
type_def = await self.config.get(
"parameter-type", param_type
workspace, "parameter-type", param_type
)
if type_def:
self.param_type_cache[param_type] = json.loads(type_def)
self.param_type_cache[cache_key] = json.loads(type_def)
else:
logger.warning(f"Parameter type '{param_type}' not found in config")
self.param_type_cache[param_type] = {}
self.param_type_cache[cache_key] = {}
except Exception as e:
logger.error(f"Error fetching parameter type '{param_type}': {e}")
self.param_type_cache[param_type] = {}
self.param_type_cache[cache_key] = {}
# Apply default from type definition (as string)
type_def = self.param_type_cache[param_type]
type_def = self.param_type_cache[cache_key]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
@ -94,8 +99,9 @@ class FlowConfig:
else:
# Controller has no value, try to get default from type definition
param_type = param_meta.get("type")
if param_type and param_type in self.param_type_cache:
type_def = self.param_type_cache[param_type]
cache_key = (workspace, param_type) if param_type else None
if cache_key and cache_key in self.param_type_cache:
type_def = self.param_type_cache[cache_key]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
@ -114,7 +120,9 @@ class FlowConfig:
async def handle_list_blueprints(self, msg):
names = list(await self.config.keys("flow-blueprint"))
names = list(await self.config.keys(
msg.workspace, "flow-blueprint"
))
return FlowResponse(
error = None,
@ -126,14 +134,14 @@ class FlowConfig:
return FlowResponse(
error = None,
blueprint_definition = await self.config.get(
"flow-blueprint", msg.blueprint_name
msg.workspace, "flow-blueprint", msg.blueprint_name
),
)
async def handle_put_blueprint(self, msg):
await self.config.put(
"flow-blueprint",
msg.workspace, "flow-blueprint",
msg.blueprint_name, msg.blueprint_definition
)
@ -145,7 +153,9 @@ class FlowConfig:
logger.debug(f"Flow config message: {msg}")
await self.config.delete("flow-blueprint", msg.blueprint_name)
await self.config.delete(
msg.workspace, "flow-blueprint", msg.blueprint_name
)
return FlowResponse(
error = None,
@ -153,7 +163,7 @@ class FlowConfig:
async def handle_list_flows(self, msg):
names = list(await self.config.keys("flow"))
names = list(await self.config.keys(msg.workspace, "flow"))
return FlowResponse(
error = None,
@ -162,7 +172,9 @@ class FlowConfig:
async def handle_get_flow(self, msg):
flow_data = await self.config.get("flow", msg.flow_id)
flow_data = await self.config.get(
msg.workspace, "flow", msg.flow_id
)
flow = json.loads(flow_data)
return FlowResponse(
@ -174,37 +186,49 @@ class FlowConfig:
async def handle_start_flow(self, msg):
workspace = msg.workspace
if msg.blueprint_name is None:
raise RuntimeError("No blueprint name")
if msg.flow_id is None:
raise RuntimeError("No flow ID")
if msg.flow_id in await self.config.keys("flow"):
if msg.flow_id in await self.config.keys(workspace, "flow"):
raise RuntimeError("Flow already exists")
if msg.description is None:
raise RuntimeError("No description")
if msg.blueprint_name not in await self.config.keys("flow-blueprint"):
if msg.blueprint_name not in await self.config.keys(
workspace, "flow-blueprint"
):
raise RuntimeError("Blueprint does not exist")
cls = json.loads(
await self.config.get("flow-blueprint", msg.blueprint_name)
await self.config.get(
workspace, "flow-blueprint", msg.blueprint_name
)
)
# Resolve parameters by merging user-provided values with defaults
user_params = msg.parameters if msg.parameters else {}
parameters = await self.resolve_parameters(cls, user_params)
parameters = await self.resolve_parameters(
workspace, cls, user_params
)
# Log the resolved parameters for debugging
logger.debug(f"User provided parameters: {user_params}")
logger.debug(f"Resolved parameters (with defaults): {parameters}")
# Apply parameter substitution to template replacement function
# Apply parameter substitution to template replacement function.
# {workspace} is substituted from msg.workspace to isolate
# queue names across workspaces.
def repl_template_with_params(tmp):
result = tmp.replace(
"{workspace}", workspace
).replace(
"{blueprint}", msg.blueprint_name
).replace(
"{id}", msg.flow_id
@ -253,7 +277,7 @@ class FlowConfig:
json.dumps(entry),
))
await self.config.put_many(updates)
await self.config.put_many(workspace, updates)
def repl_interface(i):
return {
@ -270,7 +294,7 @@ class FlowConfig:
interfaces = {}
await self.config.put(
"flow", msg.flow_id,
workspace, "flow", msg.flow_id,
json.dumps({
"description": msg.description,
"blueprint-name": msg.blueprint_name,
@ -283,68 +307,77 @@ class FlowConfig:
error = None,
)
async def ensure_existing_flow_topics(self):
"""Ensure topics exist for all already-running flows.
async def ensure_existing_flow_topics(self, workspaces):
"""Ensure topics exist for all already-running flows across
the given workspaces.
Called on startup to handle flows that were started before this
version of the flow service was deployed, or before a restart.
"""
flow_ids = await self.config.keys("flow")
for workspace in workspaces:
flow_ids = await self.config.keys(workspace, "flow")
for flow_id in flow_ids:
try:
flow_data = await self.config.get("flow", flow_id)
if flow_data is None:
continue
flow = json.loads(flow_data)
blueprint_name = flow.get("blueprint-name")
if blueprint_name is None:
continue
# Skip flows that are mid-shutdown
if flow.get("status") == "stopping":
continue
parameters = flow.get("parameters", {})
blueprint_data = await self.config.get(
"flow-blueprint", blueprint_name
)
if blueprint_data is None:
logger.warning(
f"Blueprint '{blueprint_name}' not found for "
f"flow '{flow_id}', skipping topic creation"
for flow_id in flow_ids:
try:
flow_data = await self.config.get(
workspace, "flow", flow_id
)
continue
if flow_data is None:
continue
cls = json.loads(blueprint_data)
flow = json.loads(flow_data)
def repl_template(tmp):
result = tmp.replace(
"{blueprint}", blueprint_name
).replace(
"{id}", flow_id
blueprint_name = flow.get("blueprint-name")
if blueprint_name is None:
continue
# Skip flows that are mid-shutdown
if flow.get("status") == "stopping":
continue
parameters = flow.get("parameters", {})
blueprint_data = await self.config.get(
workspace, "flow-blueprint", blueprint_name
)
for param_name, param_value in parameters.items():
result = result.replace(
f"{{{param_name}}}", str(param_value)
if blueprint_data is None:
logger.warning(
f"Blueprint '{blueprint_name}' not found "
f"for flow '{workspace}/{flow_id}', skipping "
f"topic creation"
)
return result
continue
topics = self._collect_flow_topics(cls, repl_template)
for topic in topics:
await self.pubsub.ensure_topic(topic)
cls = json.loads(blueprint_data)
logger.info(
f"Ensured topics for existing flow '{flow_id}'"
)
def repl_template(tmp):
result = tmp.replace(
"{workspace}", workspace
).replace(
"{blueprint}", blueprint_name
).replace(
"{id}", flow_id
)
for param_name, param_value in parameters.items():
result = result.replace(
f"{{{param_name}}}", str(param_value)
)
return result
except Exception as e:
logger.error(
f"Failed to ensure topics for flow '{flow_id}': {e}"
)
topics = self._collect_flow_topics(cls, repl_template)
for topic in topics:
await self.pubsub.ensure_topic(topic)
logger.info(
f"Ensured topics for existing flow "
f"'{workspace}/{flow_id}'"
)
except Exception as e:
logger.error(
f"Failed to ensure topics for flow "
f"'{workspace}/{flow_id}': {e}"
)
def _collect_flow_topics(self, cls, repl_template):
"""Collect unique topic identifiers from the blueprint.
@ -393,79 +426,95 @@ class FlowConfig:
return topics
async def _live_owned_topic_closure(self, exclude_flow_id=None):
"""Union of flow-owned topics referenced by all live flows.
async def _live_owned_topic_closure(
self, exclude_workspace=None, exclude_flow_id=None,
):
"""Union of flow-owned topics referenced by all live flows,
across every workspace.
Walks every flow record currently registered in the config
service (except ``exclude_flow_id``, typically the flow being
torn down), resolves its blueprint + parameter templates, and
collects the set of flow-owned topics those templates produce.
service (except the single ``(exclude_workspace, exclude_flow_id)``
pair typically the flow being torn down), resolves its
blueprint + parameter templates, and collects the set of
flow-owned topics those templates produce.
Used to drive closure-based topic cleanup on flow stop: a
topic may only be deleted if no remaining live flow would
still template to it. This handles all three scoping cases
transparently ``{id}`` topics have no other references once
their flow is excluded; ``{blueprint}`` topics stay alive
while another flow of the same blueprint exists; ``{workspace}``
(when introduced) stays alive while any flow in the workspace
exists.
topic may only be deleted if no remaining live flow (in any
workspace) would still template to it. This handles all
scoping cases transparently ``{id}`` topics have no other
references once their flow is excluded; ``{blueprint}`` topics
stay alive while another flow of the same blueprint exists;
``{workspace}`` topics stay alive while any flow in the same
workspace remains.
"""
live = set()
flow_ids = await self.config.keys("flow")
workspaces = await self.config.workspaces_for_type("flow")
for fid in flow_ids:
for ws in workspaces:
if fid == exclude_flow_id:
continue
flow_ids = await self.config.keys(ws, "flow")
try:
frec_raw = await self.config.get("flow", fid)
if frec_raw is None:
for fid in flow_ids:
if ws == exclude_workspace and fid == exclude_flow_id:
continue
frec = json.loads(frec_raw)
except Exception as e:
logger.warning(
f"Closure sweep: skipping flow {fid}: {e}"
)
continue
# Flows mid-shutdown don't keep their topics alive.
if frec.get("status") == "stopping":
continue
bp_name = frec.get("blueprint-name")
if bp_name is None:
continue
try:
bp_raw = await self.config.get("flow-blueprint", bp_name)
if bp_raw is None:
continue
bp = json.loads(bp_raw)
except Exception as e:
logger.warning(
f"Closure sweep: skipping flow {fid} "
f"(blueprint {bp_name}): {e}"
)
continue
parameters = frec.get("parameters", {})
def repl(tmp, bp_name=bp_name, fid=fid, parameters=parameters):
result = tmp.replace(
"{blueprint}", bp_name
).replace(
"{id}", fid
)
for pname, pvalue in parameters.items():
result = result.replace(
f"{{{pname}}}", str(pvalue)
try:
frec_raw = await self.config.get(ws, "flow", fid)
if frec_raw is None:
continue
frec = json.loads(frec_raw)
except Exception as e:
logger.warning(
f"Closure sweep: skipping flow {ws}/{fid}: {e}"
)
return result
continue
live.update(self._collect_owned_topics(bp, repl))
# Flows mid-shutdown don't keep their topics alive.
if frec.get("status") == "stopping":
continue
bp_name = frec.get("blueprint-name")
if bp_name is None:
continue
try:
bp_raw = await self.config.get(
ws, "flow-blueprint", bp_name
)
if bp_raw is None:
continue
bp = json.loads(bp_raw)
except Exception as e:
logger.warning(
f"Closure sweep: skipping flow {ws}/{fid} "
f"(blueprint {bp_name}): {e}"
)
continue
parameters = frec.get("parameters", {})
def repl(
tmp,
ws=ws, bp_name=bp_name, fid=fid,
parameters=parameters,
):
result = tmp.replace(
"{workspace}", ws
).replace(
"{blueprint}", bp_name
).replace(
"{id}", fid
)
for pname, pvalue in parameters.items():
result = result.replace(
f"{{{pname}}}", str(pvalue)
)
return result
live.update(self._collect_owned_topics(bp, repl))
return live
@ -501,13 +550,17 @@ class FlowConfig:
async def handle_stop_flow(self, msg):
workspace = msg.workspace
if msg.flow_id is None:
raise RuntimeError("No flow ID")
if msg.flow_id not in await self.config.keys("flow"):
if msg.flow_id not in await self.config.keys(workspace, "flow"):
raise RuntimeError("Flow ID invalid")
flow = json.loads(await self.config.get("flow", msg.flow_id))
flow = json.loads(
await self.config.get(workspace, "flow", msg.flow_id)
)
if "blueprint-name" not in flow:
raise RuntimeError("Internal error: flow has no flow blueprint")
@ -516,11 +569,15 @@ class FlowConfig:
parameters = flow.get("parameters", {})
cls = json.loads(
await self.config.get("flow-blueprint", blueprint_name)
await self.config.get(
workspace, "flow-blueprint", blueprint_name
)
)
def repl_template(tmp):
result = tmp.replace(
"{workspace}", workspace
).replace(
"{blueprint}", blueprint_name
).replace(
"{id}", msg.flow_id
@ -539,7 +596,7 @@ class FlowConfig:
# The config push tells processors to shut down their consumers.
flow["status"] = "stopping"
await self.config.put(
"flow", msg.flow_id, json.dumps(flow)
workspace, "flow", msg.flow_id, json.dumps(flow)
)
# Delete all processor config entries for this flow.
@ -552,7 +609,7 @@ class FlowConfig:
deletes.append((f"processor:{processor}", variant))
await self.config.delete_many(deletes)
await self.config.delete_many(workspace, deletes)
# Phase 2: Closure-based sweep. Only delete topics that no
# other live flow still references via its blueprint templates.
@ -560,6 +617,7 @@ class FlowConfig:
# of the same blueprint is still running, and {workspace}-scoped
# topics while any flow in that workspace remains.
live_owned = await self._live_owned_topic_closure(
exclude_workspace=workspace,
exclude_flow_id=msg.flow_id,
)
@ -571,13 +629,13 @@ class FlowConfig:
kept = this_flow_owned - to_delete
if kept:
logger.info(
f"Flow {msg.flow_id}: keeping {len(kept)} topics "
f"still referenced by other live flows"
f"Flow {workspace}/{msg.flow_id}: keeping {len(kept)} "
f"topics still referenced by other live flows"
)
# Phase 3: Remove the flow record.
if msg.flow_id in await self.config.keys("flow"):
await self.config.delete("flow", msg.flow_id)
if msg.flow_id in await self.config.keys(workspace, "flow"):
await self.config.delete(workspace, "flow", msg.flow_id)
return FlowResponse(
error = None,
@ -585,7 +643,18 @@ class FlowConfig:
async def handle(self, msg):
logger.debug(f"Handling flow message: {msg.operation}")
logger.debug(
f"Handling flow message: {msg.operation} "
f"workspace={msg.workspace}"
)
if not msg.workspace:
return FlowResponse(
error=Error(
type="bad-request",
message="Workspace is required",
),
)
if msg.operation == "list-blueprints":
resp = await self.handle_list_blueprints(msg)

View file

@ -103,7 +103,12 @@ class Processor(AsyncProcessor):
await self.pubsub.ensure_topic(self.flow_request_topic)
await self.config_client.start()
await self.flow.ensure_existing_flow_topics()
# Discover workspaces with existing flow config and ensure
# their topics exist before we start accepting requests.
workspaces = await self.config_client.workspaces_for_type("flow")
await self.flow.ensure_existing_flow_topics(workspaces)
await self.flow_request_consumer.start()
async def on_flow_request(self, msg, consumer, flow):

View file

@ -30,6 +30,7 @@ class ConfigReceiver:
self.flow_handlers = []
# Per-workspace flow tracking: {workspace: {flow_id: flow_def}}
self.flows = {}
self.config_version = 0
@ -43,7 +44,7 @@ class ConfigReceiver:
v = msg.value()
notify_version = v.version
notify_types = set(v.types)
changes = v.changes
# Skip if we already have this version or newer
if notify_version <= self.config_version:
@ -53,20 +54,27 @@ class ConfigReceiver:
)
return
# Gateway cares about flow config
if notify_types and "flow" not in notify_types:
# Gateway cares about flow config — check if any flow
# types changed in any workspace
flow_workspaces = changes.get("flow", [])
if changes and not flow_workspaces:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no flow types in {notify_types}"
f"no flow changes"
)
self.config_version = notify_version
return
logger.info(
f"Config notify v{notify_version}, fetching config..."
f"Config notify v{notify_version} "
f"types={list(changes.keys())}, fetching config..."
)
await self.fetch_and_apply()
# Refresh config for each affected workspace
for workspace in flow_workspaces:
await self.fetch_and_apply_workspace(workspace)
self.config_version = notify_version
except Exception as e:
logger.error(
@ -98,20 +106,25 @@ class ConfigReceiver:
response_metrics=config_resp_metrics,
)
async def fetch_and_apply(self, retry=False):
"""Fetch full config and apply flow changes.
async def fetch_and_apply_workspace(self, workspace, retry=False):
"""Fetch config for a single workspace and apply flow changes.
If retry=True, keeps retrying until successful."""
while True:
try:
logger.info("Fetching config from config service...")
logger.info(
f"Fetching config for workspace {workspace}..."
)
client = self._create_config_client()
try:
await client.start()
resp = await client.request(
ConfigRequest(operation="config"),
ConfigRequest(
operation="config",
workspace=workspace,
),
timeout=10,
)
finally:
@ -137,18 +150,22 @@ class ConfigReceiver:
flows = config.get("flow", {})
ws_flows = self.flows.get(workspace, {})
wanted = list(flows.keys())
current = list(self.flows.keys())
current = list(ws_flows.keys())
for k in wanted:
if k not in current:
self.flows[k] = json.loads(flows[k])
await self.start_flow(k, self.flows[k])
ws_flows[k] = json.loads(flows[k])
await self.start_flow(workspace, k, ws_flows[k])
for k in current:
if k not in wanted:
await self.stop_flow(k, self.flows[k])
del self.flows[k]
await self.stop_flow(workspace, k, ws_flows[k])
del ws_flows[k]
self.flows[workspace] = ws_flows
return
@ -164,27 +181,91 @@ class ConfigReceiver:
)
return
async def start_flow(self, id, flow):
async def fetch_all_workspaces(self, retry=False):
"""Fetch config for all workspaces at startup.
Discovers workspaces via the config service getvalues-all-ws
operation on the flow type."""
logger.info(f"Starting flow: {id}")
while True:
try:
logger.info("Discovering workspaces with flows...")
client = self._create_config_client()
try:
await client.start()
# Discover workspaces that have any flow config
resp = await client.request(
ConfigRequest(
operation="getvalues-all-ws",
type="flow",
),
timeout=10,
)
if resp.error:
raise RuntimeError(
f"Config error: {resp.error.message}"
)
workspaces = {
v.workspace for v in resp.values if v.workspace
}
# Always include the default workspace, even if
# empty, so that newly-created flows in it can be
# picked up by subsequent notifications.
workspaces.add("default")
logger.info(
f"Found workspaces with flows: {workspaces}"
)
finally:
await client.stop()
# Fetch and apply config for each workspace
for workspace in workspaces:
await self.fetch_and_apply_workspace(
workspace, retry=retry
)
return
except Exception as e:
if retry:
logger.warning(
f"Workspace fetch failed: {e}, retrying in 2s..."
)
await asyncio.sleep(2)
continue
logger.error(
f"Workspace fetch exception: {e}", exc_info=True
)
return
async def start_flow(self, workspace, id, flow):
logger.info(f"Starting flow: {workspace}/{id}")
for handler in self.flow_handlers:
try:
await handler.start_flow(id, flow)
await handler.start_flow(workspace, id, flow)
except Exception as e:
logger.error(
f"Config processing exception: {e}", exc_info=True
)
async def stop_flow(self, id, flow):
async def stop_flow(self, workspace, id, flow):
logger.info(f"Stopping flow: {id}")
logger.info(f"Stopping flow: {workspace}/{id}")
for handler in self.flow_handlers:
try:
await handler.stop_flow(id, flow)
await handler.stop_flow(workspace, id, flow)
except Exception as e:
logger.error(
f"Config processing exception: {e}", exc_info=True
@ -218,7 +299,7 @@ class ConfigReceiver:
# Fetch current config (subscribe-then-fetch pattern)
# Retry until config service is available
await self.fetch_and_apply(retry=True)
await self.fetch_all_workspaces(retry=True)
logger.info(
"Config loader initialised, waiting for notifys..."

View file

@ -16,7 +16,7 @@ class CoreExport:
async def process(self, data, error, ok, request):
id = request.query["id"]
user = request.query["user"]
workspace = request.query.get("workspace", "default")
response = await ok()
@ -41,7 +41,6 @@ class CoreExport:
{
"m": {
"i": data["metadata"]["id"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"e": [
@ -65,7 +64,6 @@ class CoreExport:
{
"m": {
"i": data["metadata"]["id"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
@ -78,7 +76,7 @@ class CoreExport:
await kr.process(
{
"operation": "get-kg-core",
"user": user,
"workspace": workspace,
"id": id,
},
responder

View file

@ -17,7 +17,7 @@ class CoreImport:
async def process(self, data, error, ok, request):
id = request.query["id"]
user = request.query["user"]
workspace = request.query.get("workspace", "default")
kr = KnowledgeRequestor(
backend = self.backend,
@ -43,12 +43,11 @@ class CoreImport:
msg = unpacked[1]
msg = {
"operation": "put-kg-core",
"user": user,
"workspace": workspace,
"id": id,
"triples": {
"metadata": {
"id": id,
"user": user,
"collection": "default", # Not used?
},
"triples": msg["t"],
@ -61,12 +60,11 @@ class CoreImport:
msg = unpacked[1]
msg = {
"operation": "put-kg-core",
"user": user,
"workspace": workspace,
"id": id,
"graph-embeddings": {
"metadata": {
"id": id,
"user": user,
"collection": "default", # Not used?
},
"entities": [

View file

@ -14,12 +14,12 @@ class DocumentStreamExport:
async def process(self, data, error, ok, request):
user = request.query.get("user")
workspace = request.query.get("workspace", "default")
document_id = request.query.get("document-id")
chunk_size = int(request.query.get("chunk-size", 1024 * 1024))
if not user or not document_id:
return await error("Missing required parameters: user, document-id")
if not document_id:
return await error("Missing required parameter: document-id")
response = await ok()
@ -45,7 +45,7 @@ class DocumentStreamExport:
await lr.process(
{
"operation": "stream-document",
"user": user,
"workspace": workspace,
"document-id": document_id,
"chunk-size": chunk_size,
},

View file

@ -48,7 +48,6 @@ class EntityContextsImport:
elt = EntityContexts(
metadata=Metadata(
id=data["metadata"]["id"],
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[

View file

@ -48,7 +48,6 @@ class GraphEmbeddingsImport:
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[

View file

@ -116,18 +116,20 @@ class DispatcherManager:
# Format: {"config": {"request": "...", "response": "..."}, ...}
self.queue_overrides = queue_overrides or {}
# Flows keyed by (workspace, flow_id)
self.flows = {}
# Dispatchers keyed by (workspace, flow_id, kind)
self.dispatchers = {}
self.dispatcher_lock = asyncio.Lock()
async def start_flow(self, id, flow):
logger.info(f"Starting flow {id}")
self.flows[id] = flow
async def start_flow(self, workspace, id, flow):
logger.info(f"Starting flow {workspace}/{id}")
self.flows[(workspace, id)] = flow
return
async def stop_flow(self, id, flow):
logger.info(f"Stopping flow {id}")
del self.flows[id]
async def stop_flow(self, workspace, id, flow):
logger.info(f"Stopping flow {workspace}/{id}")
del self.flows[(workspace, id)]
return
def dispatch_global_service(self):
@ -203,18 +205,20 @@ class DispatcherManager:
async def process_flow_import(self, ws, running, params):
workspace = params.get("workspace", "default")
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
flow_key = (workspace, flow)
if flow_key not in self.flows:
raise RuntimeError(f"Invalid flow {workspace}/{flow}")
if kind not in import_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
key = (workspace, flow, kind)
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow_key]["interfaces"]
# FIXME: The -store bit, does it make sense?
if kind == "entity-contexts":
@ -242,18 +246,20 @@ class DispatcherManager:
async def process_flow_export(self, ws, running, params):
workspace = params.get("workspace", "default")
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
flow_key = (workspace, flow)
if flow_key not in self.flows:
raise RuntimeError(f"Invalid flow {workspace}/{flow}")
if kind not in export_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
key = (workspace, flow, kind)
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow_key]["interfaces"]
# FIXME: The -store bit, does it make sense?
if kind == "entity-contexts":
@ -286,22 +292,36 @@ class DispatcherManager:
async def process_flow_service(self, data, responder, params):
# Workspace can come from URL or from request body, defaulting
# to "default". Having it in the URL allows gateway routing to
# be workspace-aware without touching the body.
workspace = params.get("workspace")
if not workspace and isinstance(data, dict):
workspace = data.get("workspace")
if not workspace:
workspace = "default"
flow = params.get("flow")
kind = params.get("kind")
return await self.invoke_flow_service(data, responder, flow, kind)
return await self.invoke_flow_service(
data, responder, workspace, flow, kind,
)
async def invoke_flow_service(self, data, responder, flow, kind):
async def invoke_flow_service(
self, data, responder, workspace, flow, kind,
):
if flow not in self.flows:
raise RuntimeError("Invalid flow")
flow_key = (workspace, flow)
if flow_key not in self.flows:
raise RuntimeError(f"Invalid flow {workspace}/{flow}")
key = (flow, kind)
key = (workspace, flow, kind)
if key not in self.dispatchers:
async with self.dispatcher_lock:
if key not in self.dispatchers:
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow_key]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
@ -314,8 +334,8 @@ class DispatcherManager:
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"{self.prefix}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{flow}-{kind}-request",
consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](

View file

@ -47,7 +47,9 @@ class Mux:
raise RuntimeError("Bad message")
await self.q.put((
data["id"], data.get("flow"),
data["id"],
data.get("workspace", "default"),
data.get("flow"),
data["service"],
data["request"]
))
@ -87,8 +89,10 @@ class Mux:
# worker[0] still running, move on
break
async def start_request_task(self, ws, id, flow, svc, request, workers):
async def start_request_task(
self, ws, id, workspace, flow, svc, request, workers,
):
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
while len(workers) > MAX_OUTSTANDING_REQUESTS:
@ -106,19 +110,23 @@ class Mux:
})
worker = asyncio.create_task(
self.request_task(id, request, responder, flow, svc)
self.request_task(
id, request, responder, workspace, flow, svc,
)
)
workers.append(worker)
async def request_task(self, id, request, responder, flow, svc):
async def request_task(
self, id, request, responder, workspace, flow, svc,
):
try:
if flow:
await self.dispatcher_manager.invoke_flow_service(
request, responder, flow, svc
request, responder, workspace, flow, svc,
)
else:
@ -148,7 +156,7 @@ class Mux:
# Get next request on queue
item = await asyncio.wait_for(self.q.get(), 1)
id, flow, svc, request = item
id, workspace, flow, svc, request = item
except TimeoutError:
continue
@ -172,7 +180,7 @@ class Mux:
try:
await self.start_request_task(
self.ws, id, flow, svc, request, workers
self.ws, id, workspace, flow, svc, request, workers
)
except Exception as e:

View file

@ -53,7 +53,6 @@ class RowsImport:
elt = ExtractedObject(
metadata=Metadata(
id=data["metadata"]["id"],
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
schema_name=data["schema_name"],

View file

@ -38,7 +38,6 @@ def serialize_triples(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
@ -50,7 +49,6 @@ def serialize_graph_embeddings(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"entities": [
@ -68,7 +66,6 @@ def serialize_entity_contexts(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"entities": [
@ -86,7 +83,6 @@ def serialize_document_embeddings(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"chunks": [
@ -120,8 +116,8 @@ def serialize_document_metadata(message):
if message.metadata:
ret["metadata"] = serialize_subgraph(message.metadata)
if message.user:
ret["user"] = message.user
if message.workspace:
ret["workspace"] = message.workspace
if message.tags is not None:
ret["tags"] = message.tags
@ -144,8 +140,8 @@ def serialize_processing_metadata(message):
if message.flow:
ret["flow"] = message.flow
if message.user:
ret["user"] = message.user
if message.workspace:
ret["workspace"] = message.workspace
if message.collection:
ret["collection"] = message.collection
@ -164,7 +160,7 @@ def to_document_metadata(x):
title = x.get("title", None),
comments = x.get("comments", None),
metadata = to_subgraph(x["metadata"]),
user = x.get("user", None),
workspace = x.get("workspace", None),
tags = x.get("tags", None),
)
@ -175,7 +171,7 @@ def to_processing_metadata(x):
document_id = x.get("document-id", None),
time = x.get("time", None),
flow = x.get("flow", None),
user = x.get("user", None),
workspace = x.get("workspace", None),
collection = x.get("collection", None),
tags = x.get("tags", None),
)

View file

@ -49,7 +49,6 @@ class TriplesImport:
metadata=Metadata(
id=data["metadata"]["id"],
root=data["metadata"].get("root", ""),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),

View file

@ -3,6 +3,7 @@ Collection management for the librarian - uses config service for storage
"""
import asyncio
import dataclasses
import logging
import json
import uuid
@ -20,7 +21,6 @@ logger = logging.getLogger(__name__)
def metadata_to_dict(metadata: CollectionMetadata) -> dict:
"""Convert CollectionMetadata to dictionary for JSON serialization"""
return {
'user': metadata.user,
'collection': metadata.collection,
'name': metadata.name,
'description': metadata.description,
@ -92,38 +92,38 @@ class CollectionManager:
self.pending_config_requests[response_id + "_response"] = response
self.pending_config_requests[response_id].set()
async def ensure_collection_exists(self, user: str, collection: str):
async def ensure_collection_exists(self, workspace: str, collection: str):
"""
Ensure a collection exists, creating it if necessary
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
"""
try:
# Check if collection exists via config service
request = ConfigRequest(
operation='get',
keys=[ConfigKey(type='collection', key=f'{user}:{collection}')]
workspace=workspace,
keys=[ConfigKey(type='collection', key=collection)]
)
response = await self.send_config_request(request)
# Validate response
if not response.values or len(response.values) == 0:
raise Exception(f"Invalid response from config service when checking collection {user}/{collection}")
raise Exception(f"Invalid response from config service when checking collection {workspace}/{collection}")
# Check if collection exists (value not None means it exists)
if response.values[0].value is not None:
logger.debug(f"Collection {user}/{collection} already exists")
logger.debug(f"Collection {workspace}/{collection} already exists")
return
# Collection doesn't exist (value is None), proceed to create
# Create new collection with default metadata
logger.info(f"Auto-creating collection {user}/{collection}")
logger.info(f"Auto-creating collection {workspace}/{collection}")
metadata = CollectionMetadata(
user=user,
collection=collection,
name=collection, # Default name to collection ID
description="",
@ -132,9 +132,10 @@ class CollectionManager:
request = ConfigRequest(
operation='put',
workspace=workspace,
values=[ConfigValue(
type='collection',
key=f'{user}:{collection}',
key=collection,
value=json.dumps(metadata_to_dict(metadata))
)]
)
@ -144,7 +145,7 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config update failed: {response.error.message}")
logger.info(f"Collection {user}/{collection} auto-created in config service")
logger.info(f"Collection {workspace}/{collection} auto-created in config service")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
@ -161,9 +162,10 @@ class CollectionManager:
CollectionManagementResponse with list of collections
"""
try:
# Get all collections from config service
# Get all collections in this workspace from config service
config_request = ConfigRequest(
operation='getvalues',
workspace=request.workspace,
type='collection'
)
@ -172,15 +174,19 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config query failed: {response.error.message}")
# Parse collections and filter by user
# Every value in this workspace is a collection.
# Filter to fields the current schema knows about — older
# persisted values may carry fields that have since been
# dropped (e.g. `user` from the pre-workspace-refactor era).
known_fields = {f.name for f in dataclasses.fields(CollectionMetadata)}
collections = []
for config_value in response.values:
if ":" in config_value.key:
coll_user, coll_name = config_value.key.split(":", 1)
if coll_user == request.user:
metadata_dict = json.loads(config_value.value)
metadata = CollectionMetadata(**metadata_dict)
collections.append(metadata)
metadata_dict = json.loads(config_value.value)
metadata_dict = {
k: v for k, v in metadata_dict.items() if k in known_fields
}
metadata = CollectionMetadata(**metadata_dict)
collections.append(metadata)
# Apply tag filtering if specified
if request.tag_filter:
@ -221,7 +227,6 @@ class CollectionManager:
tags = list(request.tags) if request.tags else []
metadata = CollectionMetadata(
user=request.user,
collection=request.collection,
name=name,
description=description,
@ -231,9 +236,10 @@ class CollectionManager:
# Send put request to config service
config_request = ConfigRequest(
operation='put',
workspace=request.workspace,
values=[ConfigValue(
type='collection',
key=f'{request.user}:{request.collection}',
key=request.collection,
value=json.dumps(metadata_to_dict(metadata))
)]
)
@ -243,7 +249,7 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config update failed: {response.error.message}")
logger.info(f"Collection {request.user}/{request.collection} updated in config service")
logger.info(f"Collection {request.workspace}/{request.collection} updated in config service")
# Config service will trigger config push automatically
# Storage services will receive update and create/update collections
@ -269,12 +275,13 @@ class CollectionManager:
CollectionManagementResponse indicating success or failure
"""
try:
logger.info(f"Deleting collection {request.user}/{request.collection}")
logger.info(f"Deleting collection {request.workspace}/{request.collection}")
# Send delete request to config service
config_request = ConfigRequest(
operation='delete',
keys=[ConfigKey(type='collection', key=f'{request.user}:{request.collection}')]
workspace=request.workspace,
keys=[ConfigKey(type='collection', key=request.collection)]
)
response = await self.send_config_request(config_request)
@ -282,7 +289,7 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config delete failed: {response.error.message}")
logger.info(f"Collection {request.user}/{request.collection} deleted from config service")
logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service")
# Config service will trigger config push automatically
# Storage services will receive update and delete collections

View file

@ -48,7 +48,7 @@ class Librarian:
raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RuntimeError("Document already exists")
@ -78,7 +78,7 @@ class Librarian:
logger.debug("Removing document...")
if not await self.table_store.document_exists(
request.user,
request.workspace,
request.document_id,
):
raise RuntimeError("Document does not exist")
@ -89,17 +89,17 @@ class Librarian:
logger.debug(f"Cascade deleting child document {child.id}")
try:
child_object_id = await self.table_store.get_document_object_id(
child.user,
child.workspace,
child.id
)
await self.blob_store.remove(child_object_id)
await self.table_store.remove_document(child.user, child.id)
await self.table_store.remove_document(child.workspace, child.id)
except Exception as e:
logger.warning(f"Failed to delete child document {child.id}: {e}")
# Now remove the parent document
object_id = await self.table_store.get_document_object_id(
request.user,
request.workspace,
request.document_id
)
@ -108,7 +108,7 @@ class Librarian:
# Remove doc table row
await self.table_store.remove_document(
request.user,
request.workspace,
request.document_id
)
@ -120,10 +120,10 @@ class Librarian:
logger.debug("Updating document...")
# You can't update the document ID, user or kind.
# You can't update the document ID, workspace or kind.
if not await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RuntimeError("Document does not exist")
@ -139,7 +139,7 @@ class Librarian:
logger.debug("Getting document metadata...")
doc = await self.table_store.get_document(
request.user,
request.workspace,
request.document_id
)
@ -156,7 +156,7 @@ class Librarian:
logger.debug("Getting document content...")
object_id = await self.table_store.get_document_object_id(
request.user,
request.workspace,
request.document_id
)
@ -180,18 +180,18 @@ class Librarian:
raise RuntimeError("Collection parameter is required")
if await self.table_store.processing_exists(
request.processing_metadata.user,
request.processing_metadata.workspace,
request.processing_metadata.id
):
raise RuntimeError("Processing already exists")
doc = await self.table_store.get_document(
request.processing_metadata.user,
request.processing_metadata.workspace,
request.processing_metadata.document_id
)
object_id = await self.table_store.get_document_object_id(
request.processing_metadata.user,
request.processing_metadata.workspace,
request.processing_metadata.document_id
)
@ -222,14 +222,14 @@ class Librarian:
logger.debug("Removing processing metadata...")
if not await self.table_store.processing_exists(
request.user,
request.workspace,
request.processing_id,
):
raise RuntimeError("Processing object does not exist")
# Remove doc table row
await self.table_store.remove_processing(
request.user,
request.workspace,
request.processing_id
)
@ -239,7 +239,7 @@ class Librarian:
async def list_documents(self, request):
docs = await self.table_store.list_documents(request.user)
docs = await self.table_store.list_documents(request.workspace)
# Filter out child documents and answer documents by default
include_children = getattr(request, 'include_children', False)
@ -256,7 +256,7 @@ class Librarian:
async def list_processing(self, request):
procs = await self.table_store.list_processing(request.user)
procs = await self.table_store.list_processing(request.workspace)
return LibrarianResponse(
processing_metadatas = procs,
@ -276,7 +276,7 @@ class Librarian:
raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RequestError("Document already exists")
@ -312,14 +312,14 @@ class Librarian:
"kind": request.document_metadata.kind,
"title": request.document_metadata.title,
"comments": request.document_metadata.comments,
"user": request.document_metadata.user,
"workspace": request.document_metadata.workspace,
"tags": request.document_metadata.tags,
})
# Store session in Cassandra
await self.table_store.create_upload_session(
upload_id=upload_id,
user=request.document_metadata.user,
workspace=request.document_metadata.workspace,
document_id=request.document_metadata.id,
document_metadata=doc_meta_json,
s3_upload_id=s3_upload_id,
@ -352,7 +352,7 @@ class Librarian:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to upload to this session")
# Validate chunk index
@ -419,7 +419,7 @@ class Librarian:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to complete this upload")
# Verify all chunks received
@ -457,7 +457,7 @@ class Librarian:
kind=doc_meta_dict["kind"],
title=doc_meta_dict.get("title", ""),
comments=doc_meta_dict.get("comments", ""),
user=doc_meta_dict["user"],
workspace=doc_meta_dict["workspace"],
tags=doc_meta_dict.get("tags", []),
metadata=[], # Triples not supported in chunked upload yet
)
@ -488,7 +488,7 @@ class Librarian:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to abort this upload")
# Abort S3 multipart upload
@ -520,7 +520,7 @@ class Librarian:
)
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to view this upload")
chunks_received = session["chunks_received"]
@ -548,11 +548,11 @@ class Librarian:
async def list_uploads(self, request):
"""
List all in-progress uploads for a user.
List all in-progress uploads for a workspace.
"""
logger.debug(f"Listing uploads for user {request.user}")
logger.debug(f"Listing uploads for workspace {request.workspace}")
sessions = await self.table_store.list_upload_sessions(request.user)
sessions = await self.table_store.list_upload_sessions(request.workspace)
upload_sessions = [
UploadSession(
@ -591,7 +591,7 @@ class Librarian:
# Verify parent exists
if not await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.parent_id
):
raise RequestError(
@ -599,7 +599,7 @@ class Librarian:
)
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RequestError("Document already exists")
@ -665,7 +665,7 @@ class Librarian:
)
object_id = await self.table_store.get_document_object_id(
request.user,
request.workspace,
request.document_id
)

View file

@ -277,18 +277,22 @@ class Processor(AsyncProcessor):
"""Forward config responses to collection manager"""
await self.collection_manager.on_config_response(message, consumer, flow)
async def on_librarian_config(self, config, version):
async def on_librarian_config(self, workspace, config, version):
logger.info(f"Configuration version: {version}")
logger.info(
f"Configuration version: {version} workspace: {workspace}"
)
if "flow" in config:
self.flows = {
self.flows[workspace] = {
k: json.loads(v)
for k, v in config["flow"].items()
}
else:
self.flows[workspace] = {}
logger.debug(f"Flows: {self.flows}")
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
def __del__(self):
@ -345,7 +349,6 @@ class Processor(AsyncProcessor):
metadata=Metadata(
id=doc_uri,
root=document.id,
user=processing.user,
collection=processing.collection,
),
triples=all_triples,
@ -363,10 +366,15 @@ class Processor(AsyncProcessor):
logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}")
if processing.flow not in self.flows:
raise RuntimeError("Invalid flow ID")
workspace = processing.workspace
ws_flows = self.flows.get(workspace, {})
if processing.flow not in ws_flows:
raise RuntimeError(
f"Invalid flow ID {processing.flow} for workspace "
f"{workspace}"
)
flow = self.flows[processing.flow]
flow = ws_flows[processing.flow]
if document.kind == "text/plain":
kind = "text-load"
@ -386,7 +394,6 @@ class Processor(AsyncProcessor):
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
@ -398,7 +405,6 @@ class Processor(AsyncProcessor):
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
@ -429,9 +435,9 @@ class Processor(AsyncProcessor):
"""
# Ensure collection exists when processing is added
if hasattr(request, 'processing_metadata') and request.processing_metadata:
user = request.processing_metadata.user
workspace = request.processing_metadata.workspace
collection = request.processing_metadata.collection
await self.collection_manager.ensure_collection_exists(user, collection)
await self.collection_manager.ensure_collection_exists(workspace, collection)
# Call the original add_processing method
return await self.librarian.add_processing(request)

View file

@ -50,30 +50,37 @@ class Processor(FlowProcessor):
)
)
# Per-workspace price tables
self.prices = {}
self.config_key = "token-cost"
# Load token costs from the config service
async def on_cost_config(self, config, version):
async def on_cost_config(self, workspace, config, version):
logger.info(f"Loading metering configuration version {version}")
logger.info(
f"Loading metering configuration version {version} "
f"for workspace {workspace}"
)
if self.config_key not in config:
logger.warning(f"No key {self.config_key} in config")
logger.warning(
f"No key {self.config_key} in config for {workspace}"
)
self.prices[workspace] = {}
return
config = config[self.config_key]
prices = config[self.config_key]
self.prices = {
self.prices[workspace] = {
k: json.loads(v)
for k, v in config.items()
for k, v in prices.items()
}
def get_prices(self, modelname):
def get_prices(self, workspace, modelname):
if modelname in self.prices:
model = self.prices[modelname]
ws_prices = self.prices.get(workspace, {})
if modelname in ws_prices:
model = ws_prices[modelname]
return model["input_price"], model["output_price"]
return None, None # Return None if model is not found
@ -81,6 +88,8 @@ class Processor(FlowProcessor):
v = msg.value()
workspace = flow.workspace
modelname = v.model or "unknown"
num_in = v.in_token or 0
num_out = v.out_token or 0
@ -89,7 +98,9 @@ class Processor(FlowProcessor):
__class__.token_metric.labels(model=modelname, direction="input").inc(num_in)
__class__.token_metric.labels(model=modelname, direction="output").inc(num_out)
model_input_price, model_output_price = self.get_prices(modelname)
model_input_price, model_output_price = self.get_prices(
workspace, modelname
)
if model_input_price == None:
cost_per_call = f"Model Not Found in Price list"

View file

@ -66,24 +66,37 @@ class Processor(FlowProcessor):
self.register_config_handler(self.on_prompt_config, types=["prompt"])
# Null configuration, should reload quickly
self.manager = PromptManager()
# Per-workspace prompt managers. Populated lazily as config
# arrives for each workspace.
self.managers = {}
async def on_prompt_config(self, config, version):
async def on_prompt_config(self, workspace, config, version):
logger.info(f"Loading prompt configuration version {version}")
logger.info(
f"Loading prompt configuration version {version} "
f"for workspace {workspace}"
)
if self.config_key not in config:
logger.warning(f"No key {self.config_key} in config")
logger.warning(
f"No key {self.config_key} in config for {workspace}"
)
return
config = config[self.config_key]
prompt_config = config[self.config_key]
try:
self.manager.load_config(config)
manager = self.managers.get(workspace)
if manager is None:
manager = PromptManager()
self.managers[workspace] = manager
logger.info("Prompt configuration reloaded")
manager.load_config(prompt_config)
logger.info(
f"Prompt configuration reloaded for {workspace}"
)
except Exception as e:
@ -103,6 +116,29 @@ class Processor(FlowProcessor):
# Check if streaming is requested
streaming = getattr(v, 'streaming', False)
# Look up the prompt manager for this workspace. If none is
# loaded yet, the request can't be handled.
workspace = flow.workspace
manager = self.managers.get(workspace)
if manager is None:
logger.error(
f"No prompt configuration loaded for workspace {workspace}"
)
r = PromptResponse(
error=Error(
type="no-configuration",
message=(
f"No prompt configuration for workspace "
f"{workspace}"
),
),
text=None,
object=None,
end_of_stream=True,
)
await flow("response").send(r, properties={"id": id})
return
try:
logger.debug(f"Prompt terms: {v.terms}")
@ -149,7 +185,7 @@ class Processor(FlowProcessor):
return ""
try:
await self.manager.invoke(kind, input, llm_streaming)
await manager.invoke(kind, input, llm_streaming)
except Exception as e:
logger.error(f"Prompt streaming exception: {e}", exc_info=True)
raise e
@ -177,7 +213,7 @@ class Processor(FlowProcessor):
return None
try:
resp = await self.manager.invoke(kind, input, llm)
resp = await manager.invoke(kind, input, llm)
except Exception as e:
logger.error(f"Prompt invocation exception: {e}", exc_info=True)
raise e

View file

@ -31,7 +31,7 @@ class Processor(DocumentEmbeddingsQueryService):
self.vecstore = DocVectors(store_uri)
async def query_document_embeddings(self, msg):
async def query_document_embeddings(self, workspace, msg):
try:
@ -45,7 +45,7 @@ class Processor(DocumentEmbeddingsQueryService):
resp = self.vecstore.search(
vec,
msg.user,
workspace,
msg.collection,
limit=msg.limit
)

View file

@ -48,7 +48,7 @@ class Processor(DocumentEmbeddingsQueryService):
}
)
async def query_document_embeddings(self, msg):
async def query_document_embeddings(self, workspace, msg):
try:
@ -63,7 +63,7 @@ class Processor(DocumentEmbeddingsQueryService):
dim = len(vec)
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
index_name = f"d-{workspace}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):

View file

@ -65,7 +65,7 @@ class Processor(DocumentEmbeddingsQueryService):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, msg):
async def query_document_embeddings(self, workspace, msg):
try:
@ -75,7 +75,7 @@ class Processor(DocumentEmbeddingsQueryService):
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
collection = f"d_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):

View file

@ -37,7 +37,7 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):
async def query_graph_embeddings(self, workspace, msg):
try:
@ -51,7 +51,7 @@ class Processor(GraphEmbeddingsQueryService):
resp = self.vecstore.search(
vec,
msg.user,
workspace,
msg.collection,
limit=msg.limit * 2
)

View file

@ -55,7 +55,7 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):
async def query_graph_embeddings(self, workspace, msg):
try:
@ -70,7 +70,7 @@ class Processor(GraphEmbeddingsQueryService):
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
index_name = f"t-{workspace}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):

View file

@ -71,7 +71,7 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):
async def query_graph_embeddings(self, workspace, msg):
try:
@ -81,7 +81,7 @@ class Processor(GraphEmbeddingsQueryService):
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
collection = f"t_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):

View file

@ -70,7 +70,7 @@ class GraphQLSchemaBuilder:
Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with:
- user: str
- workspace: str
- collection: str
- schema_name: str
- row_schema: RowSchema
@ -228,7 +228,7 @@ class GraphQLSchemaBuilder:
limit: Optional[int] = 100
) -> List[graphql_type]:
# Get context values
user = info.context["user"]
workspace = info.context["workspace"]
collection = info.context["collection"]
# Parse the where clause
@ -236,7 +236,7 @@ class GraphQLSchemaBuilder:
# Call the query backend
results = await query_callback(
user, collection, schema_name, row_schema,
workspace, collection, schema_name, row_schema,
filters, limit, order_by, direction
)

View file

@ -167,7 +167,7 @@ class QueryExplainer:
question_components, query_results, processing_metadata
)
# Generate user-friendly explanation
# Generate workspace-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer
)
@ -503,7 +503,7 @@ class QueryExplainer:
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str:
"""Generate user-friendly explanation of the process."""
"""Generate workspace-friendly explanation of the process."""
explanation_parts = []
# Introduction

View file

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass
class QueryRequest:
"""Query request from user."""
"""Query request from workspace."""
question: str
context: Optional[str] = None
ontology_hint: Optional[str] = None

View file

@ -1,6 +1,6 @@
"""
Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components.
Decomposes workspace questions into semantic components.
"""
import logging

View file

@ -1,7 +1,7 @@
"""
Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context.
Input is query vectors plus workspace/collection/schema context.
Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
@ -70,10 +70,10 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema"""
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given workspace/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
@ -93,22 +93,22 @@ class Processor(FlowProcessor):
return None
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
async def query_row_embeddings(self, workspace, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema
# Find the collection for this workspace/collection/schema
qdrant_collection = self.find_collection(
request.user, request.collection, request.schema_name
workspace, request.collection, request.schema_name
)
if not qdrant_collection:
logger.info(
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
f"{workspace}/{request.collection}/{request.schema_name}"
)
return []
@ -163,11 +163,11 @@ class Processor(FlowProcessor):
logger.debug(
f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..."
f"{flow.workspace}/{request.collection}/{request.schema_name}..."
)
# Execute query
matches = await self.query_row_embeddings(request)
matches = await self.query_row_embeddings(flow.workspace, request)
response = RowEmbeddingsResponse(
error=None,

View file

@ -87,12 +87,12 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
# GraphQL schema builder and generated schema
self.schema_builder = GraphQLSchemaBuilder()
self.graphql_schema = None
# Per-workspace GraphQL schema builders and compiled schemas
self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {}
self.graphql_schemas: Dict[str, Any] = {}
# Cassandra session
self.cluster = None
@ -133,17 +133,27 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
self.schema_builder.clear()
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
builder = GraphQLSchemaBuilder()
self.schema_builders[workspace] = builder
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
self.graphql_schemas[workspace] = None
return
# Get the schemas dictionary for our type
@ -177,17 +187,23 @@ class Processor(FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
self.schema_builder.add_schema(schema_name, row_schema)
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
builder.add_schema(schema_name, row_schema)
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
# Regenerate GraphQL schema
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
# Regenerate GraphQL schema for this workspace
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
@ -222,7 +238,7 @@ class Processor(FlowProcessor):
async def query_cassandra(
self,
user: str,
workspace: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
@ -240,7 +256,7 @@ class Processor(FlowProcessor):
# Connect if needed
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
@ -389,26 +405,30 @@ class Processor(FlowProcessor):
async def execute_graphql_query(
self,
workspace: str,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
"""Execute a GraphQL query against the workspace's schema"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
graphql_schema = self.graphql_schemas.get(workspace)
if not graphql_schema:
raise RuntimeError(
f"No GraphQL schema available for workspace {workspace} "
f"- no schemas loaded"
)
# Create context for the query
context = {
"processor": self,
"user": user,
"workspace": workspace,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
result = await graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
@ -454,10 +474,10 @@ class Processor(FlowProcessor):
# Execute GraphQL query
result = await self.execute_graphql_query(
workspace=flow.workspace,
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)

View file

@ -30,14 +30,14 @@ class EvaluationError(Exception):
pass
async def evaluate(node, triples_client, user, collection, limit=10000):
async def evaluate(node, triples_client, workspace, collection, limit=10000):
"""
Evaluate a SPARQL algebra node.
Args:
node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries
user: user/keyspace identifier
workspace: workspace/keyspace identifier
collection: collection identifier
limit: safety limit on results
@ -55,24 +55,24 @@ async def evaluate(node, triples_client, user, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}")
return [{}]
return await handler(node, triples_client, user, collection, limit)
return await handler(node, triples_client, workspace, collection, limit)
# --- Node handlers ---
async def _eval_select_query(node, tc, user, collection, limit):
async def _eval_select_query(node, tc, workspace, collection, limit):
"""Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_project(node, tc, user, collection, limit):
async def _eval_project(node, tc, workspace, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
variables = [str(v) for v in node.PV]
return project(solutions, variables)
async def _eval_bgp(node, tc, user, collection, limit):
async def _eval_bgp(node, tc, workspace, collection, limit):
"""
Evaluate a Basic Graph Pattern.
@ -107,7 +107,7 @@ async def _eval_bgp(node, tc, user, collection, limit):
# Query the triples store
results = await _query_pattern(
tc, s_val, p_val, o_val, user, collection, limit
tc, s_val, p_val, o_val, workspace, collection, limit
)
# Map results back to variable bindings,
@ -130,17 +130,17 @@ async def _eval_bgp(node, tc, user, collection, limit):
return solutions[:limit]
async def _eval_join(node, tc, user, collection, limit):
async def _eval_join(node, tc, workspace, collection, limit):
"""Evaluate a Join node."""
left = await evaluate(node.p1, tc, user, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit)
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, user, collection, limit):
async def _eval_left_join(node, tc, workspace, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, user, collection, limit)
right_sols = await evaluate(node.p2, tc, user, collection, limit)
left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
filter_fn = None
if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +153,16 @@ async def _eval_left_join(node, tc, user, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, user, collection, limit):
async def _eval_union(node, tc, workspace, collection, limit):
"""Evaluate a Union node."""
left = await evaluate(node.p1, tc, user, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit)
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
return union(left, right)[:limit]
async def _eval_filter(node, tc, user, collection, limit):
async def _eval_filter(node, tc, workspace, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
expr = node.expr
return [
sol for sol in solutions
@ -170,22 +170,22 @@ async def _eval_filter(node, tc, user, collection, limit):
]
async def _eval_distinct(node, tc, user, collection, limit):
async def _eval_distinct(node, tc, workspace, collection, limit):
"""Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions)
async def _eval_reduced(node, tc, user, collection, limit):
async def _eval_reduced(node, tc, workspace, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions)
async def _eval_order_by(node, tc, user, collection, limit):
async def _eval_order_by(node, tc, workspace, collection, limit):
"""Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
key_fns = []
for cond in node.expr:
@ -206,7 +206,7 @@ async def _eval_order_by(node, tc, user, collection, limit):
return order_by(solutions, key_fns)
async def _eval_slice(node, tc, user, collection, limit):
async def _eval_slice(node, tc, workspace, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible
inner_limit = limit
@ -214,13 +214,13 @@ async def _eval_slice(node, tc, user, collection, limit):
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, user, collection, inner_limit)
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, user, collection, limit):
async def _eval_extend(node, tc, workspace, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
var_name = str(node.var)
expr = node.expr
@ -246,9 +246,9 @@ async def _eval_extend(node, tc, user, collection, limit):
return result
async def _eval_group(node, tc, user, collection, limit):
async def _eval_group(node, tc, workspace, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
# Extract grouping expressions
group_exprs = []
@ -289,9 +289,9 @@ async def _eval_group(node, tc, user, collection, limit):
return result
async def _eval_aggregate_join(node, tc, user, collection, limit):
async def _eval_aggregate_join(node, tc, workspace, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
result = []
for sol in solutions:
@ -310,7 +310,7 @@ async def _eval_aggregate_join(node, tc, user, collection, limit):
return result
async def _eval_graph(node, tc, user, collection, limit):
async def _eval_graph(node, tc, workspace, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term
@ -319,16 +319,16 @@ async def _eval_graph(node, tc, user, collection, limit):
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
else:
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_values(node, tc, user, collection, limit):
async def _eval_values(node, tc, workspace, collection, limit):
"""Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var]
solutions = []
@ -343,9 +343,9 @@ async def _eval_values(node, tc, user, collection, limit):
return solutions
async def _eval_to_multiset(node, tc, user, collection, limit):
async def _eval_to_multiset(node, tc, workspace, collection, limit):
"""Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
# --- Aggregate computation ---
@ -487,7 +487,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, user, collection, limit):
async def _query_pattern(tc, s, p, o, workspace, collection, limit):
"""
Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +496,7 @@ async def _query_pattern(tc, s, p, o, user, collection, limit):
results = await tc.query(
s=s, p=p, o=o,
limit=limit,
user=user,
workspace=workspace,
collection=collection,
)
return results

View file

@ -141,7 +141,7 @@ class Processor(FlowProcessor):
solutions = await evaluate(
parsed.algebra,
triples_client,
user=request.user or "trustgraph",
workspace=flow.workspace,
collection=request.collection or "default",
limit=request.limit or 10000,
)

View file

@ -178,34 +178,34 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def ensure_connection(self, user):
def ensure_connection(self, workspace):
"""Ensure we have a connection to the correct keyspace."""
if user != self.table:
if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
self.table = user
self.table = workspace
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:
# ensure_connection may construct a fresh
# EntityCentricKnowledgeGraph which does sync schema
# setup against Cassandra. Push it to a worker thread
# so the event loop doesn't block on first-use per user.
await asyncio.to_thread(self.ensure_connection, query.user)
# so the event loop doesn't block on first-use per workspace.
await asyncio.to_thread(self.ensure_connection, workspace)
# Extract values from query
s_val = get_term_value(query.s)
@ -359,13 +359,13 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
async def query_triples_stream(self, query):
async def query_triples_stream(self, workspace, query):
"""
Streaming query - yields (batch, is_final) tuples.
Uses Cassandra's paging to fetch results incrementally.
"""
try:
await asyncio.to_thread(self.ensure_connection, query.user)
await asyncio.to_thread(self.ensure_connection, workspace)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
@ -395,7 +395,7 @@ class Processor(TriplesQueryService):
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(query, batch_size):
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
yield batch, is_final
return
@ -452,9 +452,9 @@ class Processor(TriplesQueryService):
logger.error(f"Exception in streaming query: {e}", exc_info=True)
raise e
async def _fallback_stream(self, query, batch_size):
async def _fallback_stream(self, workspace, query, batch_size):
"""Fallback to non-streaming query with post-hoc batching."""
triples = await self.query_triples(query)
triples = await self.query_triples(workspace, query)
for i in range(0, len(triples), batch_size):
batch = triples[i:i + batch_size]

View file

@ -58,7 +58,7 @@ class Processor(TriplesQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:

View file

@ -63,12 +63,11 @@ class Processor(TriplesQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:
# Extract user and collection, use defaults if not provided
user = query.user if query.user else "default"
workspace = workspace
collection = query.collection if query.collection else "default"
triples = []
@ -80,13 +79,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -94,13 +93,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -112,13 +111,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -127,13 +126,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -148,13 +147,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -163,13 +162,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -182,13 +181,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -197,13 +196,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -221,13 +220,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -236,13 +235,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -255,13 +254,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -270,13 +269,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -291,13 +290,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -306,13 +305,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -325,12 +324,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -339,12 +338,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)

View file

@ -63,14 +63,12 @@ class Processor(TriplesQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:
# Extract user and collection, use defaults if not provided
user = query.user if query.user else "default"
collection = query.collection if query.collection else "default"
triples = []
if query.s is not None:
@ -80,13 +78,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -94,13 +92,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -112,13 +110,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -127,13 +125,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -148,13 +146,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -163,13 +161,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -182,13 +180,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -197,13 +195,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -221,13 +219,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -236,13 +234,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -255,13 +253,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -270,13 +268,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -291,13 +289,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -306,13 +304,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -325,12 +323,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -339,12 +337,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -367,7 +365,7 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
@staticmethod
def add_args(parser):

View file

@ -26,11 +26,11 @@ LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, user, collection, verbose,
self, rag, workspace, collection, verbose,
doc_limit=20, track_usage=None,
):
self.rag = rag
self.user = user
self.workspace = workspace
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
@ -97,7 +97,7 @@ class Query:
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
user=self.user, collection=self.collection,
collection=self.collection,
)
results = await asyncio.gather(
@ -122,7 +122,7 @@ class Query:
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
content = await self.rag.fetch_chunk(match.chunk_id, self.workspace)
docs.append(content)
chunk_ids.append(match.chunk_id)
except Exception as e:
@ -154,7 +154,7 @@ class DocumentRag:
logger.debug("DocumentRag initialized")
async def query(
self, query, user="trustgraph", collection="default",
self, query, workspace="default", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
@ -163,7 +163,7 @@ class DocumentRag:
Args:
query: The query string
user: User identifier
workspace: Workspace for isolation (also scopes chunk lookup)
collection: Collection identifier
doc_limit: Max chunks to retrieve
streaming: Enable streaming LLM response
@ -210,7 +210,8 @@ class DocumentRag:
await explain_callback(q_triples, q_uri)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
rag=self, workspace=workspace, collection=collection,
verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage,
)

View file

@ -96,19 +96,19 @@ class Processor(FlowProcessor):
await super(Processor, self).start()
await self.librarian.start()
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
async def fetch_chunk_content(self, chunk_id, workspace, timeout=120):
"""Fetch chunk content from librarian. Chunks are small so
single request-response is fine."""
return await self.librarian.fetch_document_text(
document_id=chunk_id, user=user, timeout=timeout,
document_id=chunk_id, workspace=workspace, timeout=timeout,
)
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""Save answer content to the librarian."""
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
@ -119,7 +119,7 @@ class Processor(FlowProcessor):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
await self.librarian.request(request, timeout=timeout)
@ -150,14 +150,13 @@ class Processor(FlowProcessor):
doc_limit = self.doc_limit
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
user=v.user,
collection=v.collection, # Store in user's collection
collection=v.collection,
),
triples=triples,
))
@ -178,7 +177,7 @@ class Processor(FlowProcessor):
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
workspace=flow.workspace,
content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...",
)
@ -202,7 +201,7 @@ class Processor(FlowProcessor):
# All chunks (including final one with end_of_stream=True) are sent via callback
response, usage = await self.rag.query(
v.query,
user=v.user,
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
streaming=True,
@ -227,7 +226,7 @@ class Processor(FlowProcessor):
# Non-streaming path - single response with answer and token usage
response, usage = await self.rag.query(
v.query,
user=v.user,
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
explain_callback=send_explainability,

View file

@ -75,12 +75,11 @@ def edge_id(s, p, o):
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
"""LRU cache with TTL for label caching.
CRITICAL SECURITY WARNING:
This cache is shared within a GraphRag instance but GraphRag instances
are created per-request. Cache keys MUST include user:collection prefix
to ensure data isolation between different security contexts.
GraphRag instances are created per-request, so this cache is
request-scoped. Cache keys include the collection prefix to keep
entries from different collections distinct within one request.
"""
def __init__(self, max_size=5000, ttl=300):
@ -119,12 +118,11 @@ class LRUCacheWithTTL:
class Query:
def __init__(
self, rag, user, collection, verbose,
self, rag, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, track_usage=None,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
@ -194,7 +192,7 @@ class Query:
entity_tasks = [
self.rag.graph_embeddings_client.query(
vector=v, limit=per_concept_limit,
user=self.user, collection=self.collection,
collection=self.collection,
)
for v in vectors
]
@ -222,18 +220,18 @@ class Query:
async def maybe_label(self, e):
# CRITICAL SECURITY: Cache key MUST include user and collection
# to prevent data leakage between different contexts
cache_key = f"{self.user}:{self.collection}:{e}"
# The label cache lives on a per-request GraphRag instance — no
# cross-request isolation concern. The collection prefix keeps
# entries from different collections distinct within one request.
cache_key = f"{self.collection}:{e}"
# Check LRU cache first with isolated key
cached_label = self.rag.label_cache.get(cache_key)
if cached_label is not None:
return cached_label
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
collection=self.collection,
g="",
)
@ -255,19 +253,19 @@ class Query:
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
collection=self.collection,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
collection=self.collection,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection,
collection=self.collection,
batch_size=20, g="",
)
])
@ -468,7 +466,7 @@ class Query:
subgraph_tasks.append(
self.rag.triples_client.query(
s=None, p=TG_CONTAINS, o=quoted, limit=1,
user=self.user, collection=self.collection,
collection=self.collection,
g=GRAPH_SOURCE,
)
)
@ -501,7 +499,7 @@ class Query:
derivation_tasks = [
self.rag.triples_client.query(
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
user=self.user, collection=self.collection,
collection=self.collection,
g=GRAPH_SOURCE,
)
for uri in current_uris
@ -535,7 +533,7 @@ class Query:
metadata_tasks = [
self.rag.triples_client.query(
s=uri, p=None, o=None, limit=50,
user=self.user, collection=self.collection,
collection=self.collection,
)
for uri in doc_uris
]
@ -560,11 +558,9 @@ class Query:
class GraphRag:
"""
CRITICAL SECURITY:
This class MUST be instantiated per-request to ensure proper isolation
between users and collections. The cache within this instance will only
live for the duration of a single request, preventing cross-contamination
of data between different security contexts.
Must be instantiated per-request so the label cache lives only for
the duration of a single request. Workspace isolation is enforced
by the trusted flow layer (flow.workspace), not by this class.
"""
def __init__(
@ -587,7 +583,7 @@ class GraphRag:
logger.debug("GraphRag initialized")
async def query(
self, query, user = "trustgraph", collection = "default",
self, query, collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
streaming = False,
@ -600,7 +596,6 @@ class GraphRag:
Args:
query: The query string
user: User identifier
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
@ -657,7 +652,7 @@ class GraphRag:
await explain_callback(q_triples, q_uri)
q = Query(
rag = self, user = user, collection = collection,
rag = self, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,

View file

@ -62,9 +62,9 @@ class Processor(FlowProcessor):
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
# Caching must NEVER allow information leakage across these boundaries
# Workspace isolation is enforced by the flow layer (flow.workspace).
# Per-request caching (see GraphRag) keeps within-request state
# scoped; no cross-request sharing here.
self.register_specification(
ConsumerSpec(
@ -170,13 +170,13 @@ class Processor(FlowProcessor):
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
@ -188,7 +188,7 @@ class Processor(FlowProcessor):
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
@ -199,7 +199,7 @@ class Processor(FlowProcessor):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
# Create future for response
@ -241,14 +241,13 @@ class Processor(FlowProcessor):
explainability_refs_emitted = []
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
user=v.user,
collection=v.collection, # Store in user's collection, not separate explainability collection
collection=v.collection,
),
triples=triples,
))
@ -266,9 +265,9 @@ class Processor(FlowProcessor):
explainability_refs_emitted.append(explain_id)
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
# Create new GraphRag instance per request — its label cache
# is request-scoped, and flow clients must not be shared
# across requests.
rag = GraphRag(
embeddings_client=flow("embeddings-request"),
graph_embeddings_client=flow("graph-embeddings-request"),
@ -311,7 +310,7 @@ class Processor(FlowProcessor):
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
workspace=flow.workspace,
content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...",
)
@ -333,7 +332,7 @@ class Processor(FlowProcessor):
# Query with streaming and real-time explain
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
query = v.query, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
@ -349,7 +348,7 @@ class Processor(FlowProcessor):
else:
# Non-streaming path with real-time explain
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
query = v.query, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
@ -464,7 +463,7 @@ class Processor(FlowProcessor):
help=f'Max edges after LLM scoring (default: 25)'
)
# Note: Explainability triples are now stored in the user's collection
# Note: Explainability triples are now stored in the request's collection
# with the named graph urn:graph:retrieval (no separate collection needed)
def run():

View file

@ -66,32 +66,39 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
logger.info("NLP Query service initialized")
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
@ -106,29 +113,37 @@ class Processor(FlowProcessor):
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
async def phase1_select_schemas(self, question: str, flow) -> List[str]:
"""Phase 1: Use prompt service to select relevant schemas for the question"""
logger.info("Starting Phase 1: Schema selection")
ws_schemas = self.schemas.get(flow.workspace, {})
# Prepare schema information for the prompt
schema_info = []
for name, schema in self.schemas.items():
for name, schema in ws_schemas.items():
schema_desc = {
"name": name,
"description": schema.description,
@ -176,12 +191,14 @@ class Processor(FlowProcessor):
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]:
"""Phase 2: Generate GraphQL query using selected schemas"""
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
ws_schemas = self.schemas.get(flow.workspace, {})
# Get detailed schema information for selected schemas only
selected_schema_info = []
for schema_name in selected_schemas:
if schema_name in self.schemas:
schema = self.schemas[schema_name]
if schema_name in ws_schemas:
schema = ws_schemas[schema_name]
schema_desc = {
"name": schema_name,
"description": schema.description,

View file

@ -72,21 +72,28 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
logger.info("Structured Data Diagnosis service initialized")
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
@ -120,13 +127,19 @@ class Processor(FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
async def on_message(self, msg, consumer, flow):
"""Handle incoming structured data diagnosis request"""
@ -216,15 +229,19 @@ class Processor(FlowProcessor):
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Get target schema
if request.schema_name not in self.schemas:
# Get target schema from this workspace's schemas
ws_schemas = self.schemas.get(flow.workspace, {})
if request.schema_name not in ws_schemas:
error = Error(
type="SchemaNotFound",
message=f"Schema '{request.schema_name}' not found in configuration"
message=(
f"Schema '{request.schema_name}' not found "
f"in configuration for workspace {flow.workspace}"
)
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
target_schema = self.schemas[request.schema_name]
target_schema = ws_schemas[request.schema_name]
# Generate descriptor using prompt service
descriptor = await self.generate_descriptor_with_prompt(
@ -260,26 +277,33 @@ class Processor(FlowProcessor):
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Step 2: Use provided schema name or auto-select first available
ws_schemas = self.schemas.get(flow.workspace, {})
schema_name = request.schema_name
if not schema_name and self.schemas:
schema_name = list(self.schemas.keys())[0]
if not schema_name and ws_schemas:
schema_name = list(ws_schemas.keys())[0]
logger.info(f"Auto-selected schema: {schema_name}")
if not schema_name:
error = Error(
type="NoSchemaAvailable",
message="No schema specified and no schemas available in configuration"
message=(
f"No schema specified and no schemas available "
f"in configuration for workspace {flow.workspace}"
)
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
if schema_name not in self.schemas:
if schema_name not in ws_schemas:
error = Error(
type="SchemaNotFound",
message=f"Schema '{schema_name}' not found in configuration"
message=(
f"Schema '{schema_name}' not found in "
f"configuration for workspace {flow.workspace}"
)
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
target_schema = self.schemas[schema_name]
target_schema = ws_schemas[schema_name]
# Step 3: Generate descriptor
descriptor = await self.generate_descriptor_with_prompt(
@ -316,8 +340,9 @@ class Processor(FlowProcessor):
logger.info("Processing schema-selection operation")
# Prepare all schemas for the prompt - match the original config format
ws_schemas = self.schemas.get(flow.workspace, {})
all_schemas = []
for schema_name, row_schema in self.schemas.items():
for schema_name, row_schema in ws_schemas.items():
schema_info = {
"name": row_schema.name,
"description": row_schema.description,

View file

@ -111,9 +111,9 @@ class Processor(FlowProcessor):
else:
variables_as_strings[key] = str(value)
# Use user/collection values from request
# Use collection from request. Workspace isolation is
# enforced by flow.workspace at the rows-query service.
objects_request = RowsQueryRequest(
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
variables=variables_as_strings,

View file

@ -33,7 +33,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_document_embeddings(self, message):
async def store_document_embeddings(self, workspace, message):
for emb in message.chunks:
@ -45,7 +45,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if vec:
self.vecstore.insert(
vec, chunk_id,
message.metadata.user,
workspace,
message.metadata.collection
)
@ -60,27 +60,27 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection)
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(workspace, collection)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
self.vecstore.delete_collection(user, collection)
logger.info(f"Successfully deleted collection {user}/{collection}")
self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -88,12 +88,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def store_document_embeddings(self, message):
async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -112,7 +112,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
f"d-{workspace}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
@ -165,22 +165,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
prefix = f"d-{user}-{collection}-"
prefix = f"d-{workspace}-{collection}-"
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
@ -195,10 +195,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -39,12 +39,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_document_embeddings(self, message):
async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -63,7 +63,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
f"d_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
@ -107,22 +107,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
prefix = f"d_{user}_{collection}_"
prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -137,10 +137,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_graph_embeddings(self, message):
async def store_graph_embeddings(self, workspace, message):
for entity in message.entities:
entity_value = get_term_value(entity.entity)
@ -57,7 +57,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if vec:
self.vecstore.insert(
vec, entity_value,
message.metadata.user,
workspace,
message.metadata.collection,
chunk_id=entity.chunk_id or "",
)
@ -73,27 +73,27 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection)
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(workspace, collection)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
self.vecstore.delete_collection(user, collection)
logger.info(f"Successfully deleted collection {user}/{collection}")
self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -102,12 +102,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def store_graph_embeddings(self, message):
async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -126,7 +126,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
f"t-{workspace}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
@ -183,22 +183,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
prefix = f"t-{user}-{collection}-"
prefix = f"t-{workspace}-{collection}-"
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
@ -213,10 +213,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -54,12 +54,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_graph_embeddings(self, message):
async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -78,7 +78,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
f"t_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
@ -126,22 +126,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
prefix = f"t_{user}_{collection}_"
prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -156,10 +156,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -65,13 +65,13 @@ class Processor(FlowProcessor):
v = msg.value()
if v.triples:
await self.table_store.add_triples(v)
await self.table_store.add_triples(flow.workspace, v)
async def on_graph_embeddings(self, msg, consumer, flow):
v = msg.value()
if v.entities:
await self.table_store.add_graph_embeddings(v)
await self.table_store.add_graph_embeddings(flow.workspace, v)
@staticmethod
def add_args(parser):

View file

@ -2,13 +2,13 @@
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
@ -77,10 +77,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
return safe_name.lower()
def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int
self, workspace: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(user)
safe_user = self.sanitize_name(workspace)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
@ -114,18 +114,19 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"{embeddings.schema_name} from {embeddings.metadata.id}"
)
workspace = flow.workspace
# Validate collection exists in config before processing
if not self.collection_exists(
embeddings.metadata.user, embeddings.metadata.collection
workspace, embeddings.metadata.collection
):
logger.warning(
f"Collection {embeddings.metadata.collection} for user "
f"{embeddings.metadata.user} does not exist in config. "
f"Collection {embeddings.metadata.collection} for workspace "
f"{workspace} does not exist in config. "
f"Dropping message."
)
return
user = embeddings.metadata.user
collection = embeddings.metadata.collection
schema_name = embeddings.schema_name
@ -145,7 +146,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
workspace, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
@ -168,17 +169,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {user}/{collection} - "
f"Row embeddings collection create request for {workspace}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, user: str, collection: str):
"""Delete all Qdrant collections for a given user/collection"""
async def delete_collection(self, workspace: str, collection: str):
"""Delete all Qdrant collections for a given workspace/collection"""
try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -196,23 +197,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}"
f"for {workspace}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}: {e}",
f"Failed to delete collection {workspace}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, user: str, collection: str, schema_name: str
self, workspace: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific user/collection/schema"""
"""Delete Qdrant collection for a specific workspace/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
@ -233,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise

View file

@ -119,19 +119,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Track which schemas changed so we can clear partition cache
old_schema_names = set(self.schemas.keys())
# Track which schemas changed in this workspace
old_schemas = self.schemas.get(workspace, {})
old_schema_names = set(old_schemas.keys())
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
@ -165,24 +173,32 @@ class Processor(CollectionConfigHandler, FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
# Clear partition cache for schemas that changed
# This ensures next write will re-register partitions
new_schema_names = set(self.schemas.keys())
# Clear partition cache for schemas that changed in this workspace
new_schema_names = set(ws_schemas.keys())
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
if changed_schemas:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if sch not in changed_schemas
}
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
logger.info(
f"Cleared partition cache for changed schemas "
f"in {workspace}: {changed_schemas}"
)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
@ -286,7 +302,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
return index_names
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
def register_partitions(
self, keyspace: str, collection: str, schema_name: str,
workspace: str,
):
"""
Register partition entries for a (collection, schema_name) pair.
Called once on first row for each pair.
@ -295,9 +314,13 @@ class Processor(CollectionConfigHandler, FlowProcessor):
if cache_key in self.registered_partitions:
return
schema = self.schemas.get(schema_name)
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(schema_name)
if not schema:
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
logger.warning(
f"Cannot register partitions - schema {schema_name} "
f"not found in workspace {workspace}"
)
return
safe_keyspace = self.sanitize_name(keyspace)
@ -338,13 +361,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
workspace = flow.workspace
logger.info(
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
f"from {obj.metadata.id}"
f"from {obj.metadata.id} (workspace {workspace})"
)
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
if not self.collection_exists(workspace, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
@ -352,13 +376,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
# Get schema definition for this workspace
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
logger.warning(
f"No schema found for {obj.schema_name} in "
f"workspace {workspace} - skipping"
)
return
keyspace = obj.metadata.user
keyspace = workspace
collection = obj.metadata.collection
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
@ -370,7 +398,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register partitions if first time seeing this (collection, schema_name)
await asyncio.to_thread(
self.register_partitions, keyspace, collection, schema_name
self.register_partitions,
keyspace, collection, schema_name, workspace,
)
safe_keyspace = self.sanitize_name(keyspace)
@ -430,25 +459,25 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread)
await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, user)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for user {user}")
logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists
if user not in self.known_keyspaces:
if workspace not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
@ -459,7 +488,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
self.known_keyspaces.add(workspace)
# Discover all partitions for this collection
select_partitions_cql = f"""
@ -522,12 +551,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Discover partitions for this collection + schema
select_partitions_cql = f"""

View file

@ -147,9 +147,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_triples(self, message):
user = message.metadata.user
async def store_triples(self, workspace, message):
# The cassandra-driver work below — connection, schema
# setup, and per-triple inserts — is all synchronous.
@ -159,7 +157,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
def _do_store():
if self.table is None or self.table != user:
if self.table is None or self.table != workspace:
self.tg = None
@ -170,21 +168,21 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = user
self.table = workspace
for t in message.triples:
# Extract values from Term objects
@ -212,12 +210,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
await asyncio.to_thread(_do_store)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push"""
def _do_create():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
@ -227,23 +225,23 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}")
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = user
self.table = workspace
# Create collection using the built-in method
logger.info(f"Creating collection {collection} for user {user}")
logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection):
logger.info(f"Collection {collection} already exists")
@ -254,15 +252,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
await asyncio.to_thread(_do_create)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table"""
def _do_delete():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
@ -272,29 +270,29 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}")
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = user
self.table = workspace
# Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {user}")
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
@staticmethod

View file

@ -59,15 +59,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -77,15 +77,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": value,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -95,19 +95,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -117,19 +117,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -139,36 +139,34 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""Check if collection metadata node exists"""
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
params={"user": user, "collection": collection}
params={"workspace": workspace, "collection": collection}
)
return result.result_set is not None and len(result.result_set) > 0
def create_collection(self, user, collection):
def create_collection(self, workspace, collection):
"""Create collection metadata node"""
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"workspace": workspace,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
async def store_triples(self, workspace, message):
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
if not self.collection_exists(workspace, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first via collection management API."
@ -182,14 +180,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
self.create_node(s_val, user, collection)
self.create_node(s_val, workspace, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
self.create_node(o_val, workspace, collection)
self.relate_node(s_val, p_val, o_val, workspace, collection)
else:
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
self.create_literal(o_val, workspace, collection)
self.relate_literal(s_val, p_val, o_val, workspace, collection)
@staticmethod
def add_args(parser):
@ -208,58 +206,58 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in FalkorDB via config push"""
try:
# Check if collection exists
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1",
params={"user": user, "collection": collection}
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) RETURN c LIMIT 1",
params={"workspace": workspace, "collection": collection}
)
if result.result_set:
logger.info(f"Collection {user}/{collection} already exists")
logger.info(f"Collection {workspace}/{collection} already exists")
else:
# Create collection metadata node
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"workspace": workspace,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection {user}/{collection}")
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for FalkorDB triples via config push"""
try:
# Delete all nodes and literals for this user/collection
# Delete all nodes and literals for this workspace/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection}
"MATCH (n:Node {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"workspace": workspace, "collection": collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection}
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"workspace": workspace, "collection": collection}
)
# Delete collection metadata node
metadata_result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
params={"user": user, "collection": collection}
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) DELETE c",
params={"workspace": workspace, "collection": collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}")
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -117,10 +117,10 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Maybe index already exists
logger.warning("Index create failure ignored")
# New indexes for user/collection filtering
# New indexes for workspace/collection filtering
try:
session.run(
"CREATE INDEX ON :Node(user)"
"CREATE INDEX ON :Node(workspace)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
@ -136,7 +136,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
session.run(
"CREATE INDEX ON :Literal(user)"
"CREATE INDEX ON :Literal(workspace)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
@ -152,13 +152,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done")
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -167,13 +167,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -182,15 +182,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -199,15 +199,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_triple(self, tx, t, user, collection):
def create_triple(self, tx, t, workspace, collection):
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
@ -224,48 +224,46 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Create new s node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=s_val, user=user, collection=collection
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=s_val, workspace=workspace, collection=collection
)
if t.o.type == IRI:
# Create new o node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=o_val, user=user, collection=collection
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=o_val, workspace=workspace, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
)
else:
# Create new o literal with given uri, if not exists
result = tx.run(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=o_val, user=user, collection=collection
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=o_val, workspace=workspace, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
)
async def store_triples(self, message):
async def store_triples(self, workspace, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
if not self.collection_exists(workspace, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first via collection management API."
@ -279,18 +277,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
self.create_node(s_val, user, collection)
self.create_node(s_val, workspace, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
self.create_node(o_val, workspace, collection)
self.relate_node(s_val, p_val, o_val, workspace, collection)
else:
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
self.create_literal(o_val, workspace, collection)
self.relate_literal(s_val, p_val, o_val, workspace, collection)
# Alternative implementation using transactions
# with self.io.session(database=self.db) as session:
# session.execute_write(self.create_triple, t, user, collection)
# session.execute_write(self.create_triple, t, workspace, collection)
@staticmethod
def add_args(parser):
@ -321,72 +319,72 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
def _collection_exists_in_db(self, user, collection):
def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
workspace=workspace, collection=collection
)
return bool(list(result))
def _create_collection_in_db(self, user, collection):
def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Memgraph via config push"""
try:
if self._collection_exists_in_db(user, collection):
logger.info(f"Collection {user}/{collection} already exists")
if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {workspace}/{collection} already exists")
else:
self._create_collection_in_db(user, collection)
logger.info(f"Created collection {user}/{collection}")
self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
# Delete all nodes for this workspace and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
# Delete all literals for this workspace and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c",
user=user, collection=collection
workspace=workspace, collection=collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -80,14 +80,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Create indexes...")
# Legacy indexes for backwards compatibility
try:
session.run(
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
try:
@ -96,7 +94,6 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
try:
@ -105,13 +102,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
# New compound indexes for user/collection filtering
try:
session.run(
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
@ -119,17 +114,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
session.run(
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
logger.warning("Index create failure ignored")
# Note: Neo4j doesn't support compound indexes on relationships in all versions
# Try to create individual indexes on relationship properties
# Neo4j doesn't support compound indexes on relationships in all versions
try:
session.run(
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
)
except Exception as e:
logger.warning(f"Relationship index create failure: {e}")
@ -145,13 +139,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done")
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -160,13 +154,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -175,15 +169,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -192,15 +186,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -209,14 +203,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
async def store_triples(self, message):
async def store_triples(self, workspace, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
if not self.collection_exists(workspace, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first via collection management API."
@ -230,14 +222,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
self.create_node(s_val, user, collection)
self.create_node(s_val, workspace, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
self.create_node(o_val, workspace, collection)
self.relate_node(s_val, p_val, o_val, workspace, collection)
else:
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
self.create_literal(o_val, workspace, collection)
self.relate_literal(s_val, p_val, o_val, workspace, collection)
@staticmethod
def add_args(parser):
@ -268,75 +260,70 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
def _collection_exists_in_db(self, user, collection):
def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
workspace=workspace, collection=collection
)
return bool(list(result))
def _create_collection_in_db(self, user, collection):
def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Neo4j via config push"""
try:
if self._collection_exists_in_db(user, collection):
logger.info(f"Collection {user}/{collection} already exists")
if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {workspace}/{collection} already exists")
else:
self._create_collection_in_db(user, collection)
logger.info(f"Created collection {user}/{collection}")
self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c",
user=user, collection=collection
workspace=workspace, collection=collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -72,10 +72,11 @@ class ConfigTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS config (
workspace text,
class text,
key text,
value text,
PRIMARY KEY (class, key)
PRIMARY KEY ((workspace, class), key)
);
""");
@ -124,52 +125,63 @@ class ConfigTableStore:
def prepare_statements(self):
self.put_config_stmt = self.cassandra.prepare("""
INSERT INTO config ( class, key, value )
VALUES (?, ?, ?)
""")
self.get_classes_stmt = self.cassandra.prepare("""
SELECT DISTINCT class FROM config;
INSERT INTO config ( workspace, class, key, value )
VALUES (?, ?, ?, ?)
""")
self.get_keys_stmt = self.cassandra.prepare("""
SELECT key FROM config WHERE class = ?;
SELECT key FROM config
WHERE workspace = ? AND class = ?;
""")
self.get_value_stmt = self.cassandra.prepare("""
SELECT value FROM config WHERE class = ? AND key = ?;
SELECT value FROM config
WHERE workspace = ? AND class = ? AND key = ?;
""")
self.delete_key_stmt = self.cassandra.prepare("""
DELETE FROM config
WHERE class = ? AND key = ?;
WHERE workspace = ? AND class = ? AND key = ?;
""")
self.get_all_stmt = self.cassandra.prepare("""
SELECT class AS cls, key, value FROM config;
SELECT workspace, class AS cls, key, value FROM config;
""")
self.get_all_for_workspace_stmt = self.cassandra.prepare("""
SELECT class AS cls, key, value FROM config
WHERE workspace = ?
ALLOW FILTERING;
""")
self.get_values_stmt = self.cassandra.prepare("""
SELECT key, value FROM config WHERE class = ?;
SELECT key, value FROM config
WHERE workspace = ? AND class = ?;
""")
async def put_config(self, cls, key, value):
self.get_values_all_ws_stmt = self.cassandra.prepare("""
SELECT workspace, key, value FROM config
WHERE class = ?
ALLOW FILTERING;
""")
async def put_config(self, workspace, cls, key, value):
try:
await async_execute(
self.cassandra,
self.put_config_stmt,
(cls, key, value),
(workspace, cls, key, value),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def get_value(self, cls, key):
async def get_value(self, workspace, cls, key):
try:
rows = await async_execute(
self.cassandra,
self.get_value_stmt,
(cls, key),
(workspace, cls, key),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -179,12 +191,12 @@ class ConfigTableStore:
return row[0]
return None
async def get_values(self, cls):
async def get_values(self, workspace, cls):
try:
rows = await async_execute(
self.cassandra,
self.get_values_stmt,
(cls,),
(workspace, cls),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -192,18 +204,20 @@ class ConfigTableStore:
return [[row[0], row[1]] for row in rows]
async def get_classes(self):
async def get_values_all_ws(self, cls):
"""Return (workspace, key, value) tuples for all workspaces
with entries of the given class."""
try:
rows = await async_execute(
self.cassandra,
self.get_classes_stmt,
(),
self.get_values_all_ws_stmt,
(cls,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
return [row[0] for row in rows]
return [(row[0], row[1], row[2]) for row in rows]
async def get_all(self):
try:
@ -216,14 +230,27 @@ class ConfigTableStore:
logger.error("Exception occurred", exc_info=True)
raise
return [(row[0], row[1], row[2], row[3]) for row in rows]
async def get_all_for_workspace(self, workspace):
try:
rows = await async_execute(
self.cassandra,
self.get_all_for_workspace_stmt,
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
return [(row[0], row[1], row[2]) for row in rows]
async def get_keys(self, cls):
async def get_keys(self, workspace, cls):
try:
rows = await async_execute(
self.cassandra,
self.get_keys_stmt,
(cls,),
(workspace, cls),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -231,12 +258,12 @@ class ConfigTableStore:
return [row[0] for row in rows]
async def delete_key(self, cls, key):
async def delete_key(self, workspace, cls, key):
try:
await async_execute(
self.cassandra,
self.delete_key_stmt,
(cls, key),
(workspace, cls, key),
)
except Exception:
logger.error("Exception occurred", exc_info=True)

View file

@ -88,7 +88,7 @@ class KnowledgeTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS triples (
user text,
workspace text,
document_id text,
id uuid,
time timestamp,
@ -98,7 +98,7 @@ class KnowledgeTableStore:
triples list<tuple<
text, boolean, text, boolean, text, boolean
>>,
PRIMARY KEY ((user, document_id), id)
PRIMARY KEY ((workspace, document_id), id)
);
""");
@ -106,7 +106,7 @@ class KnowledgeTableStore:
self.cassandra.execute("""
create table if not exists graph_embeddings (
user text,
workspace text,
document_id text,
id uuid,
time timestamp,
@ -119,20 +119,20 @@ class KnowledgeTableStore:
list<double>
>
>,
PRIMARY KEY ((user, document_id), id)
PRIMARY KEY ((workspace, document_id), id)
);
""");
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS graph_embeddings_user ON
graph_embeddings ( user );
CREATE INDEX IF NOT EXISTS graph_embeddings_workspace ON
graph_embeddings ( workspace );
""");
logger.debug("document_embeddings table...")
self.cassandra.execute("""
create table if not exists document_embeddings (
user text,
workspace text,
document_id text,
id uuid,
time timestamp,
@ -145,13 +145,13 @@ class KnowledgeTableStore:
list<double>
>
>,
PRIMARY KEY ((user, document_id), id)
PRIMARY KEY ((workspace, document_id), id)
);
""");
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS document_embeddings_user ON
document_embeddings ( user );
CREATE INDEX IF NOT EXISTS document_embeddings_workspace ON
document_embeddings ( workspace );
""");
logger.info("Cassandra schema OK.")
@ -161,7 +161,7 @@ class KnowledgeTableStore:
self.insert_triples_stmt = self.cassandra.prepare("""
INSERT INTO triples
(
id, user, document_id,
id, workspace, document_id,
time, metadata, triples
)
VALUES (?, ?, ?, ?, ?, ?)
@ -170,7 +170,7 @@ class KnowledgeTableStore:
self.insert_graph_embeddings_stmt = self.cassandra.prepare("""
INSERT INTO graph_embeddings
(
id, user, document_id, time, metadata, entity_embeddings
id, workspace, document_id, time, metadata, entity_embeddings
)
VALUES (?, ?, ?, ?, ?, ?)
""")
@ -178,45 +178,45 @@ class KnowledgeTableStore:
self.insert_document_embeddings_stmt = self.cassandra.prepare("""
INSERT INTO document_embeddings
(
id, user, document_id, time, metadata, chunks
id, workspace, document_id, time, metadata, chunks
)
VALUES (?, ?, ?, ?, ?, ?)
""")
self.list_cores_stmt = self.cassandra.prepare("""
SELECT DISTINCT user, document_id FROM graph_embeddings
WHERE user = ?
SELECT DISTINCT workspace, document_id FROM graph_embeddings
WHERE workspace = ?
""")
self.get_triples_stmt = self.cassandra.prepare("""
SELECT id, time, metadata, triples
FROM triples
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.get_graph_embeddings_stmt = self.cassandra.prepare("""
SELECT id, time, metadata, entity_embeddings
FROM graph_embeddings
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.get_document_embeddings_stmt = self.cassandra.prepare("""
SELECT id, time, metadata, chunks
FROM document_embeddings
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.delete_triples_stmt = self.cassandra.prepare("""
DELETE FROM triples
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.delete_graph_embeddings_stmt = self.cassandra.prepare("""
DELETE FROM graph_embeddings
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
async def add_triples(self, m):
async def add_triples(self, workspace, m):
when = int(time.time() * 1000)
@ -232,7 +232,7 @@ class KnowledgeTableStore:
self.cassandra,
self.insert_triples_stmt,
(
uuid.uuid4(), m.metadata.user,
uuid.uuid4(), workspace,
m.metadata.root or m.metadata.id, when,
[], triples,
),
@ -241,7 +241,7 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
async def add_graph_embeddings(self, m):
async def add_graph_embeddings(self, workspace, m):
when = int(time.time() * 1000)
@ -258,7 +258,7 @@ class KnowledgeTableStore:
self.cassandra,
self.insert_graph_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
uuid.uuid4(), workspace,
m.metadata.root or m.metadata.id, when,
[], entities,
),
@ -267,7 +267,7 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
async def add_document_embeddings(self, m):
async def add_document_embeddings(self, workspace, m):
when = int(time.time() * 1000)
@ -284,7 +284,7 @@ class KnowledgeTableStore:
self.cassandra,
self.insert_document_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
uuid.uuid4(), workspace,
m.metadata.root or m.metadata.id, when,
[], chunks,
),
@ -293,7 +293,7 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
async def list_kg_cores(self, user):
async def list_kg_cores(self, workspace):
logger.debug("List kg cores...")
@ -301,7 +301,7 @@ class KnowledgeTableStore:
rows = await async_execute(
self.cassandra,
self.list_cores_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -313,7 +313,7 @@ class KnowledgeTableStore:
return lst
async def delete_kg_core(self, user, document_id):
async def delete_kg_core(self, workspace, document_id):
logger.debug("Delete kg cores...")
@ -321,7 +321,7 @@ class KnowledgeTableStore:
await async_execute(
self.cassandra,
self.delete_triples_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -331,13 +331,13 @@ class KnowledgeTableStore:
await async_execute(
self.cassandra,
self.delete_graph_embeddings_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def get_triples(self, user, document_id, receiver):
async def get_triples(self, workspace, document_id, receiver):
logger.debug("Get triples...")
@ -345,7 +345,7 @@ class KnowledgeTableStore:
rows = await async_execute(
self.cassandra,
self.get_triples_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -369,7 +369,6 @@ class KnowledgeTableStore:
Triples(
metadata = Metadata(
id = document_id,
user = user,
collection = "default", # FIXME: What to put here?
),
triples = triples
@ -378,7 +377,7 @@ class KnowledgeTableStore:
logger.debug("Done")
async def get_graph_embeddings(self, user, document_id, receiver):
async def get_graph_embeddings(self, workspace, document_id, receiver):
logger.debug("Get GE...")
@ -386,7 +385,7 @@ class KnowledgeTableStore:
rows = await async_execute(
self.cassandra,
self.get_graph_embeddings_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -409,12 +408,11 @@ class KnowledgeTableStore:
GraphEmbeddings(
metadata = Metadata(
id = document_id,
user = user,
collection = "default", # FIXME: What to put here?
),
entities = entities
)
)
)
logger.debug("Done")

View file

@ -64,7 +64,7 @@ class LibraryTableStore:
self.cluster = Cluster(cassandra_host)
self.cassandra = self.cluster.connect()
logger.info("Connected.")
self.ensure_cassandra_schema()
@ -76,13 +76,13 @@ class LibraryTableStore:
logger.debug("Ensure Cassandra schema...")
logger.debug("Keyspace...")
# FIXME: Replication factor should be configurable
self.cassandra.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
}};
""");
@ -93,7 +93,7 @@ class LibraryTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS document (
id text,
user text,
workspace text,
time timestamp,
kind text,
title text,
@ -103,7 +103,9 @@ class LibraryTableStore:
>>,
tags list<text>,
object_id uuid,
PRIMARY KEY (user, id)
parent_id text,
document_type text,
PRIMARY KEY (workspace, id)
);
""");
@ -114,27 +116,6 @@ class LibraryTableStore:
ON document (object_id)
""");
# Add parent_id and document_type columns for child document support
logger.debug("document table parent_id column...")
try:
self.cassandra.execute("""
ALTER TABLE document ADD parent_id text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"parent_id column may already exist: {e}")
try:
self.cassandra.execute("""
ALTER TABLE document ADD document_type text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"document_type column may already exist: {e}")
logger.debug("document parent index...")
self.cassandra.execute("""
@ -150,10 +131,10 @@ class LibraryTableStore:
document_id text,
time timestamp,
flow text,
user text,
workspace text,
collection text,
tags list<text>,
PRIMARY KEY (user, id)
PRIMARY KEY (workspace, id)
);
""");
@ -162,7 +143,7 @@ class LibraryTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS upload_session (
upload_id text PRIMARY KEY,
user text,
workspace text,
document_id text,
document_metadata text,
s3_upload_id text,
@ -176,11 +157,11 @@ class LibraryTableStore:
) WITH default_time_to_live = 86400;
""");
logger.debug("upload_session user index...")
logger.debug("upload_session workspace index...")
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS upload_session_user
ON upload_session (user)
CREATE INDEX IF NOT EXISTS upload_session_workspace
ON upload_session (workspace)
""");
logger.info("Cassandra schema OK.")
@ -190,7 +171,7 @@ class LibraryTableStore:
self.insert_document_stmt = self.cassandra.prepare("""
INSERT INTO document
(
id, user, time,
id, workspace, time,
kind, title, comments,
metadata, tags, object_id,
parent_id, document_type
@ -202,25 +183,25 @@ class LibraryTableStore:
UPDATE document
SET time = ?, title = ?, comments = ?,
metadata = ?, tags = ?
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.get_document_stmt = self.cassandra.prepare("""
SELECT time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.delete_document_stmt = self.cassandra.prepare("""
DELETE FROM document
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.test_document_exists_stmt = self.cassandra.prepare("""
SELECT id
FROM document
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
LIMIT 1
""")
@ -229,7 +210,7 @@ class LibraryTableStore:
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ?
WHERE workspace = ?
""")
self.list_document_by_tag_stmt = self.cassandra.prepare("""
@ -237,7 +218,7 @@ class LibraryTableStore:
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND tags CONTAINS ?
WHERE workspace = ? AND tags CONTAINS ?
ALLOW FILTERING
""")
@ -245,7 +226,7 @@ class LibraryTableStore:
INSERT INTO processing
(
id, document_id, time,
flow, user, collection,
flow, workspace, collection,
tags
)
VALUES (?, ?, ?, ?, ?, ?, ?)
@ -253,13 +234,13 @@ class LibraryTableStore:
self.delete_processing_stmt = self.cassandra.prepare("""
DELETE FROM processing
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.test_processing_exists_stmt = self.cassandra.prepare("""
SELECT id
FROM processing
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
LIMIT 1
""")
@ -267,14 +248,14 @@ class LibraryTableStore:
SELECT
id, document_id, time, flow, collection, tags
FROM processing
WHERE user = ?
WHERE workspace = ?
""")
# Upload session prepared statements
self.insert_upload_session_stmt = self.cassandra.prepare("""
INSERT INTO upload_session
(
upload_id, user, document_id, document_metadata,
upload_id, workspace, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
)
@ -283,7 +264,7 @@ class LibraryTableStore:
self.get_upload_session_stmt = self.cassandra.prepare("""
SELECT
upload_id, user, document_id, document_metadata,
upload_id, workspace, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
FROM upload_session
@ -308,25 +289,25 @@ class LibraryTableStore:
total_size, chunk_size, total_chunks,
chunks_received, created_at, updated_at
FROM upload_session
WHERE user = ?
WHERE workspace = ?
""")
# Child document queries
self.list_children_stmt = self.cassandra.prepare("""
SELECT
id, user, time, kind, title, comments, metadata, tags,
id, workspace, time, kind, title, comments, metadata, tags,
object_id, parent_id, document_type
FROM document
WHERE parent_id = ?
ALLOW FILTERING
""")
async def document_exists(self, user, id):
async def document_exists(self, workspace, id):
rows = await async_execute(
self.cassandra,
self.test_document_exists_stmt,
(user, id),
(workspace, id),
)
return bool(rows)
@ -351,7 +332,7 @@ class LibraryTableStore:
self.cassandra,
self.insert_document_stmt,
(
document.id, document.user, int(document.time * 1000),
document.id, document.workspace, int(document.time * 1000),
document.kind, document.title, document.comments,
metadata, document.tags, object_id,
parent_id, document_type
@ -381,7 +362,7 @@ class LibraryTableStore:
(
int(document.time * 1000), document.title,
document.comments, metadata, document.tags,
document.user, document.id
document.workspace, document.id
),
)
except Exception:
@ -390,7 +371,7 @@ class LibraryTableStore:
logger.debug("Update complete")
async def remove_document(self, user, document_id):
async def remove_document(self, workspace, document_id):
logger.info(f"Removing document {document_id}")
@ -398,7 +379,7 @@ class LibraryTableStore:
await async_execute(
self.cassandra,
self.delete_document_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -406,7 +387,7 @@ class LibraryTableStore:
logger.debug("Delete complete")
async def list_documents(self, user):
async def list_documents(self, workspace):
logger.debug("List documents...")
@ -414,7 +395,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.list_document_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -423,7 +404,7 @@ class LibraryTableStore:
lst = [
DocumentMetadata(
id = row[0],
user = user,
workspace = workspace,
time = int(time.mktime(row[1].timetuple())),
kind = row[2],
title = row[3],
@ -465,7 +446,7 @@ class LibraryTableStore:
lst = [
DocumentMetadata(
id = row[0],
user = row[1],
workspace = row[1],
time = int(time.mktime(row[2].timetuple())),
kind = row[3],
title = row[4],
@ -489,7 +470,7 @@ class LibraryTableStore:
return lst
async def get_document(self, user, id):
async def get_document(self, workspace, id):
logger.debug("Get document")
@ -497,7 +478,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.get_document_stmt,
(user, id),
(workspace, id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -506,7 +487,7 @@ class LibraryTableStore:
for row in rows:
doc = DocumentMetadata(
id = id,
user = user,
workspace = workspace,
time = int(time.mktime(row[0].timetuple())),
kind = row[1],
title = row[2],
@ -529,7 +510,7 @@ class LibraryTableStore:
raise RuntimeError("No such document row?")
async def get_document_object_id(self, user, id):
async def get_document_object_id(self, workspace, id):
logger.debug("Get document obj ID")
@ -537,7 +518,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.get_document_stmt,
(user, id),
(workspace, id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -549,12 +530,12 @@ class LibraryTableStore:
raise RuntimeError("No such document row?")
async def processing_exists(self, user, id):
async def processing_exists(self, workspace, id):
rows = await async_execute(
self.cassandra,
self.test_processing_exists_stmt,
(user, id),
(workspace, id),
)
return bool(rows)
@ -570,7 +551,7 @@ class LibraryTableStore:
(
processing.id, processing.document_id,
int(processing.time * 1000), processing.flow,
processing.user, processing.collection,
processing.workspace, processing.collection,
processing.tags
),
)
@ -580,7 +561,7 @@ class LibraryTableStore:
logger.debug("Add complete")
async def remove_processing(self, user, processing_id):
async def remove_processing(self, workspace, processing_id):
logger.info(f"Removing processing {processing_id}")
@ -588,7 +569,7 @@ class LibraryTableStore:
await async_execute(
self.cassandra,
self.delete_processing_stmt,
(user, processing_id),
(workspace, processing_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -596,7 +577,7 @@ class LibraryTableStore:
logger.debug("Delete complete")
async def list_processing(self, user):
async def list_processing(self, workspace):
logger.debug("List processing objects")
@ -604,7 +585,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.list_processing_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -616,7 +597,7 @@ class LibraryTableStore:
document_id = row[1],
time = int(time.mktime(row[2].timetuple())),
flow = row[3],
user = user,
workspace = workspace,
collection = row[4],
tags = row[5] if row[5] else [],
)
@ -632,7 +613,7 @@ class LibraryTableStore:
async def create_upload_session(
self,
upload_id,
user,
workspace,
document_id,
document_metadata,
s3_upload_id,
@ -652,7 +633,7 @@ class LibraryTableStore:
self.cassandra,
self.insert_upload_session_stmt,
(
upload_id, user, document_id, document_metadata,
upload_id, workspace, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, {}, now, now
),
@ -681,7 +662,7 @@ class LibraryTableStore:
for row in rows:
session = {
"upload_id": row[0],
"user": row[1],
"workspace": row[1],
"document_id": row[2],
"document_metadata": row[3],
"s3_upload_id": row[4],
@ -738,16 +719,16 @@ class LibraryTableStore:
logger.debug("Upload session deleted")
async def list_upload_sessions(self, user):
"""List all upload sessions for a user."""
async def list_upload_sessions(self, workspace):
"""List all upload sessions for a workspace."""
logger.debug(f"List upload sessions for {user}")
logger.debug(f"List upload sessions for {workspace}")
try:
rows = await async_execute(
self.cassandra,
self.list_upload_sessions_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)

View file

@ -2,7 +2,6 @@
Joke Tool Service - An example dynamic tool service.
This service demonstrates the tool service integration by:
- Using the 'user' field to personalize responses
- Using config params (style) to customize joke style
- Using arguments (topic) to generate topic-specific jokes
@ -143,17 +142,16 @@ class Processor(DynamicToolService):
super(Processor, self).__init__(**params)
logger.info("Joke service initialized")
async def invoke(self, user, config, arguments):
async def invoke(self, config, arguments):
"""
Generate a joke based on the topic and style.
Args:
user: The user requesting the joke
config: Config values including 'style' (pun, dad-joke, one-liner)
arguments: Arguments including 'topic' (programming, animals, food)
Returns:
A personalized joke string
A joke string
"""
# Get style from config (default: random)
style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"]))
@ -183,10 +181,9 @@ class Processor(DynamicToolService):
# Pick a random joke
joke = random.choice(jokes)
# Personalize the response
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
response = f"Here's a {style} for you:\n\n{joke}"
logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}")
logger.debug(f"Generated joke: style={style}, topic={topic}")
return response