IAM tech spec: Auth and access management current state and proposed

changes.

Workspace support:
- Support for separate workspaces
- Addition of workspace CLI support for test purposes
- Massive test update
- Remove many 'user' references in services - workspace now provides
  the same separation
- Update API
This commit is contained in:
Cyber MacGeddon 2026-04-18 23:07:26 +01:00
parent 48da6c5f8b
commit 594deba73e
347 changed files with 6788 additions and 5540 deletions

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.3,<2.4",
"trustgraph-base>=2.4,<2.5",
"aiohttp",
"anthropic",
"scylla-driver",

View file

@ -26,42 +26,50 @@ class Service(ToolService):
self.register_config_handler(self.on_mcp_config, types=["mcp"])
# Per-workspace MCP service registries
self.mcp_services = {}
async def on_mcp_config(self, config, version):
async def on_mcp_config(self, workspace, config, version):
logger.info(f"Got config version {version}")
logger.info(
f"Got config version {version} for workspace {workspace}"
)
if "mcp" not in config:
self.mcp_services = {}
self.mcp_services[workspace] = {}
return
self.mcp_services = {
self.mcp_services[workspace] = {
k: json.loads(v)
for k, v in config["mcp"].items()
}
async def invoke_tool(self, name, parameters):
async def invoke_tool(self, workspace, name, parameters):
try:
if name not in self.mcp_services:
raise RuntimeError(f"MCP service {name} not known")
ws_services = self.mcp_services.get(workspace, {})
if "url" not in self.mcp_services[name]:
if name not in ws_services:
raise RuntimeError(
f"MCP service {name} not known in workspace "
f"{workspace}"
)
if "url" not in ws_services[name]:
raise RuntimeError(f"MCP service {name} URL not defined")
url = self.mcp_services[name]["url"]
url = ws_services[name]["url"]
if "remote-name" in self.mcp_services[name]:
remote_name = self.mcp_services[name]["remote-name"]
if "remote-name" in ws_services[name]:
remote_name = ws_services[name]["remote-name"]
else:
remote_name = name
# Build headers with optional bearer token
headers = {}
if "auth-token" in self.mcp_services[name]:
token = self.mcp_services[name]["auth-token"]
if "auth-token" in ws_services[name]:
token = ws_services[name]["auth-token"]
headers["Authorization"] = f"Bearer {token}"
logger.info(f"Invoking {remote_name} at {url}")

View file

@ -108,7 +108,7 @@ class Aggregator:
)
def build_synthesis_request(self, correlation_id, original_question,
user, collection):
collection):
"""
Build the AgentRequest that triggers the synthesis phase.
"""
@ -139,7 +139,6 @@ class Aggregator:
state="",
group=template.group if template else [],
history=history,
user=user,
collection=collection,
streaming=template.streaming if template else False,
session_id=parent_session_id,

View file

@ -46,25 +46,20 @@ from ..tool_filter import filter_tools_by_group_and_state, get_next_state
logger = logging.getLogger(__name__)
class UserAwareContext:
"""Wraps flow interface to inject user context for tools that need it."""
class FlowContext:
"""Wraps flow interface with orchestrator-only scratch state
(explain URIs, response handle, streaming flag). Workspace isolation
is enforced by the flow layer (flow.workspace), not by this class."""
def __init__(self, flow, user, respond=None, streaming=False):
def __init__(self, flow, respond=None, streaming=False):
self._flow = flow
self._user = user
self.respond = respond
self.streaming = streaming
self.current_explain_uri = None
self.last_sub_explain_uri = None
def __call__(self, service_name):
client = self._flow(service_name)
if service_name in (
"structured-query-request",
"row-embeddings-query-request",
):
client._current_user = self._user
return client
return self._flow(service_name)
class UsageTracker:
@ -131,7 +126,6 @@ class PatternBase:
state="",
group=getattr(request, 'group', []),
history=[completion_step],
user=request.user,
collection=getattr(request, 'collection', 'default'),
streaming=False,
session_id=getattr(request, 'session_id', ''),
@ -158,9 +152,9 @@ class PatternBase:
current_state=getattr(request, 'state', None),
)
def make_context(self, flow, user, respond=None, streaming=False):
"""Create a user-aware context wrapper."""
return UserAwareContext(flow, user, respond=respond, streaming=streaming)
def make_context(self, flow, respond=None, streaming=False):
"""Create a flow context wrapper."""
return FlowContext(flow, respond=respond, streaming=streaming)
def build_history(self, request):
"""Convert AgentStep history into Action objects."""
@ -249,7 +243,7 @@ class PatternBase:
# ---- Provenance emission ------------------------------------------------
async def emit_session_triples(self, flow, session_uri, question, user,
async def emit_session_triples(self, flow, session_uri, question,
collection, respond, streaming,
parent_uri=None):
"""Emit provenance triples for a new session."""
@ -264,7 +258,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=session_uri,
user=user,
collection=collection,
),
triples=triples,
@ -281,7 +274,7 @@ class PatternBase:
async def emit_pattern_decision_triples(
self, flow, session_id, session_uri, pattern, task_type,
user, collection, respond,
collection, respond,
):
"""Emit provenance triples for a meta-router pattern decision."""
uri = agent_pattern_decision_uri(session_id)
@ -292,7 +285,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -329,7 +322,7 @@ class PatternBase:
try:
await self.processor.save_answer_content(
doc_id=thought_doc_id,
user=request.user,
workspace=flow.workspace,
content=act.thought,
title=f"Agent Thought: {act.name}",
)
@ -360,7 +353,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=iteration_uri,
user=request.user,
collection=getattr(request, 'collection', 'default'),
),
triples=iter_triples,
@ -399,7 +391,7 @@ class PatternBase:
try:
await self.processor.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
workspace=flow.workspace,
content=observation_text,
title=f"Agent Observation",
)
@ -420,7 +412,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=observation_entity_uri,
user=request.user,
collection=getattr(request, 'collection', 'default'),
),
triples=obs_triples,
@ -456,7 +447,7 @@ class PatternBase:
try:
await self.processor.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
workspace=flow.workspace,
content=answer_text,
title=f"Agent Answer: {request.question[:50]}...",
)
@ -478,7 +469,6 @@ class PatternBase:
await flow("explainability").send(Triples(
metadata=Metadata(
id=final_uri,
user=request.user,
collection=getattr(request, 'collection', 'default'),
),
triples=final_triples,
@ -496,7 +486,7 @@ class PatternBase:
# ---- Orchestrator provenance helpers ------------------------------------
async def emit_decomposition_triples(
self, flow, session_id, session_uri, goals, user, collection,
self, flow, session_id, session_uri, goals, collection,
respond, streaming,
):
"""Emit provenance for a supervisor decomposition step."""
@ -506,7 +496,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -516,7 +506,7 @@ class PatternBase:
))
async def emit_finding_triples(
self, flow, session_id, index, goal, answer_text, user, collection,
self, flow, session_id, index, goal, answer_text, collection,
respond, streaming, subagent_session_id="",
):
"""Emit provenance for a subagent finding."""
@ -532,7 +522,7 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc"
try:
await self.processor.save_answer_content(
doc_id=doc_id, user=user,
doc_id=doc_id, workspace=flow.workspace,
content=answer_text,
title=f"Finding: {goal[:60]}",
)
@ -545,7 +535,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -555,7 +545,7 @@ class PatternBase:
))
async def emit_plan_triples(
self, flow, session_id, session_uri, steps, user, collection,
self, flow, session_id, session_uri, steps, collection,
respond, streaming,
):
"""Emit provenance for a plan creation."""
@ -565,7 +555,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -575,7 +565,7 @@ class PatternBase:
))
async def emit_step_result_triples(
self, flow, session_id, index, goal, answer_text, user, collection,
self, flow, session_id, index, goal, answer_text, collection,
respond, streaming,
):
"""Emit provenance for a plan step result."""
@ -585,7 +575,7 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc"
try:
await self.processor.save_answer_content(
doc_id=doc_id, user=user,
doc_id=doc_id, workspace=flow.workspace,
content=answer_text,
title=f"Step result: {goal[:60]}",
)
@ -598,7 +588,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -608,7 +598,7 @@ class PatternBase:
))
async def emit_synthesis_triples(
self, flow, session_id, previous_uris, answer_text, user, collection,
self, flow, session_id, previous_uris, answer_text, collection,
respond, streaming, termination_reason=None,
):
"""Emit provenance for a synthesis answer."""
@ -617,7 +607,7 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc"
try:
await self.processor.save_answer_content(
doc_id=doc_id, user=user,
doc_id=doc_id, workspace=flow.workspace,
content=answer_text,
title="Synthesis",
)
@ -633,7 +623,7 @@ class PatternBase:
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
metadata=Metadata(id=uri, collection=collection),
triples=triples,
))
await respond(AgentResponse(
@ -751,7 +741,6 @@ class PatternBase:
)
for h in history
],
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id,

View file

@ -53,7 +53,7 @@ class PlanThenExecutePattern(PatternBase):
if iteration_num == 1:
await self.emit_session_triples(
flow, session_uri, request.question,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
logger.info(
@ -109,11 +109,17 @@ class PlanThenExecutePattern(PatternBase):
think = self.make_think_callback(respond, streaming)
tools = self.filter_tools(self.processor.agent.tools, request)
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
tools = self.filter_tools(agent.tools, request)
framing = getattr(request, 'framing', '')
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -147,7 +153,7 @@ class PlanThenExecutePattern(PatternBase):
step_goals = [ps.get("goal", "") for ps in plan_steps]
await self.emit_plan_triples(
flow, session_id, session_uri, step_goals,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
# Build PlanStep objects
@ -179,7 +185,6 @@ class PlanThenExecutePattern(PatternBase):
state=request.state,
group=getattr(request, 'group', []),
history=new_history,
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -237,9 +242,15 @@ class PlanThenExecutePattern(PatternBase):
"result": dep_result,
})
tools = self.filter_tools(self.processor.agent.tools, request)
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
tools = self.filter_tools(agent.tools, request)
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
@ -307,7 +318,7 @@ class PlanThenExecutePattern(PatternBase):
# Emit step result provenance
await self.emit_step_result_triples(
flow, session_id, pending_idx, goal, step_result,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
# Build execution step for history
@ -327,7 +338,6 @@ class PlanThenExecutePattern(PatternBase):
state=request.state,
group=getattr(request, 'group', []),
history=new_history,
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -352,7 +362,7 @@ class PlanThenExecutePattern(PatternBase):
framing = getattr(request, 'framing', '')
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -387,7 +397,7 @@ class PlanThenExecutePattern(PatternBase):
last_step_uri = make_step_result_uri(session_id, len(plan) - 1)
await self.emit_synthesis_triples(
flow, session_id, last_step_uri,
response_text, request.user, collection, respond, streaming,
response_text, collection, respond, streaming,
termination_reason="plan-complete",
)

View file

@ -61,7 +61,7 @@ class ReactPattern(PatternBase):
)
await self.emit_session_triples(
flow, session_uri, request.question,
request.user, collection, respond, streaming,
collection, respond, streaming,
parent_uri=parent_uri,
)
@ -80,13 +80,20 @@ class ReactPattern(PatternBase):
observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id)
answer_cb = self.make_answer_callback(respond, streaming, message_id=answer_msg_id)
# Look up the per-workspace agent
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
# Filter tools
filtered_tools = self.filter_tools(
self.processor.agent.tools, request,
agent.tools, request,
)
# Create temporary agent with filtered tools and optional framing
additional_context = self.processor.agent.additional_context
additional_context = agent.additional_context
framing = getattr(request, 'framing', '')
if framing:
if additional_context:
@ -100,7 +107,7 @@ class ReactPattern(PatternBase):
)
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)

View file

@ -42,7 +42,7 @@ from ..tool_filter import validate_tool_config
from ..react.types import Final, Action, Tool, Argument
from . meta_router import MetaRouter
from . pattern_base import PatternBase, UserAwareContext
from . pattern_base import PatternBase, FlowContext
from . react_pattern import ReactPattern
from . plan_pattern import PlanThenExecutePattern
from . supervisor_pattern import SupervisorPattern
@ -76,10 +76,9 @@ class Processor(AgentService):
}
)
self.agent = AgentManager(
tools={},
additional_context="",
)
# Per-workspace agent managers and meta-routers
self.agents = {}
self.meta_routers = {}
self.tool_service_clients = {}
@ -91,9 +90,6 @@ class Processor(AgentService):
# Aggregator for supervisor fan-in
self.aggregator = Aggregator()
# Meta-router (initialised on first config load)
self.meta_router = None
self.register_config_handler(
self.on_tools_config, types=["tool", "tool-service"]
)
@ -204,13 +200,13 @@ class Processor(AgentService):
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None,
async def save_answer_content(self, doc_id, workspace, content, title=None,
timeout=120):
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
@ -221,7 +217,7 @@ class Processor(AgentService):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
future = asyncio.get_event_loop().create_future()
@ -247,9 +243,12 @@ class Processor(AgentService):
def provenance_session_uri(self, session_id):
return agent_session_uri(session_id)
async def on_tools_config(self, config, version):
async def on_tools_config(self, workspace, config, version):
logger.info(f"Loading configuration version {version}")
logger.info(
f"Loading configuration version {version} "
f"for workspace {workspace}"
)
try:
tools = {}
@ -316,7 +315,6 @@ class Processor(AgentService):
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection"),
user=None,
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
@ -324,7 +322,6 @@ class Processor(AgentService):
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None,
index_name=data.get("index-name"),
limit=int(data.get("limit", 10)),
)
@ -408,15 +405,17 @@ class Processor(AgentService):
agent_config = config[self.config_key]
additional = agent_config.get("additional-context", None)
self.agent = AgentManager(
self.agents[workspace] = AgentManager(
tools=tools,
additional_context=additional,
)
# Re-initialise meta-router with config
self.meta_router = MetaRouter(config=config)
# Re-initialise meta-router with config for this workspace
self.meta_routers[workspace] = MetaRouter(config=config)
logger.info(f"Loaded {len(tools)} tools")
logger.info(
f"Loaded {len(tools)} tools for workspace {workspace}"
)
except Exception as e:
logger.error(
@ -466,7 +465,7 @@ class Processor(AgentService):
await self.supervisor_pattern.emit_finding_triples(
flow, parent_session_id, finding_index,
subagent_goal, answer_text,
template.user, collection,
collection,
respond, template.streaming,
subagent_session_id=subagent_session_id,
)
@ -486,7 +485,6 @@ class Processor(AgentService):
synthesis_request = self.aggregator.build_synthesis_request(
correlation_id,
original_question=template.question,
user=template.user,
collection=getattr(template, 'collection', 'default'),
)
@ -515,10 +513,11 @@ class Processor(AgentService):
# If no pattern set and this is the first iteration, route
if not pattern and not request.history:
context = UserAwareContext(flow, request.user)
context = FlowContext(flow)
if self.meta_router:
pattern, task_type, framing = await self.meta_router.route(
meta_router = self.meta_routers.get(flow.workspace)
if meta_router:
pattern, task_type, framing = await meta_router.route(
request.question, context, usage=usage,
)
else:
@ -553,7 +552,6 @@ class Processor(AgentService):
await selected.emit_pattern_decision_triples(
flow, session_id, session_uri,
pattern, getattr(request, 'task_type', ''),
request.user,
getattr(request, 'collection', 'default'),
respond,
)

View file

@ -54,7 +54,7 @@ class SupervisorPattern(PatternBase):
if iteration_num == 1:
await self.emit_session_triples(
flow, session_uri, request.question,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
logger.info(
@ -99,10 +99,16 @@ class SupervisorPattern(PatternBase):
)
framing = getattr(request, 'framing', '')
tools = self.filter_tools(self.processor.agent.tools, request)
agent = self.processor.agents.get(flow.workspace)
if agent is None:
raise RuntimeError(
f"No agent configuration for workspace {flow.workspace}"
)
tools = self.filter_tools(agent.tools, request)
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -144,7 +150,7 @@ class SupervisorPattern(PatternBase):
# Emit decomposition provenance
await self.emit_decomposition_triples(
flow, session_id, session_uri, goals,
request.user, collection, respond, streaming,
collection, respond, streaming,
)
# Fan out: emit a subagent request for each goal
@ -155,7 +161,6 @@ class SupervisorPattern(PatternBase):
state="",
group=getattr(request, 'group', []),
history=[],
user=request.user,
collection=collection,
streaming=False, # Subagents don't stream
session_id=subagent_session,
@ -207,7 +212,7 @@ class SupervisorPattern(PatternBase):
subagent_results = {"(no results)": "No subagent results available"}
context = self.make_context(
flow, request.user,
flow,
respond=respond, streaming=streaming,
)
client = context("prompt-request")
@ -237,7 +242,7 @@ class SupervisorPattern(PatternBase):
]
await self.emit_synthesis_triples(
flow, session_id, finding_uris,
response_text, request.user, collection, respond, streaming,
response_text, collection, respond, streaming,
termination_reason="subagents-complete",
)

View file

@ -10,6 +10,7 @@ import sys
import functools
import logging
import uuid
from typing import Dict
from datetime import datetime, timezone
# Module logger
@ -73,10 +74,8 @@ class Processor(AgentService):
}
)
self.agent = AgentManager(
tools={},
additional_context="",
)
# Per-workspace agent managers
self.agents: Dict[str, AgentManager] = {}
# Track active tool service clients for cleanup
self.tool_service_clients = {}
@ -193,13 +192,13 @@ class Processor(AgentService):
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
@ -211,7 +210,7 @@ class Processor(AgentService):
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
@ -222,7 +221,7 @@ class Processor(AgentService):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
# Create future for response
@ -249,9 +248,12 @@ class Processor(AgentService):
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, config, version):
async def on_tools_config(self, workspace, config, version):
logger.info(f"Loading configuration version {version}")
logger.info(
f"Loading configuration version {version} "
f"for workspace {workspace}"
)
try:
@ -321,7 +323,6 @@ class Processor(AgentService):
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection"),
user=None # User will be provided dynamically via context
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
@ -329,7 +330,6 @@ class Processor(AgentService):
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None, # User will be provided dynamically via context
index_name=data.get("index-name"), # Optional filter
limit=int(data.get("limit", 10)) # Max results
)
@ -409,13 +409,17 @@ class Processor(AgentService):
agent_config = config[self.config_key]
additional = agent_config.get("additional-context", None)
self.agent = AgentManager(
self.agents[workspace] = AgentManager(
tools=tools,
additional_context=additional
)
logger.info(f"Loaded {len(tools)} tools")
logger.info("Tool configuration reloaded.")
logger.info(
f"Loaded {len(tools)} tools for workspace {workspace}"
)
logger.info(
f"Tool configuration reloaded for workspace {workspace}."
)
except Exception as e:
@ -460,7 +464,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=session_uri,
user=request.user,
collection=collection,
),
triples=triples,
@ -557,35 +560,41 @@ class Processor(AgentService):
await respond(r)
# Look up the agent for this workspace
workspace = flow.workspace
agent = self.agents.get(workspace)
if agent is None:
logger.error(
f"No agent configuration loaded for workspace "
f"{workspace}"
)
raise RuntimeError(
f"No agent configuration for workspace {workspace}"
)
# Apply tool filtering based on request groups and state
filtered_tools = filter_tools_by_group_and_state(
tools=self.agent.tools,
tools=agent.tools,
requested_groups=getattr(request, 'group', None),
current_state=getattr(request, 'state', None)
)
# Create temporary agent with filtered tools
temp_agent = AgentManager(
tools=filtered_tools,
additional_context=self.agent.additional_context
additional_context=agent.additional_context
)
logger.debug("Call React")
# Create user-aware context wrapper that preserves the flow interface
# but adds user information for tools that need it
class UserAwareContext:
def __init__(self, flow, user):
# Thin wrapper around flow — carries only explain URI state.
class _Context:
def __init__(self, flow):
self._flow = flow
self._user = user
self.last_sub_explain_uri = None
def __call__(self, service_name):
client = self._flow(service_name)
# For query clients that need user context, store it
if service_name in ("structured-query-request", "row-embeddings-query-request"):
client._current_user = self._user
return client
return self._flow(service_name)
# Callback: emit Analysis+ToolUse triples before tool executes
async def on_action(act_decision):
@ -604,7 +613,7 @@ class Processor(AgentService):
try:
await self.save_answer_content(
doc_id=t_doc_id,
user=request.user,
workspace=flow.workspace,
content=act_decision.thought,
title=f"Agent Thought: {act_decision.name}",
)
@ -629,7 +638,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=iter_uri,
user=request.user,
collection=collection,
),
triples=iter_triples,
@ -644,7 +652,7 @@ class Processor(AgentService):
explain_triples=iter_triples,
))
user_context = UserAwareContext(flow, request.user)
user_context = _Context(flow)
act = await temp_agent.react(
question = request.question,
@ -685,7 +693,7 @@ class Processor(AgentService):
try:
await self.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
workspace=flow.workspace,
content=f,
title=f"Agent Answer: {request.question[:50]}...",
)
@ -706,7 +714,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=final_uri,
user=request.user,
collection=collection,
),
triples=final_triples,
@ -763,7 +770,7 @@ class Processor(AgentService):
try:
await self.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
workspace=flow.workspace,
content=act.observation,
title=f"Agent Observation",
)
@ -783,7 +790,6 @@ class Processor(AgentService):
await flow("explainability").send(Triples(
metadata=Metadata(
id=observation_entity_uri,
user=request.user,
collection=collection,
),
triples=obs_triples,
@ -820,7 +826,6 @@ class Processor(AgentService):
)
for h in history
],
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id, # Pass session_id for provenance continuity

View file

@ -116,31 +116,26 @@ class McpToolImpl:
# This tool implementation knows how to query structured data using natural language
class StructuredQueryImpl:
def __init__(self, context, collection=None, user=None):
def __init__(self, context, collection=None):
self.context = context
self.collection = collection # For multi-tenant scenarios
self.user = user # User context for multi-tenancy
self.collection = collection
@staticmethod
def get_arguments():
return [
Argument(
name="question",
type="string",
type="string",
description="Natural language question about structured data (tables, databases, etc.)"
)
]
async def invoke(self, **arguments):
client = self.context("structured-query-request")
logger.debug("Structured query question...")
# Get user from client context if available, otherwise use instance user or default
user = getattr(client, '_current_user', self.user or "trustgraph")
result = await client.structured_query(
question=arguments.get("question"),
user=user,
collection=self.collection or "default"
)
@ -159,11 +154,10 @@ class StructuredQueryImpl:
# This tool implementation knows how to query row embeddings for semantic search
class RowEmbeddingsQueryImpl:
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
def __init__(self, context, schema_name, collection=None, index_name=None, limit=10):
self.context = context
self.schema_name = schema_name
self.collection = collection
self.user = user
self.index_name = index_name # Optional: filter to specific index
self.limit = limit # Max results to return
@ -190,13 +184,9 @@ class RowEmbeddingsQueryImpl:
client = self.context("row-embeddings-query-request")
logger.debug("Row embeddings query...")
# Get user from client context if available
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vector=vector,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",
index_name=self.index_name,
limit=self.limit
@ -250,7 +240,7 @@ class ToolServiceImpl:
Initialize a tool service implementation.
Args:
context: The context function (provides user info)
context: Flow context (callable resolving service names to clients)
request_queue: Full Pulsar topic for requests
response_queue: Full Pulsar topic for responses
config_values: Dict of config values (e.g., {"collection": "customers"})
@ -325,17 +315,10 @@ class ToolServiceImpl:
logger.debug(f"Config: {self.config_values}")
logger.debug(f"Arguments: {arguments}")
# Get user from context if available
user = "trustgraph"
if hasattr(self.context, '_user'):
user = self.context._user
# Get or create the client
client = await self._get_or_create_client()
# Call the tool service
response = await client.call(
user=user,
config=self.config_values,
arguments=arguments,
)

View file

@ -95,7 +95,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
text = await self.get_document_text(v, flow.workspace)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
@ -144,7 +144,7 @@ class Processor(ChunkingService):
await self.librarian.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
@ -168,7 +168,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -179,7 +178,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,

View file

@ -92,7 +92,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
text = await self.get_document_text(v, flow.workspace)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
@ -140,7 +140,7 @@ class Processor(ChunkingService):
await self.librarian.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
@ -164,7 +164,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -175,7 +174,6 @@ class Processor(ChunkingService):
metadata=Metadata(
id=c_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,

View file

@ -9,42 +9,8 @@ from ... tables.config import ConfigTableStore
# Module logger
logger = logging.getLogger(__name__)
class ConfigurationClass:
async def keys(self):
return await self.table_store.get_keys(self.type)
async def values(self):
vals = await self.table_store.get_values(self.type)
return {
v[0]: v[1]
for v in vals
}
async def get(self, key):
return await self.table_store.get_value(self.type, key)
async def put(self, key, value):
return await self.table_store.put_config(self.type, key, value)
async def delete(self, key):
return await self.table_store.delete_key(self.type, key)
async def has(self, key):
val = await self.table_store.get_value(self.type, key)
return val is not None
class Configuration:
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
# FIXME: This has state now, but does it address all of the above?
# REVIEW: Above
# FIXME: Some version vs config race conditions
def __init__(self, push, host, username, password, keyspace):
# External function to respond to update
@ -60,34 +26,17 @@ class Configuration:
async def get_version(self):
return await self.table_store.get_version()
def get(self, type):
c = ConfigurationClass()
c.table_store = self.table_store
c.type = type
return c
async def handle_get(self, v):
# for k in v.keys:
# if k.type not in self or k.key not in self[k.type]:
# return ConfigResponse(
# version = None,
# values = None,
# directory = None,
# config = None,
# error = Error(
# type = "key-error",
# message = f"Key error"
# )
# )
workspace = v.workspace
values = [
ConfigValue(
type = k.type,
key = k.key,
value = await self.table_store.get_value(k.type, k.key)
value = await self.table_store.get_value(
workspace, k.type, k.key
)
)
for k in v.keys
]
@ -96,43 +45,19 @@ class Configuration:
version = await self.get_version(),
values = values,
)
async def handle_list(self, v):
# if v.type not in self:
# return ConfigResponse(
# version = None,
# values = None,
# directory = None,
# config = None,
# error = Error(
# type = "key-error",
# message = "No such type",
# ),
# )
return ConfigResponse(
version = await self.get_version(),
directory = await self.table_store.get_keys(v.type),
directory = await self.table_store.get_keys(
v.workspace, v.type
),
)
async def handle_getvalues(self, v):
# if v.type not in self:
# return ConfigResponse(
# version = None,
# values = None,
# directory = None,
# config = None,
# error = Error(
# type = "key-error",
# message = f"Key error"
# )
# )
vals = await self.table_store.get_values(v.type)
vals = await self.table_store.get_values(v.workspace, v.type)
values = map(
lambda x: ConfigValue(
@ -146,39 +71,63 @@ class Configuration:
values = list(values),
)
async def handle_getvalues_all_ws(self, v):
"""Fetch all values of a given type across all workspaces.
Used by shared processors to load type-scoped config at
startup without enumerating workspaces separately."""
vals = await self.table_store.get_values_all_ws(v.type)
values = [
ConfigValue(
workspace = row[0],
type = v.type,
key = row[1],
value = row[2],
)
for row in vals
]
return ConfigResponse(
version = await self.get_version(),
values = values,
)
async def handle_delete(self, v):
workspace = v.workspace
types = list(set(k.type for k in v.keys))
for k in v.keys:
await self.table_store.delete_key(k.type, k.key)
await self.table_store.delete_key(workspace, k.type, k.key)
await self.inc_version()
await self.push(types=types)
await self.push(changes={t: [workspace] for t in types})
return ConfigResponse(
)
async def handle_put(self, v):
workspace = v.workspace
types = list(set(k.type for k in v.values))
for k in v.values:
await self.table_store.put_config(k.type, k.key, k.value)
await self.table_store.put_config(
workspace, k.type, k.key, k.value
)
await self.inc_version()
await self.push(types=types)
await self.push(changes={t: [workspace] for t in types})
return ConfigResponse(
)
async def get_config(self):
async def get_config(self, workspace):
table = await self.table_store.get_all()
table = await self.table_store.get_all_for_workspace(workspace)
config = {}
@ -191,7 +140,7 @@ class Configuration:
async def handle_config(self, v):
config = await self.get_config()
config = await self.get_config(v.workspace)
return ConfigResponse(
version = await self.get_version(),
@ -200,7 +149,20 @@ class Configuration:
async def handle(self, msg):
logger.debug(f"Handling config message: {msg.operation}")
logger.debug(
f"Handling config message: {msg.operation} "
f"workspace={msg.workspace}"
)
# getvalues-all-ws spans all workspaces, so no workspace
# required; everything else is workspace-scoped.
if msg.operation != "getvalues-all-ws" and not msg.workspace:
return ConfigResponse(
error=Error(
type = "bad-request",
message = "Workspace is required"
)
)
if msg.operation == "get":
@ -214,6 +176,10 @@ class Configuration:
resp = await self.handle_getvalues(msg)
elif msg.operation == "getvalues-all-ws":
resp = await self.handle_getvalues_all_ws(msg)
elif msg.operation == "delete":
resp = await self.handle_delete(msg)

View file

@ -128,18 +128,21 @@ class Processor(AsyncProcessor):
await self.push() # Startup poke: empty types = everything
await self.config_request_consumer.start()
async def push(self, types=None):
async def push(self, changes=None):
version = await self.config.get_version()
resp = ConfigPush(
version = version,
types = types or [],
changes = changes or {},
)
await self.config_push_producer.send(resp)
logger.info(f"Pushed config poke version {version}, types={resp.types}")
logger.info(
f"Pushed config poke version {version}, "
f"changes={resp.changes}"
)
async def on_config_request(self, msg, consumer, flow):

View file

@ -33,7 +33,7 @@ class KnowledgeManager:
logger.info("Deleting knowledge core...")
await self.table_store.delete_kg_core(
request.user, request.id
request.workspace, request.id
)
await respond(
@ -63,7 +63,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_triples(
request.user,
request.workspace,
request.id,
publish_triples,
)
@ -81,7 +81,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_graph_embeddings(
request.user,
request.workspace,
request.id,
publish_ge,
)
@ -100,7 +100,7 @@ class KnowledgeManager:
async def list_kg_cores(self, request, respond):
ids = await self.table_store.list_kg_cores(request.user)
ids = await self.table_store.list_kg_cores(request.workspace)
await respond(
KnowledgeResponse(
@ -114,12 +114,14 @@ class KnowledgeManager:
async def put_kg_core(self, request, respond):
workspace = request.workspace
if request.triples:
await self.table_store.add_triples(request.triples)
await self.table_store.add_triples(workspace, request.triples)
if request.graph_embeddings:
await self.table_store.add_graph_embeddings(
request.graph_embeddings
workspace, request.graph_embeddings
)
await respond(
@ -178,10 +180,15 @@ class KnowledgeManager:
if request.flow is None:
raise RuntimeError("Flow ID must be specified")
if request.flow not in self.flow_config.flows:
raise RuntimeError("Invalid flow")
workspace = request.workspace
ws_flows = self.flow_config.flows.get(workspace, {})
if request.flow not in ws_flows:
raise RuntimeError(
f"Invalid flow {request.flow} for workspace "
f"{workspace}"
)
flow = self.flow_config.flows[request.flow]
flow = ws_flows[request.flow]
if "interfaces" not in flow:
raise RuntimeError("No defined interfaces")
@ -257,7 +264,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_triples(
request.user,
request.workspace,
request.id,
publish_triples,
)
@ -272,7 +279,7 @@ class KnowledgeManager:
# Remove doc table row
await self.table_store.get_graph_embeddings(
request.user,
request.workspace,
request.id,
publish_ge,
)

View file

@ -124,19 +124,21 @@ class Processor(AsyncProcessor):
await self.knowledge_request_consumer.start()
await self.knowledge_response_producer.start()
async def on_knowledge_config(self, config, version):
async def on_knowledge_config(self, workspace, config, version):
logger.info(f"Configuration version: {version}")
logger.info(
f"Configuration version: {version} workspace: {workspace}"
)
if "flow" in config:
self.flows = {
self.flows[workspace] = {
k: json.loads(v)
for k, v in config["flow"].items()
}
else:
self.flows = {}
self.flows[workspace] = {}
logger.debug(f"Flows: {self.flows}")
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
async def process_request(self, v, id):

View file

@ -200,7 +200,7 @@ class Processor(FlowProcessor):
if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
@ -215,7 +215,7 @@ class Processor(FlowProcessor):
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.librarian.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
if isinstance(content, str):
content = content.encode('utf-8')
@ -243,7 +243,7 @@ class Processor(FlowProcessor):
await self.librarian.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=page_content,
document_type="page",
title=f"Page {page_num}",
@ -265,7 +265,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -277,7 +276,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,

View file

@ -93,7 +93,7 @@ class Processor(FlowProcessor):
if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error(
@ -114,7 +114,7 @@ class Processor(FlowProcessor):
content = await self.librarian.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
workspace=flow.workspace,
)
# Content is base64 encoded
@ -157,7 +157,7 @@ class Processor(FlowProcessor):
await self.librarian.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
workspace=flow.workspace,
content=page_content,
document_type="page",
title=f"Page {page_num}",
@ -179,7 +179,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
@ -191,7 +190,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
def make_safe_collection_name(workspace, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default'
return safe
safe_user = sanitize(user)
safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
return f"{prefix}_{safe_workspace}_{safe_collection}"
class DocVectors:
@ -49,26 +49,26 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""
Check if any collection exists for this user/collection combination.
Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
def create_collection(self, workspace, collection, dimension=384):
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
@ -116,15 +116,15 @@ class DocVectors:
index_params=index_params
)
self.collections[(dimension, user, collection)] = collection_name
self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, chunk_id, user, collection):
def insert(self, embeds, chunk_id, workspace, collection):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, workspace, collection)
data = [
{
@ -134,25 +134,25 @@ class DocVectors:
]
self.client.insert(
collection_name=self.collections[(dim, user, collection)],
collection_name=self.collections[(dim, workspace, collection)],
data=data
)
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10):
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix)
if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...")
self.client.load_collection(
@ -181,12 +181,12 @@ class DocVectors:
return res
def delete_collection(self, user, collection):
def delete_collection(self, workspace, collection):
"""
Delete all dimension variants of the collection for the given user/collection.
Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
# Get all collections and filter for matches
@ -199,10 +199,10 @@ class DocVectors:
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
def make_safe_collection_name(workspace, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default'
return safe
safe_user = sanitize(user)
safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
return f"{prefix}_{safe_workspace}_{safe_collection}"
class EntityVectors:
@ -49,26 +49,26 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""
Check if any collection exists for this user/collection combination.
Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
def create_collection(self, workspace, collection, dimension=384):
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
@ -122,15 +122,15 @@ class EntityVectors:
index_params=index_params
)
self.collections[(dimension, user, collection)] = collection_name
self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, entity, user, collection, chunk_id=""):
def insert(self, embeds, entity, workspace, collection, chunk_id=""):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, workspace, collection)
data = [
{
@ -141,25 +141,25 @@ class EntityVectors:
]
self.client.insert(
collection_name=self.collections[(dim, user, collection)],
collection_name=self.collections[(dim, workspace, collection)],
data=data
)
def search(self, embeds, user, collection, fields=["entity"], limit=10):
def search(self, embeds, workspace, collection, fields=["entity"], limit=10):
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix)
if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...")
self.client.load_collection(
@ -188,12 +188,12 @@ class EntityVectors:
return res
def delete_collection(self, user, collection):
def delete_collection(self, workspace, collection):
"""
Delete all dimension variants of the collection for the given user/collection.
Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
# Get all collections and filter for matches
@ -206,10 +206,10 @@ class EntityVectors:
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -69,19 +69,26 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.register_config_handler(self.on_schema_config, types=["schema"])
self.register_config_handler(self.on_collection_config, types=["collection"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
@ -115,13 +122,19 @@ class Processor(CollectionConfigHandler, FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
@ -149,23 +162,29 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""Process incoming ExtractedObject and compute embeddings"""
obj = msg.value()
workspace = flow.workspace
logger.info(
f"Computing embeddings for {len(obj.values)} rows, "
f"schema {obj.schema_name}, doc {obj.metadata.id}"
f"schema {obj.schema_name}, doc {obj.metadata.id}, "
f"workspace {workspace}"
)
# Validate collection exists before processing
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
if not self.collection_exists(workspace, obj.metadata.collection):
logger.warning(
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
f"Collection {obj.metadata.collection} for workspace {workspace} "
f"does not exist in config. Dropping message."
)
return
# Get schema definition
schema = self.schemas.get(obj.schema_name)
# Get schema definition for this workspace
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
logger.warning(
f"No schema found for {obj.schema_name} in "
f"workspace {workspace} - skipping"
)
return
# Get all index names for this schema
@ -239,13 +258,13 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error("Exception during embedding computation", exc_info=True)
raise e
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
logger.debug(f"Row embeddings collection notification for {workspace}/{collection}")
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Collection deletion notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
logger.debug(f"Row embeddings collection delete notification for {workspace}/{collection}")
@staticmethod
def add_args(parser):

View file

@ -75,24 +75,36 @@ class Processor(FlowProcessor):
)
)
# Null configuration, should reload quickly
self.manager = PromptManager()
# Per-workspace prompt managers
self.managers = {}
async def on_prompt_config(self, config, version):
async def on_prompt_config(self, workspace, config, version):
logger.info(f"Loading configuration version {version}")
logger.info(
f"Loading configuration version {version} "
f"for workspace {workspace}"
)
if self.config_key not in config:
logger.warning(f"No key {self.config_key} in config")
logger.warning(
f"No key {self.config_key} in config for {workspace}"
)
return
config = config[self.config_key]
prompt_config = config[self.config_key]
try:
self.manager.load_config(config)
manager = self.managers.get(workspace)
if manager is None:
manager = PromptManager()
self.managers[workspace] = manager
logger.info("Prompt configuration reloaded")
manager.load_config(prompt_config)
logger.info(
f"Prompt configuration reloaded for {workspace}"
)
except Exception as e:
@ -107,7 +119,6 @@ class Processor(FlowProcessor):
metadata = Metadata(
id = metadata.id,
root = metadata.root,
user = metadata.user,
collection = metadata.collection,
),
triples = triples,
@ -120,7 +131,6 @@ class Processor(FlowProcessor):
metadata = Metadata(
id = metadata.id,
root = metadata.root,
user = metadata.user,
collection = metadata.collection,
),
entities = entity_contexts,
@ -170,13 +180,24 @@ class Processor(FlowProcessor):
try:
v = msg.value()
workspace = flow.workspace
# Extract chunk text
chunk_text = v.chunk.decode('utf-8')
logger.debug("Processing chunk for agent extraction")
logger.debug(
f"Processing chunk for agent extraction, "
f"workspace {workspace}"
)
prompt = self.manager.render(
manager = self.managers.get(workspace)
if manager is None:
logger.error(
f"No prompt configuration for workspace {workspace}"
)
return
prompt = manager.render(
self.template_id,
{
"text": chunk_text

View file

@ -213,7 +213,6 @@ class Processor(FlowProcessor):
Metadata(
id=v.metadata.id,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
batch
@ -227,7 +226,6 @@ class Processor(FlowProcessor):
Metadata(
id=v.metadata.id,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
batch

View file

@ -109,20 +109,22 @@ class Processor(FlowProcessor):
# Register config handler for ontology updates
self.register_config_handler(self.on_ontology_config, types=["ontology"])
# Shared components (not flow-specific)
self.ontology_loader = OntologyLoader()
# Per-workspace ontology loaders
self.ontology_loaders = {} # workspace -> OntologyLoader
self.text_processor = TextProcessor()
# Per-flow components (each flow gets its own embedder/vector store/selector)
self.flow_components = {} # flow_id -> {embedder, vector_store, selector}
# Per-flow components (each flow gets its own embedder/vector
# store/selector). Keyed by id(flow) — Flow objects are unique
# per (workspace, flow), so this is implicitly workspace-scoped.
self.flow_components = {}
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3)
# Track loaded ontology version
self.current_ontology_version = None
self.loaded_ontology_ids = set()
# Per-workspace ontology version tracking
self.current_ontology_versions = {} # workspace -> version
self.loaded_ontology_ids = {} # workspace -> set of ids
async def initialize_flow_components(self, flow):
"""Initialize per-flow OntoRAG components.
@ -167,17 +169,23 @@ class Processor(FlowProcessor):
vector_store=vector_store
)
# Embed all loaded ontologies for this flow
if self.ontology_loader.get_all_ontologies():
logger.info(f"Embedding ontologies for flow {flow_id}")
for ont_id, ontology in self.ontology_loader.get_all_ontologies().items():
workspace = flow.workspace
# Embed all loaded ontologies for this workspace
loader = self.ontology_loaders.get(workspace)
if loader is not None and loader.get_all_ontologies():
logger.info(
f"Embedding ontologies for flow {flow_id} "
f"(workspace {workspace})"
)
for ont_id, ontology in loader.get_all_ontologies().items():
await ontology_embedder.embed_ontology(ontology)
logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}")
# Initialize ontology selector
ontology_selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=self.ontology_loader,
ontology_loader=loader,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold
)
@ -187,7 +195,8 @@ class Processor(FlowProcessor):
'embedder': ontology_embedder,
'vector_store': vector_store,
'selector': ontology_selector,
'dimension': dimension
'dimension': dimension,
'workspace': workspace,
}
logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})")
@ -197,31 +206,27 @@ class Processor(FlowProcessor):
logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True)
raise
async def on_ontology_config(self, config, version):
"""
Handle ontology configuration updates from ConfigPush queue.
Parses and stores ontologies. Embedding happens per-flow on first message.
Called automatically when:
- Processor starts (gets full config history via start_of_messages=True)
- Config service pushes updates (immediate event-driven notification)
Args:
config: Full configuration map - config[type][key] = value
version: Config version number (monotonically increasing)
"""
async def on_ontology_config(self, workspace, config, version):
"""Handle ontology configuration updates for a workspace."""
try:
logger.info(f"Received ontology config update, version={version}")
logger.info(
f"Received ontology config update, "
f"version={version} workspace={workspace}"
)
# Skip if we've already processed this version
if version == self.current_ontology_version:
logger.debug(f"Already at version {version}, skipping")
# Skip if we've already processed this version for this workspace
if version == self.current_ontology_versions.get(workspace):
logger.debug(
f"Already at version {version} for {workspace}, "
f"skipping"
)
return
# Extract ontology configurations
if "ontology" not in config:
logger.warning("No 'ontology' section in config")
logger.warning(
f"No 'ontology' section in config for {workspace}"
)
return
ontology_configs = config["ontology"]
@ -235,38 +240,65 @@ class Processor(FlowProcessor):
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
continue
logger.info(f"Loaded {len(ontologies)} ontology definitions")
logger.info(
f"Loaded {len(ontologies)} ontology definitions "
f"for {workspace}"
)
# Determine what changed (for incremental updates)
# Determine what changed for this workspace
ws_loaded_ids = self.loaded_ontology_ids.get(workspace, set())
new_ids = set(ontologies.keys())
added_ids = new_ids - self.loaded_ontology_ids
removed_ids = self.loaded_ontology_ids - new_ids
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
added_ids = new_ids - ws_loaded_ids
removed_ids = ws_loaded_ids - new_ids
updated_ids = new_ids & ws_loaded_ids # May have changed content
if added_ids:
logger.info(f"New ontologies: {added_ids}")
logger.info(f"New ontologies in {workspace}: {added_ids}")
if removed_ids:
logger.info(f"Removed ontologies: {removed_ids}")
logger.info(f"Removed ontologies in {workspace}: {removed_ids}")
if updated_ids:
logger.info(f"Updated ontologies: {updated_ids}")
logger.info(f"Updated ontologies in {workspace}: {updated_ids}")
# Update ontology loader's internal state
self.ontology_loader.update_ontologies(ontologies)
# Get or create per-workspace loader
loader = self.ontology_loaders.get(workspace)
if loader is None:
loader = OntologyLoader()
self.ontology_loaders[workspace] = loader
loader.update_ontologies(ontologies)
# Clear all flow components to force re-embedding with new ontologies
# Clear flow components for this workspace to force
# re-embedding with new ontologies.
if added_ids or removed_ids or updated_ids:
logger.info("Clearing flow components to trigger re-embedding")
self.flow_components.clear()
self._clear_workspace_flow_components(workspace)
# Update tracking
self.current_ontology_version = version
self.loaded_ontology_ids = new_ids
self.current_ontology_versions[workspace] = version
self.loaded_ontology_ids[workspace] = new_ids
logger.info(f"Ontology config update complete, version={version}")
logger.info(
f"Ontology config update complete for {workspace}, "
f"version={version}"
)
except Exception as e:
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
def _clear_workspace_flow_components(self, workspace):
"""Drop cached flow components belonging to the given workspace
so they're re-initialised on next message with fresh ontology
embeddings."""
to_remove = [
fid for fid, comp in self.flow_components.items()
if comp.get("workspace") == workspace
]
if to_remove:
logger.info(
f"Clearing {len(to_remove)} flow components for "
f"workspace {workspace}"
)
for fid in to_remove:
del self.flow_components[fid]
async def on_message(self, msg, consumer, flow):
"""Process incoming chunk message."""
v = msg.value()
@ -624,7 +656,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=metadata.id,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
triples=triples,
@ -637,7 +668,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=metadata.id,
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
entities=entities,

View file

@ -207,7 +207,6 @@ class Processor(FlowProcessor):
Metadata(
id=v.metadata.id,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
batch

View file

@ -84,32 +84,39 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
@ -124,21 +131,27 @@ class Processor(FlowProcessor):
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
async def extract_objects_for_schema(self, text: str, schema_name: str, schema: RowSchema, flow) -> List[Dict[str, Any]]:
"""Extract objects from text for a specific schema"""
@ -234,18 +247,26 @@ class Processor(FlowProcessor):
"""Process incoming chunk and extract objects"""
v = msg.value()
logger.info(f"Extracting objects from chunk {v.metadata.id}...")
workspace = flow.workspace
logger.info(
f"Extracting objects from chunk {v.metadata.id} "
f"(workspace {workspace})..."
)
chunk_text = v.chunk.decode("utf-8")
# If no schemas configured, log warning and return
if not self.schemas:
logger.warning("No schemas configured - skipping extraction")
# If no schemas configured for this workspace, log and return
ws_schemas = self.schemas.get(workspace, {})
if not ws_schemas:
logger.warning(
f"No schemas configured for workspace {workspace} "
f"- skipping extraction"
)
return
try:
# Extract objects for each configured schema
for schema_name, schema in self.schemas.items():
for schema_name, schema in ws_schemas.items():
logger.debug(f"Extracting {schema_name} objects from chunk")
@ -274,7 +295,6 @@ class Processor(FlowProcessor):
metadata=Metadata(
id=f"{v.metadata.id}:{schema_name}",
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
schema_name=schema_name,

View file

@ -17,14 +17,18 @@ class FlowConfig:
self.config = config
self.pubsub = pubsub
# Cache for parameter type definitions to avoid repeated lookups
# Per-workspace cache for parameter type definitions
# Keyed by (workspace, type-name)
self.param_type_cache = {}
async def resolve_parameters(self, flow_blueprint, user_params):
async def resolve_parameters(
self, workspace, flow_blueprint, user_params
):
"""
Resolve parameters by merging user-provided values with defaults.
Args:
workspace: Workspace containing the parameter-type definitions
flow_blueprint: The flow blueprint definition dict
user_params: User-provided parameters dict (may be None or empty)
@ -55,24 +59,25 @@ class FlowConfig:
# Look up the parameter type definition
param_type = param_meta.get("type")
if param_type:
cache_key = (workspace, param_type)
# Check cache first
if param_type not in self.param_type_cache:
if cache_key not in self.param_type_cache:
try:
# Fetch parameter type definition from config store
type_def = await self.config.get(
"parameter-type", param_type
workspace, "parameter-type", param_type
)
if type_def:
self.param_type_cache[param_type] = json.loads(type_def)
self.param_type_cache[cache_key] = json.loads(type_def)
else:
logger.warning(f"Parameter type '{param_type}' not found in config")
self.param_type_cache[param_type] = {}
self.param_type_cache[cache_key] = {}
except Exception as e:
logger.error(f"Error fetching parameter type '{param_type}': {e}")
self.param_type_cache[param_type] = {}
self.param_type_cache[cache_key] = {}
# Apply default from type definition (as string)
type_def = self.param_type_cache[param_type]
type_def = self.param_type_cache[cache_key]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
@ -94,8 +99,9 @@ class FlowConfig:
else:
# Controller has no value, try to get default from type definition
param_type = param_meta.get("type")
if param_type and param_type in self.param_type_cache:
type_def = self.param_type_cache[param_type]
cache_key = (workspace, param_type) if param_type else None
if cache_key and cache_key in self.param_type_cache:
type_def = self.param_type_cache[cache_key]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
@ -114,7 +120,9 @@ class FlowConfig:
async def handle_list_blueprints(self, msg):
names = list(await self.config.keys("flow-blueprint"))
names = list(await self.config.keys(
msg.workspace, "flow-blueprint"
))
return FlowResponse(
error = None,
@ -126,14 +134,14 @@ class FlowConfig:
return FlowResponse(
error = None,
blueprint_definition = await self.config.get(
"flow-blueprint", msg.blueprint_name
msg.workspace, "flow-blueprint", msg.blueprint_name
),
)
async def handle_put_blueprint(self, msg):
await self.config.put(
"flow-blueprint",
msg.workspace, "flow-blueprint",
msg.blueprint_name, msg.blueprint_definition
)
@ -145,7 +153,9 @@ class FlowConfig:
logger.debug(f"Flow config message: {msg}")
await self.config.delete("flow-blueprint", msg.blueprint_name)
await self.config.delete(
msg.workspace, "flow-blueprint", msg.blueprint_name
)
return FlowResponse(
error = None,
@ -153,7 +163,7 @@ class FlowConfig:
async def handle_list_flows(self, msg):
names = list(await self.config.keys("flow"))
names = list(await self.config.keys(msg.workspace, "flow"))
return FlowResponse(
error = None,
@ -162,7 +172,9 @@ class FlowConfig:
async def handle_get_flow(self, msg):
flow_data = await self.config.get("flow", msg.flow_id)
flow_data = await self.config.get(
msg.workspace, "flow", msg.flow_id
)
flow = json.loads(flow_data)
return FlowResponse(
@ -174,37 +186,49 @@ class FlowConfig:
async def handle_start_flow(self, msg):
workspace = msg.workspace
if msg.blueprint_name is None:
raise RuntimeError("No blueprint name")
if msg.flow_id is None:
raise RuntimeError("No flow ID")
if msg.flow_id in await self.config.keys("flow"):
if msg.flow_id in await self.config.keys(workspace, "flow"):
raise RuntimeError("Flow already exists")
if msg.description is None:
raise RuntimeError("No description")
if msg.blueprint_name not in await self.config.keys("flow-blueprint"):
if msg.blueprint_name not in await self.config.keys(
workspace, "flow-blueprint"
):
raise RuntimeError("Blueprint does not exist")
cls = json.loads(
await self.config.get("flow-blueprint", msg.blueprint_name)
await self.config.get(
workspace, "flow-blueprint", msg.blueprint_name
)
)
# Resolve parameters by merging user-provided values with defaults
user_params = msg.parameters if msg.parameters else {}
parameters = await self.resolve_parameters(cls, user_params)
parameters = await self.resolve_parameters(
workspace, cls, user_params
)
# Log the resolved parameters for debugging
logger.debug(f"User provided parameters: {user_params}")
logger.debug(f"Resolved parameters (with defaults): {parameters}")
# Apply parameter substitution to template replacement function
# Apply parameter substitution to template replacement function.
# {workspace} is substituted from msg.workspace to isolate
# queue names across workspaces.
def repl_template_with_params(tmp):
result = tmp.replace(
"{workspace}", workspace
).replace(
"{blueprint}", msg.blueprint_name
).replace(
"{id}", msg.flow_id
@ -253,7 +277,7 @@ class FlowConfig:
json.dumps(entry),
))
await self.config.put_many(updates)
await self.config.put_many(workspace, updates)
def repl_interface(i):
return {
@ -270,7 +294,7 @@ class FlowConfig:
interfaces = {}
await self.config.put(
"flow", msg.flow_id,
workspace, "flow", msg.flow_id,
json.dumps({
"description": msg.description,
"blueprint-name": msg.blueprint_name,
@ -283,68 +307,77 @@ class FlowConfig:
error = None,
)
async def ensure_existing_flow_topics(self):
"""Ensure topics exist for all already-running flows.
async def ensure_existing_flow_topics(self, workspaces):
"""Ensure topics exist for all already-running flows across
the given workspaces.
Called on startup to handle flows that were started before this
version of the flow service was deployed, or before a restart.
"""
flow_ids = await self.config.keys("flow")
for workspace in workspaces:
flow_ids = await self.config.keys(workspace, "flow")
for flow_id in flow_ids:
try:
flow_data = await self.config.get("flow", flow_id)
if flow_data is None:
continue
flow = json.loads(flow_data)
blueprint_name = flow.get("blueprint-name")
if blueprint_name is None:
continue
# Skip flows that are mid-shutdown
if flow.get("status") == "stopping":
continue
parameters = flow.get("parameters", {})
blueprint_data = await self.config.get(
"flow-blueprint", blueprint_name
)
if blueprint_data is None:
logger.warning(
f"Blueprint '{blueprint_name}' not found for "
f"flow '{flow_id}', skipping topic creation"
for flow_id in flow_ids:
try:
flow_data = await self.config.get(
workspace, "flow", flow_id
)
continue
if flow_data is None:
continue
cls = json.loads(blueprint_data)
flow = json.loads(flow_data)
def repl_template(tmp):
result = tmp.replace(
"{blueprint}", blueprint_name
).replace(
"{id}", flow_id
blueprint_name = flow.get("blueprint-name")
if blueprint_name is None:
continue
# Skip flows that are mid-shutdown
if flow.get("status") == "stopping":
continue
parameters = flow.get("parameters", {})
blueprint_data = await self.config.get(
workspace, "flow-blueprint", blueprint_name
)
for param_name, param_value in parameters.items():
result = result.replace(
f"{{{param_name}}}", str(param_value)
if blueprint_data is None:
logger.warning(
f"Blueprint '{blueprint_name}' not found "
f"for flow '{workspace}/{flow_id}', skipping "
f"topic creation"
)
return result
continue
topics = self._collect_flow_topics(cls, repl_template)
for topic in topics:
await self.pubsub.ensure_topic(topic)
cls = json.loads(blueprint_data)
logger.info(
f"Ensured topics for existing flow '{flow_id}'"
)
def repl_template(tmp):
result = tmp.replace(
"{workspace}", workspace
).replace(
"{blueprint}", blueprint_name
).replace(
"{id}", flow_id
)
for param_name, param_value in parameters.items():
result = result.replace(
f"{{{param_name}}}", str(param_value)
)
return result
except Exception as e:
logger.error(
f"Failed to ensure topics for flow '{flow_id}': {e}"
)
topics = self._collect_flow_topics(cls, repl_template)
for topic in topics:
await self.pubsub.ensure_topic(topic)
logger.info(
f"Ensured topics for existing flow "
f"'{workspace}/{flow_id}'"
)
except Exception as e:
logger.error(
f"Failed to ensure topics for flow "
f"'{workspace}/{flow_id}': {e}"
)
def _collect_flow_topics(self, cls, repl_template):
"""Collect unique topic identifiers from the blueprint.
@ -395,13 +428,17 @@ class FlowConfig:
async def handle_stop_flow(self, msg):
workspace = msg.workspace
if msg.flow_id is None:
raise RuntimeError("No flow ID")
if msg.flow_id not in await self.config.keys("flow"):
if msg.flow_id not in await self.config.keys(workspace, "flow"):
raise RuntimeError("Flow ID invalid")
flow = json.loads(await self.config.get("flow", msg.flow_id))
flow = json.loads(
await self.config.get(workspace, "flow", msg.flow_id)
)
if "blueprint-name" not in flow:
raise RuntimeError("Internal error: flow has no flow blueprint")
@ -410,11 +447,15 @@ class FlowConfig:
parameters = flow.get("parameters", {})
cls = json.loads(
await self.config.get("flow-blueprint", blueprint_name)
await self.config.get(
workspace, "flow-blueprint", blueprint_name
)
)
def repl_template(tmp):
result = tmp.replace(
"{workspace}", workspace
).replace(
"{blueprint}", blueprint_name
).replace(
"{id}", msg.flow_id
@ -431,7 +472,7 @@ class FlowConfig:
# The config push tells processors to shut down their consumers.
flow["status"] = "stopping"
await self.config.put(
"flow", msg.flow_id, json.dumps(flow)
workspace, "flow", msg.flow_id, json.dumps(flow)
)
# Delete all processor config entries for this flow.
@ -444,13 +485,13 @@ class FlowConfig:
deletes.append((f"processor:{processor}", variant))
await self.config.delete_many(deletes)
await self.config.delete_many(workspace, deletes)
# Phase 2: Delete topics with retries, then remove the flow record.
await self._delete_topics(topics)
if msg.flow_id in await self.config.keys("flow"):
await self.config.delete("flow", msg.flow_id)
if msg.flow_id in await self.config.keys(workspace, "flow"):
await self.config.delete(workspace, "flow", msg.flow_id)
return FlowResponse(
error = None,
@ -458,7 +499,18 @@ class FlowConfig:
async def handle(self, msg):
logger.debug(f"Handling flow message: {msg.operation}")
logger.debug(
f"Handling flow message: {msg.operation} "
f"workspace={msg.workspace}"
)
if not msg.workspace:
return FlowResponse(
error=Error(
type="bad-request",
message="Workspace is required",
),
)
if msg.operation == "list-blueprints":
resp = await self.handle_list_blueprints(msg)

View file

@ -103,7 +103,12 @@ class Processor(AsyncProcessor):
await self.pubsub.ensure_topic(self.flow_request_topic)
await self.config_client.start()
await self.flow.ensure_existing_flow_topics()
# Discover workspaces with existing flow config and ensure
# their topics exist before we start accepting requests.
workspaces = await self.config_client.workspaces_for_type("flow")
await self.flow.ensure_existing_flow_topics(workspaces)
await self.flow_request_consumer.start()
async def on_flow_request(self, msg, consumer, flow):

View file

@ -30,6 +30,7 @@ class ConfigReceiver:
self.flow_handlers = []
# Per-workspace flow tracking: {workspace: {flow_id: flow_def}}
self.flows = {}
self.config_version = 0
@ -43,7 +44,7 @@ class ConfigReceiver:
v = msg.value()
notify_version = v.version
notify_types = set(v.types)
changes = v.changes
# Skip if we already have this version or newer
if notify_version <= self.config_version:
@ -53,20 +54,27 @@ class ConfigReceiver:
)
return
# Gateway cares about flow config
if notify_types and "flow" not in notify_types:
# Gateway cares about flow config — check if any flow
# types changed in any workspace
flow_workspaces = changes.get("flow", [])
if changes and not flow_workspaces:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no flow types in {notify_types}"
f"no flow changes"
)
self.config_version = notify_version
return
logger.info(
f"Config notify v{notify_version}, fetching config..."
f"Config notify v{notify_version} "
f"types={list(changes.keys())}, fetching config..."
)
await self.fetch_and_apply()
# Refresh config for each affected workspace
for workspace in flow_workspaces:
await self.fetch_and_apply_workspace(workspace)
self.config_version = notify_version
except Exception as e:
logger.error(
@ -98,20 +106,25 @@ class ConfigReceiver:
response_metrics=config_resp_metrics,
)
async def fetch_and_apply(self, retry=False):
"""Fetch full config and apply flow changes.
async def fetch_and_apply_workspace(self, workspace, retry=False):
"""Fetch config for a single workspace and apply flow changes.
If retry=True, keeps retrying until successful."""
while True:
try:
logger.info("Fetching config from config service...")
logger.info(
f"Fetching config for workspace {workspace}..."
)
client = self._create_config_client()
try:
await client.start()
resp = await client.request(
ConfigRequest(operation="config"),
ConfigRequest(
operation="config",
workspace=workspace,
),
timeout=10,
)
finally:
@ -137,18 +150,22 @@ class ConfigReceiver:
flows = config.get("flow", {})
ws_flows = self.flows.get(workspace, {})
wanted = list(flows.keys())
current = list(self.flows.keys())
current = list(ws_flows.keys())
for k in wanted:
if k not in current:
self.flows[k] = json.loads(flows[k])
await self.start_flow(k, self.flows[k])
ws_flows[k] = json.loads(flows[k])
await self.start_flow(workspace, k, ws_flows[k])
for k in current:
if k not in wanted:
await self.stop_flow(k, self.flows[k])
del self.flows[k]
await self.stop_flow(workspace, k, ws_flows[k])
del ws_flows[k]
self.flows[workspace] = ws_flows
return
@ -164,27 +181,91 @@ class ConfigReceiver:
)
return
async def start_flow(self, id, flow):
async def fetch_all_workspaces(self, retry=False):
"""Fetch config for all workspaces at startup.
Discovers workspaces via the config service getvalues-all-ws
operation on the flow type."""
logger.info(f"Starting flow: {id}")
while True:
try:
logger.info("Discovering workspaces with flows...")
client = self._create_config_client()
try:
await client.start()
# Discover workspaces that have any flow config
resp = await client.request(
ConfigRequest(
operation="getvalues-all-ws",
type="flow",
),
timeout=10,
)
if resp.error:
raise RuntimeError(
f"Config error: {resp.error.message}"
)
workspaces = {
v.workspace for v in resp.values if v.workspace
}
# Always include the default workspace, even if
# empty, so that newly-created flows in it can be
# picked up by subsequent notifications.
workspaces.add("default")
logger.info(
f"Found workspaces with flows: {workspaces}"
)
finally:
await client.stop()
# Fetch and apply config for each workspace
for workspace in workspaces:
await self.fetch_and_apply_workspace(
workspace, retry=retry
)
return
except Exception as e:
if retry:
logger.warning(
f"Workspace fetch failed: {e}, retrying in 2s..."
)
await asyncio.sleep(2)
continue
logger.error(
f"Workspace fetch exception: {e}", exc_info=True
)
return
async def start_flow(self, workspace, id, flow):
logger.info(f"Starting flow: {workspace}/{id}")
for handler in self.flow_handlers:
try:
await handler.start_flow(id, flow)
await handler.start_flow(workspace, id, flow)
except Exception as e:
logger.error(
f"Config processing exception: {e}", exc_info=True
)
async def stop_flow(self, id, flow):
async def stop_flow(self, workspace, id, flow):
logger.info(f"Stopping flow: {id}")
logger.info(f"Stopping flow: {workspace}/{id}")
for handler in self.flow_handlers:
try:
await handler.stop_flow(id, flow)
await handler.stop_flow(workspace, id, flow)
except Exception as e:
logger.error(
f"Config processing exception: {e}", exc_info=True
@ -218,7 +299,7 @@ class ConfigReceiver:
# Fetch current config (subscribe-then-fetch pattern)
# Retry until config service is available
await self.fetch_and_apply(retry=True)
await self.fetch_all_workspaces(retry=True)
logger.info(
"Config loader initialised, waiting for notifys..."

View file

@ -16,7 +16,7 @@ class CoreExport:
async def process(self, data, error, ok, request):
id = request.query["id"]
user = request.query["user"]
workspace = request.query.get("workspace", "default")
response = await ok()
@ -41,7 +41,6 @@ class CoreExport:
{
"m": {
"i": data["metadata"]["id"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"e": [
@ -65,7 +64,6 @@ class CoreExport:
{
"m": {
"i": data["metadata"]["id"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
@ -78,7 +76,7 @@ class CoreExport:
await kr.process(
{
"operation": "get-kg-core",
"user": user,
"workspace": workspace,
"id": id,
},
responder

View file

@ -17,7 +17,7 @@ class CoreImport:
async def process(self, data, error, ok, request):
id = request.query["id"]
user = request.query["user"]
workspace = request.query.get("workspace", "default")
kr = KnowledgeRequestor(
backend = self.backend,
@ -43,12 +43,11 @@ class CoreImport:
msg = unpacked[1]
msg = {
"operation": "put-kg-core",
"user": user,
"workspace": workspace,
"id": id,
"triples": {
"metadata": {
"id": id,
"user": user,
"collection": "default", # Not used?
},
"triples": msg["t"],
@ -61,12 +60,11 @@ class CoreImport:
msg = unpacked[1]
msg = {
"operation": "put-kg-core",
"user": user,
"workspace": workspace,
"id": id,
"graph-embeddings": {
"metadata": {
"id": id,
"user": user,
"collection": "default", # Not used?
},
"entities": [

View file

@ -14,12 +14,12 @@ class DocumentStreamExport:
async def process(self, data, error, ok, request):
user = request.query.get("user")
workspace = request.query.get("workspace", "default")
document_id = request.query.get("document-id")
chunk_size = int(request.query.get("chunk-size", 1024 * 1024))
if not user or not document_id:
return await error("Missing required parameters: user, document-id")
if not document_id:
return await error("Missing required parameter: document-id")
response = await ok()
@ -45,7 +45,7 @@ class DocumentStreamExport:
await lr.process(
{
"operation": "stream-document",
"user": user,
"workspace": workspace,
"document-id": document_id,
"chunk-size": chunk_size,
},

View file

@ -48,7 +48,6 @@ class EntityContextsImport:
elt = EntityContexts(
metadata=Metadata(
id=data["metadata"]["id"],
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[

View file

@ -48,7 +48,6 @@ class GraphEmbeddingsImport:
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[

View file

@ -116,18 +116,20 @@ class DispatcherManager:
# Format: {"config": {"request": "...", "response": "..."}, ...}
self.queue_overrides = queue_overrides or {}
# Flows keyed by (workspace, flow_id)
self.flows = {}
# Dispatchers keyed by (workspace, flow_id, kind)
self.dispatchers = {}
self.dispatcher_lock = asyncio.Lock()
async def start_flow(self, id, flow):
logger.info(f"Starting flow {id}")
self.flows[id] = flow
async def start_flow(self, workspace, id, flow):
logger.info(f"Starting flow {workspace}/{id}")
self.flows[(workspace, id)] = flow
return
async def stop_flow(self, id, flow):
logger.info(f"Stopping flow {id}")
del self.flows[id]
async def stop_flow(self, workspace, id, flow):
logger.info(f"Stopping flow {workspace}/{id}")
del self.flows[(workspace, id)]
return
def dispatch_global_service(self):
@ -203,18 +205,20 @@ class DispatcherManager:
async def process_flow_import(self, ws, running, params):
workspace = params.get("workspace", "default")
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
flow_key = (workspace, flow)
if flow_key not in self.flows:
raise RuntimeError(f"Invalid flow {workspace}/{flow}")
if kind not in import_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
key = (workspace, flow, kind)
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow_key]["interfaces"]
# FIXME: The -store bit, does it make sense?
if kind == "entity-contexts":
@ -242,18 +246,20 @@ class DispatcherManager:
async def process_flow_export(self, ws, running, params):
workspace = params.get("workspace", "default")
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
flow_key = (workspace, flow)
if flow_key not in self.flows:
raise RuntimeError(f"Invalid flow {workspace}/{flow}")
if kind not in export_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
key = (workspace, flow, kind)
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow_key]["interfaces"]
# FIXME: The -store bit, does it make sense?
if kind == "entity-contexts":
@ -286,22 +292,36 @@ class DispatcherManager:
async def process_flow_service(self, data, responder, params):
# Workspace can come from URL or from request body, defaulting
# to "default". Having it in the URL allows gateway routing to
# be workspace-aware without touching the body.
workspace = params.get("workspace")
if not workspace and isinstance(data, dict):
workspace = data.get("workspace")
if not workspace:
workspace = "default"
flow = params.get("flow")
kind = params.get("kind")
return await self.invoke_flow_service(data, responder, flow, kind)
return await self.invoke_flow_service(
data, responder, workspace, flow, kind,
)
async def invoke_flow_service(self, data, responder, flow, kind):
async def invoke_flow_service(
self, data, responder, workspace, flow, kind,
):
if flow not in self.flows:
raise RuntimeError("Invalid flow")
flow_key = (workspace, flow)
if flow_key not in self.flows:
raise RuntimeError(f"Invalid flow {workspace}/{flow}")
key = (flow, kind)
key = (workspace, flow, kind)
if key not in self.dispatchers:
async with self.dispatcher_lock:
if key not in self.dispatchers:
intf_defs = self.flows[flow]["interfaces"]
intf_defs = self.flows[flow_key]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
@ -314,8 +334,8 @@ class DispatcherManager:
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"{self.prefix}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{flow}-{kind}-request",
consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](

View file

@ -47,7 +47,9 @@ class Mux:
raise RuntimeError("Bad message")
await self.q.put((
data["id"], data.get("flow"),
data["id"],
data.get("workspace", "default"),
data.get("flow"),
data["service"],
data["request"]
))
@ -87,8 +89,10 @@ class Mux:
# worker[0] still running, move on
break
async def start_request_task(self, ws, id, flow, svc, request, workers):
async def start_request_task(
self, ws, id, workspace, flow, svc, request, workers,
):
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
while len(workers) > MAX_OUTSTANDING_REQUESTS:
@ -106,19 +110,23 @@ class Mux:
})
worker = asyncio.create_task(
self.request_task(id, request, responder, flow, svc)
self.request_task(
id, request, responder, workspace, flow, svc,
)
)
workers.append(worker)
async def request_task(self, id, request, responder, flow, svc):
async def request_task(
self, id, request, responder, workspace, flow, svc,
):
try:
if flow:
await self.dispatcher_manager.invoke_flow_service(
request, responder, flow, svc
request, responder, workspace, flow, svc,
)
else:
@ -148,7 +156,7 @@ class Mux:
# Get next request on queue
item = await asyncio.wait_for(self.q.get(), 1)
id, flow, svc, request = item
id, workspace, flow, svc, request = item
except TimeoutError:
continue
@ -172,7 +180,7 @@ class Mux:
try:
await self.start_request_task(
self.ws, id, flow, svc, request, workers
self.ws, id, workspace, flow, svc, request, workers
)
except Exception as e:

View file

@ -53,7 +53,6 @@ class RowsImport:
elt = ExtractedObject(
metadata=Metadata(
id=data["metadata"]["id"],
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
schema_name=data["schema_name"],

View file

@ -38,7 +38,6 @@ def serialize_triples(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
@ -50,7 +49,6 @@ def serialize_graph_embeddings(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"entities": [
@ -68,7 +66,6 @@ def serialize_entity_contexts(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"entities": [
@ -86,7 +83,6 @@ def serialize_document_embeddings(message):
"metadata": {
"id": message.metadata.id,
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"chunks": [
@ -120,8 +116,8 @@ def serialize_document_metadata(message):
if message.metadata:
ret["metadata"] = serialize_subgraph(message.metadata)
if message.user:
ret["user"] = message.user
if message.workspace:
ret["workspace"] = message.workspace
if message.tags is not None:
ret["tags"] = message.tags
@ -144,8 +140,8 @@ def serialize_processing_metadata(message):
if message.flow:
ret["flow"] = message.flow
if message.user:
ret["user"] = message.user
if message.workspace:
ret["workspace"] = message.workspace
if message.collection:
ret["collection"] = message.collection
@ -164,7 +160,7 @@ def to_document_metadata(x):
title = x.get("title", None),
comments = x.get("comments", None),
metadata = to_subgraph(x["metadata"]),
user = x.get("user", None),
workspace = x.get("workspace", None),
tags = x.get("tags", None),
)
@ -175,7 +171,7 @@ def to_processing_metadata(x):
document_id = x.get("document-id", None),
time = x.get("time", None),
flow = x.get("flow", None),
user = x.get("user", None),
workspace = x.get("workspace", None),
collection = x.get("collection", None),
tags = x.get("tags", None),
)

View file

@ -49,7 +49,6 @@ class TriplesImport:
metadata=Metadata(
id=data["metadata"]["id"],
root=data["metadata"].get("root", ""),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),

View file

@ -3,6 +3,7 @@ Collection management for the librarian - uses config service for storage
"""
import asyncio
import dataclasses
import logging
import json
import uuid
@ -20,7 +21,6 @@ logger = logging.getLogger(__name__)
def metadata_to_dict(metadata: CollectionMetadata) -> dict:
"""Convert CollectionMetadata to dictionary for JSON serialization"""
return {
'user': metadata.user,
'collection': metadata.collection,
'name': metadata.name,
'description': metadata.description,
@ -92,38 +92,38 @@ class CollectionManager:
self.pending_config_requests[response_id + "_response"] = response
self.pending_config_requests[response_id].set()
async def ensure_collection_exists(self, user: str, collection: str):
async def ensure_collection_exists(self, workspace: str, collection: str):
"""
Ensure a collection exists, creating it if necessary
Args:
user: User ID
workspace: Workspace ID
collection: Collection ID
"""
try:
# Check if collection exists via config service
request = ConfigRequest(
operation='get',
keys=[ConfigKey(type='collection', key=f'{user}:{collection}')]
workspace=workspace,
keys=[ConfigKey(type='collection', key=collection)]
)
response = await self.send_config_request(request)
# Validate response
if not response.values or len(response.values) == 0:
raise Exception(f"Invalid response from config service when checking collection {user}/{collection}")
raise Exception(f"Invalid response from config service when checking collection {workspace}/{collection}")
# Check if collection exists (value not None means it exists)
if response.values[0].value is not None:
logger.debug(f"Collection {user}/{collection} already exists")
logger.debug(f"Collection {workspace}/{collection} already exists")
return
# Collection doesn't exist (value is None), proceed to create
# Create new collection with default metadata
logger.info(f"Auto-creating collection {user}/{collection}")
logger.info(f"Auto-creating collection {workspace}/{collection}")
metadata = CollectionMetadata(
user=user,
collection=collection,
name=collection, # Default name to collection ID
description="",
@ -132,9 +132,10 @@ class CollectionManager:
request = ConfigRequest(
operation='put',
workspace=workspace,
values=[ConfigValue(
type='collection',
key=f'{user}:{collection}',
key=collection,
value=json.dumps(metadata_to_dict(metadata))
)]
)
@ -144,7 +145,7 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config update failed: {response.error.message}")
logger.info(f"Collection {user}/{collection} auto-created in config service")
logger.info(f"Collection {workspace}/{collection} auto-created in config service")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
@ -161,9 +162,10 @@ class CollectionManager:
CollectionManagementResponse with list of collections
"""
try:
# Get all collections from config service
# Get all collections in this workspace from config service
config_request = ConfigRequest(
operation='getvalues',
workspace=request.workspace,
type='collection'
)
@ -172,15 +174,19 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config query failed: {response.error.message}")
# Parse collections and filter by user
# Every value in this workspace is a collection.
# Filter to fields the current schema knows about — older
# persisted values may carry fields that have since been
# dropped (e.g. `user` from the pre-workspace-refactor era).
known_fields = {f.name for f in dataclasses.fields(CollectionMetadata)}
collections = []
for config_value in response.values:
if ":" in config_value.key:
coll_user, coll_name = config_value.key.split(":", 1)
if coll_user == request.user:
metadata_dict = json.loads(config_value.value)
metadata = CollectionMetadata(**metadata_dict)
collections.append(metadata)
metadata_dict = json.loads(config_value.value)
metadata_dict = {
k: v for k, v in metadata_dict.items() if k in known_fields
}
metadata = CollectionMetadata(**metadata_dict)
collections.append(metadata)
# Apply tag filtering if specified
if request.tag_filter:
@ -221,7 +227,6 @@ class CollectionManager:
tags = list(request.tags) if request.tags else []
metadata = CollectionMetadata(
user=request.user,
collection=request.collection,
name=name,
description=description,
@ -231,9 +236,10 @@ class CollectionManager:
# Send put request to config service
config_request = ConfigRequest(
operation='put',
workspace=request.workspace,
values=[ConfigValue(
type='collection',
key=f'{request.user}:{request.collection}',
key=request.collection,
value=json.dumps(metadata_to_dict(metadata))
)]
)
@ -243,7 +249,7 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config update failed: {response.error.message}")
logger.info(f"Collection {request.user}/{request.collection} updated in config service")
logger.info(f"Collection {request.workspace}/{request.collection} updated in config service")
# Config service will trigger config push automatically
# Storage services will receive update and create/update collections
@ -269,12 +275,13 @@ class CollectionManager:
CollectionManagementResponse indicating success or failure
"""
try:
logger.info(f"Deleting collection {request.user}/{request.collection}")
logger.info(f"Deleting collection {request.workspace}/{request.collection}")
# Send delete request to config service
config_request = ConfigRequest(
operation='delete',
keys=[ConfigKey(type='collection', key=f'{request.user}:{request.collection}')]
workspace=request.workspace,
keys=[ConfigKey(type='collection', key=request.collection)]
)
response = await self.send_config_request(config_request)
@ -282,7 +289,7 @@ class CollectionManager:
if response.error:
raise RuntimeError(f"Config delete failed: {response.error.message}")
logger.info(f"Collection {request.user}/{request.collection} deleted from config service")
logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service")
# Config service will trigger config push automatically
# Storage services will receive update and delete collections

View file

@ -48,7 +48,7 @@ class Librarian:
raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RuntimeError("Document already exists")
@ -78,7 +78,7 @@ class Librarian:
logger.debug("Removing document...")
if not await self.table_store.document_exists(
request.user,
request.workspace,
request.document_id,
):
raise RuntimeError("Document does not exist")
@ -89,17 +89,17 @@ class Librarian:
logger.debug(f"Cascade deleting child document {child.id}")
try:
child_object_id = await self.table_store.get_document_object_id(
child.user,
child.workspace,
child.id
)
await self.blob_store.remove(child_object_id)
await self.table_store.remove_document(child.user, child.id)
await self.table_store.remove_document(child.workspace, child.id)
except Exception as e:
logger.warning(f"Failed to delete child document {child.id}: {e}")
# Now remove the parent document
object_id = await self.table_store.get_document_object_id(
request.user,
request.workspace,
request.document_id
)
@ -108,7 +108,7 @@ class Librarian:
# Remove doc table row
await self.table_store.remove_document(
request.user,
request.workspace,
request.document_id
)
@ -120,10 +120,10 @@ class Librarian:
logger.debug("Updating document...")
# You can't update the document ID, user or kind.
# You can't update the document ID, workspace or kind.
if not await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RuntimeError("Document does not exist")
@ -139,7 +139,7 @@ class Librarian:
logger.debug("Getting document metadata...")
doc = await self.table_store.get_document(
request.user,
request.workspace,
request.document_id
)
@ -156,7 +156,7 @@ class Librarian:
logger.debug("Getting document content...")
object_id = await self.table_store.get_document_object_id(
request.user,
request.workspace,
request.document_id
)
@ -180,18 +180,18 @@ class Librarian:
raise RuntimeError("Collection parameter is required")
if await self.table_store.processing_exists(
request.processing_metadata.user,
request.processing_metadata.workspace,
request.processing_metadata.id
):
raise RuntimeError("Processing already exists")
doc = await self.table_store.get_document(
request.processing_metadata.user,
request.processing_metadata.workspace,
request.processing_metadata.document_id
)
object_id = await self.table_store.get_document_object_id(
request.processing_metadata.user,
request.processing_metadata.workspace,
request.processing_metadata.document_id
)
@ -222,14 +222,14 @@ class Librarian:
logger.debug("Removing processing metadata...")
if not await self.table_store.processing_exists(
request.user,
request.workspace,
request.processing_id,
):
raise RuntimeError("Processing object does not exist")
# Remove doc table row
await self.table_store.remove_processing(
request.user,
request.workspace,
request.processing_id
)
@ -239,7 +239,7 @@ class Librarian:
async def list_documents(self, request):
docs = await self.table_store.list_documents(request.user)
docs = await self.table_store.list_documents(request.workspace)
# Filter out child documents and answer documents by default
include_children = getattr(request, 'include_children', False)
@ -256,7 +256,7 @@ class Librarian:
async def list_processing(self, request):
procs = await self.table_store.list_processing(request.user)
procs = await self.table_store.list_processing(request.workspace)
return LibrarianResponse(
processing_metadatas = procs,
@ -276,7 +276,7 @@ class Librarian:
raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RequestError("Document already exists")
@ -312,14 +312,14 @@ class Librarian:
"kind": request.document_metadata.kind,
"title": request.document_metadata.title,
"comments": request.document_metadata.comments,
"user": request.document_metadata.user,
"workspace": request.document_metadata.workspace,
"tags": request.document_metadata.tags,
})
# Store session in Cassandra
await self.table_store.create_upload_session(
upload_id=upload_id,
user=request.document_metadata.user,
workspace=request.document_metadata.workspace,
document_id=request.document_metadata.id,
document_metadata=doc_meta_json,
s3_upload_id=s3_upload_id,
@ -352,7 +352,7 @@ class Librarian:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to upload to this session")
# Validate chunk index
@ -419,7 +419,7 @@ class Librarian:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to complete this upload")
# Verify all chunks received
@ -457,7 +457,7 @@ class Librarian:
kind=doc_meta_dict["kind"],
title=doc_meta_dict.get("title", ""),
comments=doc_meta_dict.get("comments", ""),
user=doc_meta_dict["user"],
workspace=doc_meta_dict["workspace"],
tags=doc_meta_dict.get("tags", []),
metadata=[], # Triples not supported in chunked upload yet
)
@ -488,7 +488,7 @@ class Librarian:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to abort this upload")
# Abort S3 multipart upload
@ -520,7 +520,7 @@ class Librarian:
)
# Validate ownership
if session["user"] != request.user:
if session["workspace"] != request.workspace:
raise RequestError("Not authorized to view this upload")
chunks_received = session["chunks_received"]
@ -548,11 +548,11 @@ class Librarian:
async def list_uploads(self, request):
"""
List all in-progress uploads for a user.
List all in-progress uploads for a workspace.
"""
logger.debug(f"Listing uploads for user {request.user}")
logger.debug(f"Listing uploads for workspace {request.workspace}")
sessions = await self.table_store.list_upload_sessions(request.user)
sessions = await self.table_store.list_upload_sessions(request.workspace)
upload_sessions = [
UploadSession(
@ -591,7 +591,7 @@ class Librarian:
# Verify parent exists
if not await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.parent_id
):
raise RequestError(
@ -599,7 +599,7 @@ class Librarian:
)
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.workspace,
request.document_metadata.id
):
raise RequestError("Document already exists")
@ -665,7 +665,7 @@ class Librarian:
)
object_id = await self.table_store.get_document_object_id(
request.user,
request.workspace,
request.document_id
)

View file

@ -277,18 +277,22 @@ class Processor(AsyncProcessor):
"""Forward config responses to collection manager"""
await self.collection_manager.on_config_response(message, consumer, flow)
async def on_librarian_config(self, config, version):
async def on_librarian_config(self, workspace, config, version):
logger.info(f"Configuration version: {version}")
logger.info(
f"Configuration version: {version} workspace: {workspace}"
)
if "flow" in config:
self.flows = {
self.flows[workspace] = {
k: json.loads(v)
for k, v in config["flow"].items()
}
else:
self.flows[workspace] = {}
logger.debug(f"Flows: {self.flows}")
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
def __del__(self):
@ -345,7 +349,6 @@ class Processor(AsyncProcessor):
metadata=Metadata(
id=doc_uri,
root=document.id,
user=processing.user,
collection=processing.collection,
),
triples=all_triples,
@ -363,10 +366,15 @@ class Processor(AsyncProcessor):
logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}")
if processing.flow not in self.flows:
raise RuntimeError("Invalid flow ID")
workspace = processing.workspace
ws_flows = self.flows.get(workspace, {})
if processing.flow not in ws_flows:
raise RuntimeError(
f"Invalid flow ID {processing.flow} for workspace "
f"{workspace}"
)
flow = self.flows[processing.flow]
flow = ws_flows[processing.flow]
if document.kind == "text/plain":
kind = "text-load"
@ -386,7 +394,6 @@ class Processor(AsyncProcessor):
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
@ -398,7 +405,6 @@ class Processor(AsyncProcessor):
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
@ -429,9 +435,9 @@ class Processor(AsyncProcessor):
"""
# Ensure collection exists when processing is added
if hasattr(request, 'processing_metadata') and request.processing_metadata:
user = request.processing_metadata.user
workspace = request.processing_metadata.workspace
collection = request.processing_metadata.collection
await self.collection_manager.ensure_collection_exists(user, collection)
await self.collection_manager.ensure_collection_exists(workspace, collection)
# Call the original add_processing method
return await self.librarian.add_processing(request)

View file

@ -50,30 +50,37 @@ class Processor(FlowProcessor):
)
)
# Per-workspace price tables
self.prices = {}
self.config_key = "token-cost"
# Load token costs from the config service
async def on_cost_config(self, config, version):
async def on_cost_config(self, workspace, config, version):
logger.info(f"Loading metering configuration version {version}")
logger.info(
f"Loading metering configuration version {version} "
f"for workspace {workspace}"
)
if self.config_key not in config:
logger.warning(f"No key {self.config_key} in config")
logger.warning(
f"No key {self.config_key} in config for {workspace}"
)
self.prices[workspace] = {}
return
config = config[self.config_key]
prices = config[self.config_key]
self.prices = {
self.prices[workspace] = {
k: json.loads(v)
for k, v in config.items()
for k, v in prices.items()
}
def get_prices(self, modelname):
def get_prices(self, workspace, modelname):
if modelname in self.prices:
model = self.prices[modelname]
ws_prices = self.prices.get(workspace, {})
if modelname in ws_prices:
model = ws_prices[modelname]
return model["input_price"], model["output_price"]
return None, None # Return None if model is not found
@ -81,6 +88,8 @@ class Processor(FlowProcessor):
v = msg.value()
workspace = flow.workspace
modelname = v.model or "unknown"
num_in = v.in_token or 0
num_out = v.out_token or 0
@ -89,7 +98,9 @@ class Processor(FlowProcessor):
__class__.token_metric.labels(model=modelname, direction="input").inc(num_in)
__class__.token_metric.labels(model=modelname, direction="output").inc(num_out)
model_input_price, model_output_price = self.get_prices(modelname)
model_input_price, model_output_price = self.get_prices(
workspace, modelname
)
if model_input_price == None:
cost_per_call = f"Model Not Found in Price list"

View file

@ -66,24 +66,37 @@ class Processor(FlowProcessor):
self.register_config_handler(self.on_prompt_config, types=["prompt"])
# Null configuration, should reload quickly
self.manager = PromptManager()
# Per-workspace prompt managers. Populated lazily as config
# arrives for each workspace.
self.managers = {}
async def on_prompt_config(self, config, version):
async def on_prompt_config(self, workspace, config, version):
logger.info(f"Loading prompt configuration version {version}")
logger.info(
f"Loading prompt configuration version {version} "
f"for workspace {workspace}"
)
if self.config_key not in config:
logger.warning(f"No key {self.config_key} in config")
logger.warning(
f"No key {self.config_key} in config for {workspace}"
)
return
config = config[self.config_key]
prompt_config = config[self.config_key]
try:
self.manager.load_config(config)
manager = self.managers.get(workspace)
if manager is None:
manager = PromptManager()
self.managers[workspace] = manager
logger.info("Prompt configuration reloaded")
manager.load_config(prompt_config)
logger.info(
f"Prompt configuration reloaded for {workspace}"
)
except Exception as e:
@ -103,6 +116,29 @@ class Processor(FlowProcessor):
# Check if streaming is requested
streaming = getattr(v, 'streaming', False)
# Look up the prompt manager for this workspace. If none is
# loaded yet, the request can't be handled.
workspace = flow.workspace
manager = self.managers.get(workspace)
if manager is None:
logger.error(
f"No prompt configuration loaded for workspace {workspace}"
)
r = PromptResponse(
error=Error(
type="no-configuration",
message=(
f"No prompt configuration for workspace "
f"{workspace}"
),
),
text=None,
object=None,
end_of_stream=True,
)
await flow("response").send(r, properties={"id": id})
return
try:
logger.debug(f"Prompt terms: {v.terms}")
@ -149,7 +185,7 @@ class Processor(FlowProcessor):
return ""
try:
await self.manager.invoke(kind, input, llm_streaming)
await manager.invoke(kind, input, llm_streaming)
except Exception as e:
logger.error(f"Prompt streaming exception: {e}", exc_info=True)
raise e
@ -177,7 +213,7 @@ class Processor(FlowProcessor):
return None
try:
resp = await self.manager.invoke(kind, input, llm)
resp = await manager.invoke(kind, input, llm)
except Exception as e:
logger.error(f"Prompt invocation exception: {e}", exc_info=True)
raise e

View file

@ -31,7 +31,7 @@ class Processor(DocumentEmbeddingsQueryService):
self.vecstore = DocVectors(store_uri)
async def query_document_embeddings(self, msg):
async def query_document_embeddings(self, workspace, msg):
try:
@ -45,7 +45,7 @@ class Processor(DocumentEmbeddingsQueryService):
resp = self.vecstore.search(
vec,
msg.user,
workspace,
msg.collection,
limit=msg.limit
)

View file

@ -48,7 +48,7 @@ class Processor(DocumentEmbeddingsQueryService):
}
)
async def query_document_embeddings(self, msg):
async def query_document_embeddings(self, workspace, msg):
try:
@ -63,7 +63,7 @@ class Processor(DocumentEmbeddingsQueryService):
dim = len(vec)
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
index_name = f"d-{workspace}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):

View file

@ -65,7 +65,7 @@ class Processor(DocumentEmbeddingsQueryService):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, msg):
async def query_document_embeddings(self, workspace, msg):
try:
@ -75,7 +75,7 @@ class Processor(DocumentEmbeddingsQueryService):
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
collection = f"d_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):

View file

@ -37,7 +37,7 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):
async def query_graph_embeddings(self, workspace, msg):
try:
@ -51,7 +51,7 @@ class Processor(GraphEmbeddingsQueryService):
resp = self.vecstore.search(
vec,
msg.user,
workspace,
msg.collection,
limit=msg.limit * 2
)

View file

@ -55,7 +55,7 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):
async def query_graph_embeddings(self, workspace, msg):
try:
@ -70,7 +70,7 @@ class Processor(GraphEmbeddingsQueryService):
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
index_name = f"t-{workspace}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):

View file

@ -71,7 +71,7 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_graph_embeddings(self, msg):
async def query_graph_embeddings(self, workspace, msg):
try:
@ -81,7 +81,7 @@ class Processor(GraphEmbeddingsQueryService):
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
collection = f"t_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):

View file

@ -70,7 +70,7 @@ class GraphQLSchemaBuilder:
Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with:
- user: str
- workspace: str
- collection: str
- schema_name: str
- row_schema: RowSchema
@ -228,7 +228,7 @@ class GraphQLSchemaBuilder:
limit: Optional[int] = 100
) -> List[graphql_type]:
# Get context values
user = info.context["user"]
workspace = info.context["workspace"]
collection = info.context["collection"]
# Parse the where clause
@ -236,7 +236,7 @@ class GraphQLSchemaBuilder:
# Call the query backend
results = await query_callback(
user, collection, schema_name, row_schema,
workspace, collection, schema_name, row_schema,
filters, limit, order_by, direction
)

View file

@ -167,7 +167,7 @@ class QueryExplainer:
question_components, query_results, processing_metadata
)
# Generate user-friendly explanation
# Generate workspace-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer
)
@ -503,7 +503,7 @@ class QueryExplainer:
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str:
"""Generate user-friendly explanation of the process."""
"""Generate workspace-friendly explanation of the process."""
explanation_parts = []
# Introduction

View file

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass
class QueryRequest:
"""Query request from user."""
"""Query request from workspace."""
question: str
context: Optional[str] = None
ontology_hint: Optional[str] = None

View file

@ -1,6 +1,6 @@
"""
Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components.
Decomposes workspace questions into semantic components.
"""
import logging

View file

@ -1,7 +1,7 @@
"""
Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context.
Input is query vectors plus workspace/collection/schema context.
Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
@ -70,10 +70,10 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema"""
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given workspace/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
@ -93,22 +93,22 @@ class Processor(FlowProcessor):
return None
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
async def query_row_embeddings(self, workspace, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema
# Find the collection for this workspace/collection/schema
qdrant_collection = self.find_collection(
request.user, request.collection, request.schema_name
workspace, request.collection, request.schema_name
)
if not qdrant_collection:
logger.info(
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
f"{workspace}/{request.collection}/{request.schema_name}"
)
return []
@ -163,11 +163,11 @@ class Processor(FlowProcessor):
logger.debug(
f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..."
f"{flow.workspace}/{request.collection}/{request.schema_name}..."
)
# Execute query
matches = await self.query_row_embeddings(request)
matches = await self.query_row_embeddings(flow.workspace, request)
response = RowEmbeddingsResponse(
error=None,

View file

@ -87,12 +87,12 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
# GraphQL schema builder and generated schema
self.schema_builder = GraphQLSchemaBuilder()
self.graphql_schema = None
# Per-workspace GraphQL schema builders and compiled schemas
self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {}
self.graphql_schemas: Dict[str, Any] = {}
# Cassandra session
self.cluster = None
@ -133,17 +133,27 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
self.schema_builder.clear()
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
builder = GraphQLSchemaBuilder()
self.schema_builders[workspace] = builder
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
self.graphql_schemas[workspace] = None
return
# Get the schemas dictionary for our type
@ -177,17 +187,23 @@ class Processor(FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
self.schema_builder.add_schema(schema_name, row_schema)
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
builder.add_schema(schema_name, row_schema)
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
# Regenerate GraphQL schema
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
# Regenerate GraphQL schema for this workspace
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
@ -222,7 +238,7 @@ class Processor(FlowProcessor):
async def query_cassandra(
self,
user: str,
workspace: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
@ -240,7 +256,7 @@ class Processor(FlowProcessor):
# Connect if needed
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
@ -389,26 +405,30 @@ class Processor(FlowProcessor):
async def execute_graphql_query(
self,
workspace: str,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
"""Execute a GraphQL query against the workspace's schema"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
graphql_schema = self.graphql_schemas.get(workspace)
if not graphql_schema:
raise RuntimeError(
f"No GraphQL schema available for workspace {workspace} "
f"- no schemas loaded"
)
# Create context for the query
context = {
"processor": self,
"user": user,
"workspace": workspace,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
result = await graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
@ -454,10 +474,10 @@ class Processor(FlowProcessor):
# Execute GraphQL query
result = await self.execute_graphql_query(
workspace=flow.workspace,
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)

View file

@ -30,14 +30,14 @@ class EvaluationError(Exception):
pass
async def evaluate(node, triples_client, user, collection, limit=10000):
async def evaluate(node, triples_client, workspace, collection, limit=10000):
"""
Evaluate a SPARQL algebra node.
Args:
node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries
user: user/keyspace identifier
workspace: workspace/keyspace identifier
collection: collection identifier
limit: safety limit on results
@ -55,24 +55,24 @@ async def evaluate(node, triples_client, user, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}")
return [{}]
return await handler(node, triples_client, user, collection, limit)
return await handler(node, triples_client, workspace, collection, limit)
# --- Node handlers ---
async def _eval_select_query(node, tc, user, collection, limit):
async def _eval_select_query(node, tc, workspace, collection, limit):
"""Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_project(node, tc, user, collection, limit):
async def _eval_project(node, tc, workspace, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
variables = [str(v) for v in node.PV]
return project(solutions, variables)
async def _eval_bgp(node, tc, user, collection, limit):
async def _eval_bgp(node, tc, workspace, collection, limit):
"""
Evaluate a Basic Graph Pattern.
@ -107,7 +107,7 @@ async def _eval_bgp(node, tc, user, collection, limit):
# Query the triples store
results = await _query_pattern(
tc, s_val, p_val, o_val, user, collection, limit
tc, s_val, p_val, o_val, workspace, collection, limit
)
# Map results back to variable bindings,
@ -130,17 +130,17 @@ async def _eval_bgp(node, tc, user, collection, limit):
return solutions[:limit]
async def _eval_join(node, tc, user, collection, limit):
async def _eval_join(node, tc, workspace, collection, limit):
"""Evaluate a Join node."""
left = await evaluate(node.p1, tc, user, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit)
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, user, collection, limit):
async def _eval_left_join(node, tc, workspace, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, user, collection, limit)
right_sols = await evaluate(node.p2, tc, user, collection, limit)
left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
filter_fn = None
if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +153,16 @@ async def _eval_left_join(node, tc, user, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, user, collection, limit):
async def _eval_union(node, tc, workspace, collection, limit):
"""Evaluate a Union node."""
left = await evaluate(node.p1, tc, user, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit)
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
return union(left, right)[:limit]
async def _eval_filter(node, tc, user, collection, limit):
async def _eval_filter(node, tc, workspace, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
expr = node.expr
return [
sol for sol in solutions
@ -170,22 +170,22 @@ async def _eval_filter(node, tc, user, collection, limit):
]
async def _eval_distinct(node, tc, user, collection, limit):
async def _eval_distinct(node, tc, workspace, collection, limit):
"""Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions)
async def _eval_reduced(node, tc, user, collection, limit):
async def _eval_reduced(node, tc, workspace, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions)
async def _eval_order_by(node, tc, user, collection, limit):
async def _eval_order_by(node, tc, workspace, collection, limit):
"""Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
key_fns = []
for cond in node.expr:
@ -206,7 +206,7 @@ async def _eval_order_by(node, tc, user, collection, limit):
return order_by(solutions, key_fns)
async def _eval_slice(node, tc, user, collection, limit):
async def _eval_slice(node, tc, workspace, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible
inner_limit = limit
@ -214,13 +214,13 @@ async def _eval_slice(node, tc, user, collection, limit):
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, user, collection, inner_limit)
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, user, collection, limit):
async def _eval_extend(node, tc, workspace, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
var_name = str(node.var)
expr = node.expr
@ -246,9 +246,9 @@ async def _eval_extend(node, tc, user, collection, limit):
return result
async def _eval_group(node, tc, user, collection, limit):
async def _eval_group(node, tc, workspace, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
# Extract grouping expressions
group_exprs = []
@ -289,9 +289,9 @@ async def _eval_group(node, tc, user, collection, limit):
return result
async def _eval_aggregate_join(node, tc, user, collection, limit):
async def _eval_aggregate_join(node, tc, workspace, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
result = []
for sol in solutions:
@ -310,7 +310,7 @@ async def _eval_aggregate_join(node, tc, user, collection, limit):
return result
async def _eval_graph(node, tc, user, collection, limit):
async def _eval_graph(node, tc, workspace, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term
@ -319,16 +319,16 @@ async def _eval_graph(node, tc, user, collection, limit):
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
else:
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_values(node, tc, user, collection, limit):
async def _eval_values(node, tc, workspace, collection, limit):
"""Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var]
solutions = []
@ -343,9 +343,9 @@ async def _eval_values(node, tc, user, collection, limit):
return solutions
async def _eval_to_multiset(node, tc, user, collection, limit):
async def _eval_to_multiset(node, tc, workspace, collection, limit):
"""Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
# --- Aggregate computation ---
@ -487,7 +487,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, user, collection, limit):
async def _query_pattern(tc, s, p, o, workspace, collection, limit):
"""
Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +496,7 @@ async def _query_pattern(tc, s, p, o, user, collection, limit):
results = await tc.query(
s=s, p=p, o=o,
limit=limit,
user=user,
workspace=workspace,
collection=collection,
)
return results

View file

@ -141,7 +141,7 @@ class Processor(FlowProcessor):
solutions = await evaluate(
parsed.algebra,
triples_client,
user=request.user or "trustgraph",
workspace=flow.workspace,
collection=request.collection or "default",
limit=request.limit or 10000,
)

View file

@ -178,34 +178,34 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def ensure_connection(self, user):
def ensure_connection(self, workspace):
"""Ensure we have a connection to the correct keyspace."""
if user != self.table:
if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
self.table = user
self.table = workspace
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:
# ensure_connection may construct a fresh
# EntityCentricKnowledgeGraph which does sync schema
# setup against Cassandra. Push it to a worker thread
# so the event loop doesn't block on first-use per user.
await asyncio.to_thread(self.ensure_connection, query.user)
# so the event loop doesn't block on first-use per workspace.
await asyncio.to_thread(self.ensure_connection, workspace)
# Extract values from query
s_val = get_term_value(query.s)
@ -359,13 +359,13 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
async def query_triples_stream(self, query):
async def query_triples_stream(self, workspace, query):
"""
Streaming query - yields (batch, is_final) tuples.
Uses Cassandra's paging to fetch results incrementally.
"""
try:
await asyncio.to_thread(self.ensure_connection, query.user)
await asyncio.to_thread(self.ensure_connection, workspace)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
@ -395,7 +395,7 @@ class Processor(TriplesQueryService):
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(query, batch_size):
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
yield batch, is_final
return
@ -452,9 +452,9 @@ class Processor(TriplesQueryService):
logger.error(f"Exception in streaming query: {e}", exc_info=True)
raise e
async def _fallback_stream(self, query, batch_size):
async def _fallback_stream(self, workspace, query, batch_size):
"""Fallback to non-streaming query with post-hoc batching."""
triples = await self.query_triples(query)
triples = await self.query_triples(workspace, query)
for i in range(0, len(triples), batch_size):
batch = triples[i:i + batch_size]

View file

@ -58,7 +58,7 @@ class Processor(TriplesQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:

View file

@ -63,12 +63,11 @@ class Processor(TriplesQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:
# Extract user and collection, use defaults if not provided
user = query.user if query.user else "default"
workspace = workspace
collection = query.collection if query.collection else "default"
triples = []
@ -80,13 +79,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -94,13 +93,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -112,13 +111,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -127,13 +126,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -148,13 +147,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -163,13 +162,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -182,13 +181,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -197,13 +196,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -221,13 +220,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -236,13 +235,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -255,13 +254,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -270,13 +269,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -291,13 +290,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -306,13 +305,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -325,12 +324,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -339,12 +338,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)

View file

@ -63,14 +63,12 @@ class Processor(TriplesQueryService):
else:
return Term(type=LITERAL, value=ent)
async def query_triples(self, query):
async def query_triples(self, workspace, query):
try:
# Extract user and collection, use defaults if not provided
user = query.user if query.user else "default"
collection = query.collection if query.collection else "default"
triples = []
if query.s is not None:
@ -80,13 +78,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -94,13 +92,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -112,13 +110,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -127,13 +125,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -148,13 +146,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -163,13 +161,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -182,13 +180,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -197,13 +195,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -221,13 +219,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -236,13 +234,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -255,13 +253,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -270,13 +268,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -291,13 +289,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -306,13 +304,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -325,12 +323,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -339,12 +337,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -367,7 +365,7 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
@staticmethod
def add_args(parser):

View file

@ -26,11 +26,11 @@ LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, user, collection, verbose,
self, rag, workspace, collection, verbose,
doc_limit=20, track_usage=None,
):
self.rag = rag
self.user = user
self.workspace = workspace
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
@ -97,7 +97,7 @@ class Query:
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
user=self.user, collection=self.collection,
collection=self.collection,
)
results = await asyncio.gather(
@ -122,7 +122,7 @@ class Query:
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
content = await self.rag.fetch_chunk(match.chunk_id, self.workspace)
docs.append(content)
chunk_ids.append(match.chunk_id)
except Exception as e:
@ -154,7 +154,7 @@ class DocumentRag:
logger.debug("DocumentRag initialized")
async def query(
self, query, user="trustgraph", collection="default",
self, query, workspace="default", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
@ -163,7 +163,7 @@ class DocumentRag:
Args:
query: The query string
user: User identifier
workspace: Workspace for isolation (also scopes chunk lookup)
collection: Collection identifier
doc_limit: Max chunks to retrieve
streaming: Enable streaming LLM response
@ -210,7 +210,8 @@ class DocumentRag:
await explain_callback(q_triples, q_uri)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
rag=self, workspace=workspace, collection=collection,
verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage,
)

View file

@ -96,19 +96,19 @@ class Processor(FlowProcessor):
await super(Processor, self).start()
await self.librarian.start()
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
async def fetch_chunk_content(self, chunk_id, workspace, timeout=120):
"""Fetch chunk content from librarian. Chunks are small so
single request-response is fine."""
return await self.librarian.fetch_document_text(
document_id=chunk_id, user=user, timeout=timeout,
document_id=chunk_id, workspace=workspace, timeout=timeout,
)
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""Save answer content to the librarian."""
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
@ -119,7 +119,7 @@ class Processor(FlowProcessor):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
await self.librarian.request(request, timeout=timeout)
@ -150,14 +150,13 @@ class Processor(FlowProcessor):
doc_limit = self.doc_limit
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
user=v.user,
collection=v.collection, # Store in user's collection
collection=v.collection,
),
triples=triples,
))
@ -178,7 +177,7 @@ class Processor(FlowProcessor):
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
workspace=flow.workspace,
content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...",
)
@ -202,7 +201,7 @@ class Processor(FlowProcessor):
# All chunks (including final one with end_of_stream=True) are sent via callback
response, usage = await self.rag.query(
v.query,
user=v.user,
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
streaming=True,
@ -227,7 +226,7 @@ class Processor(FlowProcessor):
# Non-streaming path - single response with answer and token usage
response, usage = await self.rag.query(
v.query,
user=v.user,
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
explain_callback=send_explainability,

View file

@ -75,12 +75,11 @@ def edge_id(s, p, o):
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
"""LRU cache with TTL for label caching.
CRITICAL SECURITY WARNING:
This cache is shared within a GraphRag instance but GraphRag instances
are created per-request. Cache keys MUST include user:collection prefix
to ensure data isolation between different security contexts.
GraphRag instances are created per-request, so this cache is
request-scoped. Cache keys include the collection prefix to keep
entries from different collections distinct within one request.
"""
def __init__(self, max_size=5000, ttl=300):
@ -119,12 +118,11 @@ class LRUCacheWithTTL:
class Query:
def __init__(
self, rag, user, collection, verbose,
self, rag, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, track_usage=None,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
@ -194,7 +192,7 @@ class Query:
entity_tasks = [
self.rag.graph_embeddings_client.query(
vector=v, limit=per_concept_limit,
user=self.user, collection=self.collection,
collection=self.collection,
)
for v in vectors
]
@ -222,18 +220,18 @@ class Query:
async def maybe_label(self, e):
# CRITICAL SECURITY: Cache key MUST include user and collection
# to prevent data leakage between different contexts
cache_key = f"{self.user}:{self.collection}:{e}"
# The label cache lives on a per-request GraphRag instance — no
# cross-request isolation concern. The collection prefix keeps
# entries from different collections distinct within one request.
cache_key = f"{self.collection}:{e}"
# Check LRU cache first with isolated key
cached_label = self.rag.label_cache.get(cache_key)
if cached_label is not None:
return cached_label
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
collection=self.collection,
g="",
)
@ -255,19 +253,19 @@ class Query:
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
collection=self.collection,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
collection=self.collection,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection,
collection=self.collection,
batch_size=20, g="",
)
])
@ -468,7 +466,7 @@ class Query:
subgraph_tasks.append(
self.rag.triples_client.query(
s=None, p=TG_CONTAINS, o=quoted, limit=1,
user=self.user, collection=self.collection,
collection=self.collection,
g=GRAPH_SOURCE,
)
)
@ -501,7 +499,7 @@ class Query:
derivation_tasks = [
self.rag.triples_client.query(
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
user=self.user, collection=self.collection,
collection=self.collection,
g=GRAPH_SOURCE,
)
for uri in current_uris
@ -535,7 +533,7 @@ class Query:
metadata_tasks = [
self.rag.triples_client.query(
s=uri, p=None, o=None, limit=50,
user=self.user, collection=self.collection,
collection=self.collection,
)
for uri in doc_uris
]
@ -560,11 +558,9 @@ class Query:
class GraphRag:
"""
CRITICAL SECURITY:
This class MUST be instantiated per-request to ensure proper isolation
between users and collections. The cache within this instance will only
live for the duration of a single request, preventing cross-contamination
of data between different security contexts.
Must be instantiated per-request so the label cache lives only for
the duration of a single request. Workspace isolation is enforced
by the trusted flow layer (flow.workspace), not by this class.
"""
def __init__(
@ -587,7 +583,7 @@ class GraphRag:
logger.debug("GraphRag initialized")
async def query(
self, query, user = "trustgraph", collection = "default",
self, query, collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
streaming = False,
@ -600,7 +596,6 @@ class GraphRag:
Args:
query: The query string
user: User identifier
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
@ -657,7 +652,7 @@ class GraphRag:
await explain_callback(q_triples, q_uri)
q = Query(
rag = self, user = user, collection = collection,
rag = self, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,

View file

@ -62,9 +62,9 @@ class Processor(FlowProcessor):
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
# Caching must NEVER allow information leakage across these boundaries
# Workspace isolation is enforced by the flow layer (flow.workspace).
# Per-request caching (see GraphRag) keeps within-request state
# scoped; no cross-request sharing here.
self.register_specification(
ConsumerSpec(
@ -170,13 +170,13 @@ class Processor(FlowProcessor):
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
@ -188,7 +188,7 @@ class Processor(FlowProcessor):
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
workspace=workspace,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
@ -199,7 +199,7 @@ class Processor(FlowProcessor):
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
workspace=workspace,
)
# Create future for response
@ -241,14 +241,13 @@ class Processor(FlowProcessor):
explainability_refs_emitted = []
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
user=v.user,
collection=v.collection, # Store in user's collection, not separate explainability collection
collection=v.collection,
),
triples=triples,
))
@ -266,9 +265,9 @@ class Processor(FlowProcessor):
explainability_refs_emitted.append(explain_id)
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
# Create new GraphRag instance per request — its label cache
# is request-scoped, and flow clients must not be shared
# across requests.
rag = GraphRag(
embeddings_client=flow("embeddings-request"),
graph_embeddings_client=flow("graph-embeddings-request"),
@ -311,7 +310,7 @@ class Processor(FlowProcessor):
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
workspace=flow.workspace,
content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...",
)
@ -333,7 +332,7 @@ class Processor(FlowProcessor):
# Query with streaming and real-time explain
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
query = v.query, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
@ -349,7 +348,7 @@ class Processor(FlowProcessor):
else:
# Non-streaming path with real-time explain
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
query = v.query, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
@ -464,7 +463,7 @@ class Processor(FlowProcessor):
help=f'Max edges after LLM scoring (default: 25)'
)
# Note: Explainability triples are now stored in the user's collection
# Note: Explainability triples are now stored in the request's collection
# with the named graph urn:graph:retrieval (no separate collection needed)
def run():

View file

@ -66,32 +66,39 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
logger.info("NLP Query service initialized")
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
@ -106,29 +113,37 @@ class Processor(FlowProcessor):
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
async def phase1_select_schemas(self, question: str, flow) -> List[str]:
"""Phase 1: Use prompt service to select relevant schemas for the question"""
logger.info("Starting Phase 1: Schema selection")
ws_schemas = self.schemas.get(flow.workspace, {})
# Prepare schema information for the prompt
schema_info = []
for name, schema in self.schemas.items():
for name, schema in ws_schemas.items():
schema_desc = {
"name": name,
"description": schema.description,
@ -176,12 +191,14 @@ class Processor(FlowProcessor):
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]:
"""Phase 2: Generate GraphQL query using selected schemas"""
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
ws_schemas = self.schemas.get(flow.workspace, {})
# Get detailed schema information for selected schemas only
selected_schema_info = []
for schema_name in selected_schemas:
if schema_name in self.schemas:
schema = self.schemas[schema_name]
if schema_name in ws_schemas:
schema = ws_schemas[schema_name]
schema_desc = {
"name": schema_name,
"description": schema.description,

View file

@ -72,21 +72,28 @@ class Processor(FlowProcessor):
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config, types=["schema"])
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Per-workspace schema storage: {workspace: {name: RowSchema}}
self.schemas: Dict[str, Dict[str, RowSchema]] = {}
logger.info("Structured Data Diagnosis service initialized")
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
@ -120,13 +127,19 @@ class Processor(FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
async def on_message(self, msg, consumer, flow):
"""Handle incoming structured data diagnosis request"""
@ -216,15 +229,19 @@ class Processor(FlowProcessor):
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Get target schema
if request.schema_name not in self.schemas:
# Get target schema from this workspace's schemas
ws_schemas = self.schemas.get(flow.workspace, {})
if request.schema_name not in ws_schemas:
error = Error(
type="SchemaNotFound",
message=f"Schema '{request.schema_name}' not found in configuration"
message=(
f"Schema '{request.schema_name}' not found "
f"in configuration for workspace {flow.workspace}"
)
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
target_schema = self.schemas[request.schema_name]
target_schema = ws_schemas[request.schema_name]
# Generate descriptor using prompt service
descriptor = await self.generate_descriptor_with_prompt(
@ -260,26 +277,33 @@ class Processor(FlowProcessor):
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
# Step 2: Use provided schema name or auto-select first available
ws_schemas = self.schemas.get(flow.workspace, {})
schema_name = request.schema_name
if not schema_name and self.schemas:
schema_name = list(self.schemas.keys())[0]
if not schema_name and ws_schemas:
schema_name = list(ws_schemas.keys())[0]
logger.info(f"Auto-selected schema: {schema_name}")
if not schema_name:
error = Error(
type="NoSchemaAvailable",
message="No schema specified and no schemas available in configuration"
message=(
f"No schema specified and no schemas available "
f"in configuration for workspace {flow.workspace}"
)
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
if schema_name not in self.schemas:
if schema_name not in ws_schemas:
error = Error(
type="SchemaNotFound",
message=f"Schema '{schema_name}' not found in configuration"
message=(
f"Schema '{schema_name}' not found in "
f"configuration for workspace {flow.workspace}"
)
)
return StructuredDataDiagnosisResponse(error=error, operation=request.operation)
target_schema = self.schemas[schema_name]
target_schema = ws_schemas[schema_name]
# Step 3: Generate descriptor
descriptor = await self.generate_descriptor_with_prompt(
@ -316,8 +340,9 @@ class Processor(FlowProcessor):
logger.info("Processing schema-selection operation")
# Prepare all schemas for the prompt - match the original config format
ws_schemas = self.schemas.get(flow.workspace, {})
all_schemas = []
for schema_name, row_schema in self.schemas.items():
for schema_name, row_schema in ws_schemas.items():
schema_info = {
"name": row_schema.name,
"description": row_schema.description,

View file

@ -111,9 +111,9 @@ class Processor(FlowProcessor):
else:
variables_as_strings[key] = str(value)
# Use user/collection values from request
# Use collection from request. Workspace isolation is
# enforced by flow.workspace at the rows-query service.
objects_request = RowsQueryRequest(
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
variables=variables_as_strings,

View file

@ -33,7 +33,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_document_embeddings(self, message):
async def store_document_embeddings(self, workspace, message):
for emb in message.chunks:
@ -45,7 +45,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if vec:
self.vecstore.insert(
vec, chunk_id,
message.metadata.user,
workspace,
message.metadata.collection
)
@ -60,27 +60,27 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection)
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(workspace, collection)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
self.vecstore.delete_collection(user, collection)
logger.info(f"Successfully deleted collection {user}/{collection}")
self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -88,12 +88,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def store_document_embeddings(self, message):
async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -112,7 +112,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
f"d-{workspace}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
@ -165,22 +165,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
prefix = f"d-{user}-{collection}-"
prefix = f"d-{workspace}-{collection}-"
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
@ -195,10 +195,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -39,12 +39,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_document_embeddings(self, message):
async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -63,7 +63,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
f"d_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
@ -107,22 +107,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
prefix = f"d_{user}_{collection}_"
prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -137,10 +137,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_graph_embeddings(self, message):
async def store_graph_embeddings(self, workspace, message):
for entity in message.entities:
entity_value = get_term_value(entity.entity)
@ -57,7 +57,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if vec:
self.vecstore.insert(
vec, entity_value,
message.metadata.user,
workspace,
message.metadata.collection,
chunk_id=entity.chunk_id or "",
)
@ -73,27 +73,27 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection)
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(workspace, collection)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
self.vecstore.delete_collection(user, collection)
logger.info(f"Successfully deleted collection {user}/{collection}")
self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -102,12 +102,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def store_graph_embeddings(self, message):
async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -126,7 +126,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
f"t-{workspace}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
@ -183,22 +183,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
prefix = f"t-{user}-{collection}-"
prefix = f"t-{workspace}-{collection}-"
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
@ -213,10 +213,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -54,12 +54,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_graph_embeddings(self, message):
async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(message.metadata.user, message.metadata.collection):
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for user {message.metadata.user} "
f"Collection {message.metadata.collection} for workspace {workspace} "
f"does not exist in config (likely deleted while data was in-flight). "
f"Dropping message."
)
@ -78,7 +78,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
f"t_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
@ -126,22 +126,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
prefix = f"t_{user}_{collection}_"
prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -156,10 +156,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -65,13 +65,13 @@ class Processor(FlowProcessor):
v = msg.value()
if v.triples:
await self.table_store.add_triples(v)
await self.table_store.add_triples(flow.workspace, v)
async def on_graph_embeddings(self, msg, consumer, flow):
v = msg.value()
if v.entities:
await self.table_store.add_graph_embeddings(v)
await self.table_store.add_graph_embeddings(flow.workspace, v)
@staticmethod
def add_args(parser):

View file

@ -2,13 +2,13 @@
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
@ -77,10 +77,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
return safe_name.lower()
def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int
self, workspace: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(user)
safe_user = self.sanitize_name(workspace)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
@ -114,18 +114,19 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"{embeddings.schema_name} from {embeddings.metadata.id}"
)
workspace = flow.workspace
# Validate collection exists in config before processing
if not self.collection_exists(
embeddings.metadata.user, embeddings.metadata.collection
workspace, embeddings.metadata.collection
):
logger.warning(
f"Collection {embeddings.metadata.collection} for user "
f"{embeddings.metadata.user} does not exist in config. "
f"Collection {embeddings.metadata.collection} for workspace "
f"{workspace} does not exist in config. "
f"Dropping message."
)
return
user = embeddings.metadata.user
collection = embeddings.metadata.collection
schema_name = embeddings.schema_name
@ -145,7 +146,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
workspace, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
@ -168,17 +169,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {user}/{collection} - "
f"Row embeddings collection create request for {workspace}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, user: str, collection: str):
"""Delete all Qdrant collections for a given user/collection"""
async def delete_collection(self, workspace: str, collection: str):
"""Delete all Qdrant collections for a given workspace/collection"""
try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -196,23 +197,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}"
f"for {workspace}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}: {e}",
f"Failed to delete collection {workspace}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, user: str, collection: str, schema_name: str
self, workspace: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific user/collection/schema"""
"""Delete Qdrant collection for a specific workspace/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
@ -233,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise

View file

@ -119,19 +119,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
async def on_schema_config(self, config, version):
async def on_schema_config(self, workspace, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
logger.info(
f"Loading schema configuration version {version} "
f"for workspace {workspace}"
)
# Track which schemas changed so we can clear partition cache
old_schema_names = set(self.schemas.keys())
# Track which schemas changed in this workspace
old_schemas = self.schemas.get(workspace, {})
old_schema_names = set(old_schemas.keys())
# Clear existing schemas
self.schemas = {}
# Replace existing schemas for this workspace
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
logger.warning(
f"No '{self.config_key}' type in configuration "
f"for {workspace}"
)
return
# Get the schemas dictionary for our type
@ -165,24 +173,32 @@ class Processor(CollectionConfigHandler, FlowProcessor):
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
ws_schemas[schema_name] = row_schema
logger.info(
f"Loaded schema: {schema_name} with "
f"{len(fields)} fields for {workspace}"
)
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
logger.info(
f"Schema configuration loaded for {workspace}: "
f"{len(ws_schemas)} schemas"
)
# Clear partition cache for schemas that changed
# This ensures next write will re-register partitions
new_schema_names = set(self.schemas.keys())
# Clear partition cache for schemas that changed in this workspace
new_schema_names = set(ws_schemas.keys())
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
if changed_schemas:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if sch not in changed_schemas
}
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
logger.info(
f"Cleared partition cache for changed schemas "
f"in {workspace}: {changed_schemas}"
)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
@ -286,7 +302,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
return index_names
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
def register_partitions(
self, keyspace: str, collection: str, schema_name: str,
workspace: str,
):
"""
Register partition entries for a (collection, schema_name) pair.
Called once on first row for each pair.
@ -295,9 +314,13 @@ class Processor(CollectionConfigHandler, FlowProcessor):
if cache_key in self.registered_partitions:
return
schema = self.schemas.get(schema_name)
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(schema_name)
if not schema:
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
logger.warning(
f"Cannot register partitions - schema {schema_name} "
f"not found in workspace {workspace}"
)
return
safe_keyspace = self.sanitize_name(keyspace)
@ -338,13 +361,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
workspace = flow.workspace
logger.info(
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
f"from {obj.metadata.id}"
f"from {obj.metadata.id} (workspace {workspace})"
)
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
if not self.collection_exists(workspace, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
@ -352,13 +376,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
# Get schema definition for this workspace
ws_schemas = self.schemas.get(workspace, {})
schema = ws_schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
logger.warning(
f"No schema found for {obj.schema_name} in "
f"workspace {workspace} - skipping"
)
return
keyspace = obj.metadata.user
keyspace = workspace
collection = obj.metadata.collection
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
@ -370,7 +398,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register partitions if first time seeing this (collection, schema_name)
await asyncio.to_thread(
self.register_partitions, keyspace, collection, schema_name
self.register_partitions,
keyspace, collection, schema_name, workspace,
)
safe_keyspace = self.sanitize_name(keyspace)
@ -430,25 +459,25 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread)
await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, user)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for user {user}")
logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists
if user not in self.known_keyspaces:
if workspace not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
@ -459,7 +488,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
self.known_keyspaces.add(workspace)
# Discover all partitions for this collection
select_partitions_cql = f"""
@ -522,12 +551,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Discover partitions for this collection + schema
select_partitions_cql = f"""

View file

@ -147,9 +147,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_triples(self, message):
user = message.metadata.user
async def store_triples(self, workspace, message):
# The cassandra-driver work below — connection, schema
# setup, and per-triple inserts — is all synchronous.
@ -159,7 +157,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
def _do_store():
if self.table is None or self.table != user:
if self.table is None or self.table != workspace:
self.tg = None
@ -170,21 +168,21 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = user
self.table = workspace
for t in message.triples:
# Extract values from Term objects
@ -212,12 +210,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
await asyncio.to_thread(_do_store)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push"""
def _do_create():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
@ -227,23 +225,23 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}")
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = user
self.table = workspace
# Create collection using the built-in method
logger.info(f"Creating collection {collection} for user {user}")
logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection):
logger.info(f"Collection {collection} already exists")
@ -254,15 +252,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
await asyncio.to_thread(_do_create)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table"""
def _do_delete():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
@ -272,29 +270,29 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}")
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = user
self.table = workspace
# Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {user}")
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
@staticmethod

View file

@ -59,15 +59,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -77,15 +77,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": value,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -95,19 +95,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -117,19 +117,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -139,36 +139,34 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""Check if collection metadata node exists"""
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
params={"user": user, "collection": collection}
params={"workspace": workspace, "collection": collection}
)
return result.result_set is not None and len(result.result_set) > 0
def create_collection(self, user, collection):
def create_collection(self, workspace, collection):
"""Create collection metadata node"""
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"workspace": workspace,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
async def store_triples(self, workspace, message):
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
if not self.collection_exists(workspace, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first via collection management API."
@ -182,14 +180,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
self.create_node(s_val, user, collection)
self.create_node(s_val, workspace, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
self.create_node(o_val, workspace, collection)
self.relate_node(s_val, p_val, o_val, workspace, collection)
else:
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
self.create_literal(o_val, workspace, collection)
self.relate_literal(s_val, p_val, o_val, workspace, collection)
@staticmethod
def add_args(parser):
@ -208,58 +206,58 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in FalkorDB via config push"""
try:
# Check if collection exists
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1",
params={"user": user, "collection": collection}
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) RETURN c LIMIT 1",
params={"workspace": workspace, "collection": collection}
)
if result.result_set:
logger.info(f"Collection {user}/{collection} already exists")
logger.info(f"Collection {workspace}/{collection} already exists")
else:
# Create collection metadata node
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"workspace": workspace,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection {user}/{collection}")
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for FalkorDB triples via config push"""
try:
# Delete all nodes and literals for this user/collection
# Delete all nodes and literals for this workspace/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection}
"MATCH (n:Node {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"workspace": workspace, "collection": collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection}
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"workspace": workspace, "collection": collection}
)
# Delete collection metadata node
metadata_result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
params={"user": user, "collection": collection}
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) DELETE c",
params={"workspace": workspace, "collection": collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}")
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -117,10 +117,10 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Maybe index already exists
logger.warning("Index create failure ignored")
# New indexes for user/collection filtering
# New indexes for workspace/collection filtering
try:
session.run(
"CREATE INDEX ON :Node(user)"
"CREATE INDEX ON :Node(workspace)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
@ -136,7 +136,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
session.run(
"CREATE INDEX ON :Literal(user)"
"CREATE INDEX ON :Literal(workspace)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
@ -152,13 +152,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done")
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -167,13 +167,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -182,15 +182,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -199,15 +199,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_triple(self, tx, t, user, collection):
def create_triple(self, tx, t, workspace, collection):
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
@ -224,48 +224,46 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Create new s node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=s_val, user=user, collection=collection
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=s_val, workspace=workspace, collection=collection
)
if t.o.type == IRI:
# Create new o node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=o_val, user=user, collection=collection
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=o_val, workspace=workspace, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
)
else:
# Create new o literal with given uri, if not exists
result = tx.run(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=o_val, user=user, collection=collection
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=o_val, workspace=workspace, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
)
async def store_triples(self, message):
async def store_triples(self, workspace, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
if not self.collection_exists(workspace, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first via collection management API."
@ -279,18 +277,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
self.create_node(s_val, user, collection)
self.create_node(s_val, workspace, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
self.create_node(o_val, workspace, collection)
self.relate_node(s_val, p_val, o_val, workspace, collection)
else:
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
self.create_literal(o_val, workspace, collection)
self.relate_literal(s_val, p_val, o_val, workspace, collection)
# Alternative implementation using transactions
# with self.io.session(database=self.db) as session:
# session.execute_write(self.create_triple, t, user, collection)
# session.execute_write(self.create_triple, t, workspace, collection)
@staticmethod
def add_args(parser):
@ -321,72 +319,72 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
def _collection_exists_in_db(self, user, collection):
def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
workspace=workspace, collection=collection
)
return bool(list(result))
def _create_collection_in_db(self, user, collection):
def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Memgraph via config push"""
try:
if self._collection_exists_in_db(user, collection):
logger.info(f"Collection {user}/{collection} already exists")
if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {workspace}/{collection} already exists")
else:
self._create_collection_in_db(user, collection)
logger.info(f"Created collection {user}/{collection}")
self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
# Delete all nodes for this workspace and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
# Delete all literals for this workspace and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c",
user=user, collection=collection
workspace=workspace, collection=collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -80,14 +80,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Create indexes...")
# Legacy indexes for backwards compatibility
try:
session.run(
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
try:
@ -96,7 +94,6 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
try:
@ -105,13 +102,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
# New compound indexes for user/collection filtering
try:
session.run(
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
@ -119,17 +114,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
session.run(
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
logger.warning("Index create failure ignored")
# Note: Neo4j doesn't support compound indexes on relationships in all versions
# Try to create individual indexes on relationship properties
# Neo4j doesn't support compound indexes on relationships in all versions
try:
session.run(
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
)
except Exception as e:
logger.warning(f"Relationship index create failure: {e}")
@ -145,13 +139,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done")
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -160,13 +154,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -175,15 +169,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -192,15 +186,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -209,14 +203,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
async def store_triples(self, message):
async def store_triples(self, workspace, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
if not self.collection_exists(workspace, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first via collection management API."
@ -230,14 +222,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
self.create_node(s_val, user, collection)
self.create_node(s_val, workspace, collection)
if t.o.type == IRI:
self.create_node(o_val, user, collection)
self.relate_node(s_val, p_val, o_val, user, collection)
self.create_node(o_val, workspace, collection)
self.relate_node(s_val, p_val, o_val, workspace, collection)
else:
self.create_literal(o_val, user, collection)
self.relate_literal(s_val, p_val, o_val, user, collection)
self.create_literal(o_val, workspace, collection)
self.relate_literal(s_val, p_val, o_val, workspace, collection)
@staticmethod
def add_args(parser):
@ -268,75 +260,70 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
def _collection_exists_in_db(self, user, collection):
def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
workspace=workspace, collection=collection
)
return bool(list(result))
def _create_collection_in_db(self, user, collection):
def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Neo4j via config push"""
try:
if self._collection_exists_in_db(user, collection):
logger.info(f"Collection {user}/{collection} already exists")
if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {workspace}/{collection} already exists")
else:
self._create_collection_in_db(user, collection)
logger.info(f"Created collection {user}/{collection}")
self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c",
user=user, collection=collection
workspace=workspace, collection=collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -72,10 +72,11 @@ class ConfigTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS config (
workspace text,
class text,
key text,
value text,
PRIMARY KEY (class, key)
PRIMARY KEY ((workspace, class), key)
);
""");
@ -124,52 +125,63 @@ class ConfigTableStore:
def prepare_statements(self):
self.put_config_stmt = self.cassandra.prepare("""
INSERT INTO config ( class, key, value )
VALUES (?, ?, ?)
""")
self.get_classes_stmt = self.cassandra.prepare("""
SELECT DISTINCT class FROM config;
INSERT INTO config ( workspace, class, key, value )
VALUES (?, ?, ?, ?)
""")
self.get_keys_stmt = self.cassandra.prepare("""
SELECT key FROM config WHERE class = ?;
SELECT key FROM config
WHERE workspace = ? AND class = ?;
""")
self.get_value_stmt = self.cassandra.prepare("""
SELECT value FROM config WHERE class = ? AND key = ?;
SELECT value FROM config
WHERE workspace = ? AND class = ? AND key = ?;
""")
self.delete_key_stmt = self.cassandra.prepare("""
DELETE FROM config
WHERE class = ? AND key = ?;
WHERE workspace = ? AND class = ? AND key = ?;
""")
self.get_all_stmt = self.cassandra.prepare("""
SELECT class AS cls, key, value FROM config;
SELECT workspace, class AS cls, key, value FROM config;
""")
self.get_all_for_workspace_stmt = self.cassandra.prepare("""
SELECT class AS cls, key, value FROM config
WHERE workspace = ?
ALLOW FILTERING;
""")
self.get_values_stmt = self.cassandra.prepare("""
SELECT key, value FROM config WHERE class = ?;
SELECT key, value FROM config
WHERE workspace = ? AND class = ?;
""")
async def put_config(self, cls, key, value):
self.get_values_all_ws_stmt = self.cassandra.prepare("""
SELECT workspace, key, value FROM config
WHERE class = ?
ALLOW FILTERING;
""")
async def put_config(self, workspace, cls, key, value):
try:
await async_execute(
self.cassandra,
self.put_config_stmt,
(cls, key, value),
(workspace, cls, key, value),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def get_value(self, cls, key):
async def get_value(self, workspace, cls, key):
try:
rows = await async_execute(
self.cassandra,
self.get_value_stmt,
(cls, key),
(workspace, cls, key),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -179,12 +191,12 @@ class ConfigTableStore:
return row[0]
return None
async def get_values(self, cls):
async def get_values(self, workspace, cls):
try:
rows = await async_execute(
self.cassandra,
self.get_values_stmt,
(cls,),
(workspace, cls),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -192,18 +204,20 @@ class ConfigTableStore:
return [[row[0], row[1]] for row in rows]
async def get_classes(self):
async def get_values_all_ws(self, cls):
"""Return (workspace, key, value) tuples for all workspaces
with entries of the given class."""
try:
rows = await async_execute(
self.cassandra,
self.get_classes_stmt,
(),
self.get_values_all_ws_stmt,
(cls,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
return [row[0] for row in rows]
return [(row[0], row[1], row[2]) for row in rows]
async def get_all(self):
try:
@ -216,14 +230,27 @@ class ConfigTableStore:
logger.error("Exception occurred", exc_info=True)
raise
return [(row[0], row[1], row[2], row[3]) for row in rows]
async def get_all_for_workspace(self, workspace):
try:
rows = await async_execute(
self.cassandra,
self.get_all_for_workspace_stmt,
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
return [(row[0], row[1], row[2]) for row in rows]
async def get_keys(self, cls):
async def get_keys(self, workspace, cls):
try:
rows = await async_execute(
self.cassandra,
self.get_keys_stmt,
(cls,),
(workspace, cls),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -231,12 +258,12 @@ class ConfigTableStore:
return [row[0] for row in rows]
async def delete_key(self, cls, key):
async def delete_key(self, workspace, cls, key):
try:
await async_execute(
self.cassandra,
self.delete_key_stmt,
(cls, key),
(workspace, cls, key),
)
except Exception:
logger.error("Exception occurred", exc_info=True)

View file

@ -88,7 +88,7 @@ class KnowledgeTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS triples (
user text,
workspace text,
document_id text,
id uuid,
time timestamp,
@ -98,7 +98,7 @@ class KnowledgeTableStore:
triples list<tuple<
text, boolean, text, boolean, text, boolean
>>,
PRIMARY KEY ((user, document_id), id)
PRIMARY KEY ((workspace, document_id), id)
);
""");
@ -106,7 +106,7 @@ class KnowledgeTableStore:
self.cassandra.execute("""
create table if not exists graph_embeddings (
user text,
workspace text,
document_id text,
id uuid,
time timestamp,
@ -119,20 +119,20 @@ class KnowledgeTableStore:
list<double>
>
>,
PRIMARY KEY ((user, document_id), id)
PRIMARY KEY ((workspace, document_id), id)
);
""");
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS graph_embeddings_user ON
graph_embeddings ( user );
CREATE INDEX IF NOT EXISTS graph_embeddings_workspace ON
graph_embeddings ( workspace );
""");
logger.debug("document_embeddings table...")
self.cassandra.execute("""
create table if not exists document_embeddings (
user text,
workspace text,
document_id text,
id uuid,
time timestamp,
@ -145,13 +145,13 @@ class KnowledgeTableStore:
list<double>
>
>,
PRIMARY KEY ((user, document_id), id)
PRIMARY KEY ((workspace, document_id), id)
);
""");
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS document_embeddings_user ON
document_embeddings ( user );
CREATE INDEX IF NOT EXISTS document_embeddings_workspace ON
document_embeddings ( workspace );
""");
logger.info("Cassandra schema OK.")
@ -161,7 +161,7 @@ class KnowledgeTableStore:
self.insert_triples_stmt = self.cassandra.prepare("""
INSERT INTO triples
(
id, user, document_id,
id, workspace, document_id,
time, metadata, triples
)
VALUES (?, ?, ?, ?, ?, ?)
@ -170,7 +170,7 @@ class KnowledgeTableStore:
self.insert_graph_embeddings_stmt = self.cassandra.prepare("""
INSERT INTO graph_embeddings
(
id, user, document_id, time, metadata, entity_embeddings
id, workspace, document_id, time, metadata, entity_embeddings
)
VALUES (?, ?, ?, ?, ?, ?)
""")
@ -178,45 +178,45 @@ class KnowledgeTableStore:
self.insert_document_embeddings_stmt = self.cassandra.prepare("""
INSERT INTO document_embeddings
(
id, user, document_id, time, metadata, chunks
id, workspace, document_id, time, metadata, chunks
)
VALUES (?, ?, ?, ?, ?, ?)
""")
self.list_cores_stmt = self.cassandra.prepare("""
SELECT DISTINCT user, document_id FROM graph_embeddings
WHERE user = ?
SELECT DISTINCT workspace, document_id FROM graph_embeddings
WHERE workspace = ?
""")
self.get_triples_stmt = self.cassandra.prepare("""
SELECT id, time, metadata, triples
FROM triples
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.get_graph_embeddings_stmt = self.cassandra.prepare("""
SELECT id, time, metadata, entity_embeddings
FROM graph_embeddings
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.get_document_embeddings_stmt = self.cassandra.prepare("""
SELECT id, time, metadata, chunks
FROM document_embeddings
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.delete_triples_stmt = self.cassandra.prepare("""
DELETE FROM triples
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
self.delete_graph_embeddings_stmt = self.cassandra.prepare("""
DELETE FROM graph_embeddings
WHERE user = ? AND document_id = ?
WHERE workspace = ? AND document_id = ?
""")
async def add_triples(self, m):
async def add_triples(self, workspace, m):
when = int(time.time() * 1000)
@ -232,7 +232,7 @@ class KnowledgeTableStore:
self.cassandra,
self.insert_triples_stmt,
(
uuid.uuid4(), m.metadata.user,
uuid.uuid4(), workspace,
m.metadata.root or m.metadata.id, when,
[], triples,
),
@ -241,7 +241,7 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
async def add_graph_embeddings(self, m):
async def add_graph_embeddings(self, workspace, m):
when = int(time.time() * 1000)
@ -258,7 +258,7 @@ class KnowledgeTableStore:
self.cassandra,
self.insert_graph_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
uuid.uuid4(), workspace,
m.metadata.root or m.metadata.id, when,
[], entities,
),
@ -267,7 +267,7 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
async def add_document_embeddings(self, m):
async def add_document_embeddings(self, workspace, m):
when = int(time.time() * 1000)
@ -284,7 +284,7 @@ class KnowledgeTableStore:
self.cassandra,
self.insert_document_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
uuid.uuid4(), workspace,
m.metadata.root or m.metadata.id, when,
[], chunks,
),
@ -293,7 +293,7 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
async def list_kg_cores(self, user):
async def list_kg_cores(self, workspace):
logger.debug("List kg cores...")
@ -301,7 +301,7 @@ class KnowledgeTableStore:
rows = await async_execute(
self.cassandra,
self.list_cores_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -313,7 +313,7 @@ class KnowledgeTableStore:
return lst
async def delete_kg_core(self, user, document_id):
async def delete_kg_core(self, workspace, document_id):
logger.debug("Delete kg cores...")
@ -321,7 +321,7 @@ class KnowledgeTableStore:
await async_execute(
self.cassandra,
self.delete_triples_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -331,13 +331,13 @@ class KnowledgeTableStore:
await async_execute(
self.cassandra,
self.delete_graph_embeddings_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def get_triples(self, user, document_id, receiver):
async def get_triples(self, workspace, document_id, receiver):
logger.debug("Get triples...")
@ -345,7 +345,7 @@ class KnowledgeTableStore:
rows = await async_execute(
self.cassandra,
self.get_triples_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -369,7 +369,6 @@ class KnowledgeTableStore:
Triples(
metadata = Metadata(
id = document_id,
user = user,
collection = "default", # FIXME: What to put here?
),
triples = triples
@ -378,7 +377,7 @@ class KnowledgeTableStore:
logger.debug("Done")
async def get_graph_embeddings(self, user, document_id, receiver):
async def get_graph_embeddings(self, workspace, document_id, receiver):
logger.debug("Get GE...")
@ -386,7 +385,7 @@ class KnowledgeTableStore:
rows = await async_execute(
self.cassandra,
self.get_graph_embeddings_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -409,12 +408,11 @@ class KnowledgeTableStore:
GraphEmbeddings(
metadata = Metadata(
id = document_id,
user = user,
collection = "default", # FIXME: What to put here?
),
entities = entities
)
)
)
logger.debug("Done")

View file

@ -64,7 +64,7 @@ class LibraryTableStore:
self.cluster = Cluster(cassandra_host)
self.cassandra = self.cluster.connect()
logger.info("Connected.")
self.ensure_cassandra_schema()
@ -76,13 +76,13 @@ class LibraryTableStore:
logger.debug("Ensure Cassandra schema...")
logger.debug("Keyspace...")
# FIXME: Replication factor should be configurable
self.cassandra.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
}};
""");
@ -93,7 +93,7 @@ class LibraryTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS document (
id text,
user text,
workspace text,
time timestamp,
kind text,
title text,
@ -103,7 +103,9 @@ class LibraryTableStore:
>>,
tags list<text>,
object_id uuid,
PRIMARY KEY (user, id)
parent_id text,
document_type text,
PRIMARY KEY (workspace, id)
);
""");
@ -114,27 +116,6 @@ class LibraryTableStore:
ON document (object_id)
""");
# Add parent_id and document_type columns for child document support
logger.debug("document table parent_id column...")
try:
self.cassandra.execute("""
ALTER TABLE document ADD parent_id text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"parent_id column may already exist: {e}")
try:
self.cassandra.execute("""
ALTER TABLE document ADD document_type text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"document_type column may already exist: {e}")
logger.debug("document parent index...")
self.cassandra.execute("""
@ -150,10 +131,10 @@ class LibraryTableStore:
document_id text,
time timestamp,
flow text,
user text,
workspace text,
collection text,
tags list<text>,
PRIMARY KEY (user, id)
PRIMARY KEY (workspace, id)
);
""");
@ -162,7 +143,7 @@ class LibraryTableStore:
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS upload_session (
upload_id text PRIMARY KEY,
user text,
workspace text,
document_id text,
document_metadata text,
s3_upload_id text,
@ -176,11 +157,11 @@ class LibraryTableStore:
) WITH default_time_to_live = 86400;
""");
logger.debug("upload_session user index...")
logger.debug("upload_session workspace index...")
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS upload_session_user
ON upload_session (user)
CREATE INDEX IF NOT EXISTS upload_session_workspace
ON upload_session (workspace)
""");
logger.info("Cassandra schema OK.")
@ -190,7 +171,7 @@ class LibraryTableStore:
self.insert_document_stmt = self.cassandra.prepare("""
INSERT INTO document
(
id, user, time,
id, workspace, time,
kind, title, comments,
metadata, tags, object_id,
parent_id, document_type
@ -202,25 +183,25 @@ class LibraryTableStore:
UPDATE document
SET time = ?, title = ?, comments = ?,
metadata = ?, tags = ?
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.get_document_stmt = self.cassandra.prepare("""
SELECT time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.delete_document_stmt = self.cassandra.prepare("""
DELETE FROM document
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.test_document_exists_stmt = self.cassandra.prepare("""
SELECT id
FROM document
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
LIMIT 1
""")
@ -229,7 +210,7 @@ class LibraryTableStore:
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ?
WHERE workspace = ?
""")
self.list_document_by_tag_stmt = self.cassandra.prepare("""
@ -237,7 +218,7 @@ class LibraryTableStore:
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND tags CONTAINS ?
WHERE workspace = ? AND tags CONTAINS ?
ALLOW FILTERING
""")
@ -245,7 +226,7 @@ class LibraryTableStore:
INSERT INTO processing
(
id, document_id, time,
flow, user, collection,
flow, workspace, collection,
tags
)
VALUES (?, ?, ?, ?, ?, ?, ?)
@ -253,13 +234,13 @@ class LibraryTableStore:
self.delete_processing_stmt = self.cassandra.prepare("""
DELETE FROM processing
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
""")
self.test_processing_exists_stmt = self.cassandra.prepare("""
SELECT id
FROM processing
WHERE user = ? AND id = ?
WHERE workspace = ? AND id = ?
LIMIT 1
""")
@ -267,14 +248,14 @@ class LibraryTableStore:
SELECT
id, document_id, time, flow, collection, tags
FROM processing
WHERE user = ?
WHERE workspace = ?
""")
# Upload session prepared statements
self.insert_upload_session_stmt = self.cassandra.prepare("""
INSERT INTO upload_session
(
upload_id, user, document_id, document_metadata,
upload_id, workspace, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
)
@ -283,7 +264,7 @@ class LibraryTableStore:
self.get_upload_session_stmt = self.cassandra.prepare("""
SELECT
upload_id, user, document_id, document_metadata,
upload_id, workspace, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
FROM upload_session
@ -308,25 +289,25 @@ class LibraryTableStore:
total_size, chunk_size, total_chunks,
chunks_received, created_at, updated_at
FROM upload_session
WHERE user = ?
WHERE workspace = ?
""")
# Child document queries
self.list_children_stmt = self.cassandra.prepare("""
SELECT
id, user, time, kind, title, comments, metadata, tags,
id, workspace, time, kind, title, comments, metadata, tags,
object_id, parent_id, document_type
FROM document
WHERE parent_id = ?
ALLOW FILTERING
""")
async def document_exists(self, user, id):
async def document_exists(self, workspace, id):
rows = await async_execute(
self.cassandra,
self.test_document_exists_stmt,
(user, id),
(workspace, id),
)
return bool(rows)
@ -351,7 +332,7 @@ class LibraryTableStore:
self.cassandra,
self.insert_document_stmt,
(
document.id, document.user, int(document.time * 1000),
document.id, document.workspace, int(document.time * 1000),
document.kind, document.title, document.comments,
metadata, document.tags, object_id,
parent_id, document_type
@ -381,7 +362,7 @@ class LibraryTableStore:
(
int(document.time * 1000), document.title,
document.comments, metadata, document.tags,
document.user, document.id
document.workspace, document.id
),
)
except Exception:
@ -390,7 +371,7 @@ class LibraryTableStore:
logger.debug("Update complete")
async def remove_document(self, user, document_id):
async def remove_document(self, workspace, document_id):
logger.info(f"Removing document {document_id}")
@ -398,7 +379,7 @@ class LibraryTableStore:
await async_execute(
self.cassandra,
self.delete_document_stmt,
(user, document_id),
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -406,7 +387,7 @@ class LibraryTableStore:
logger.debug("Delete complete")
async def list_documents(self, user):
async def list_documents(self, workspace):
logger.debug("List documents...")
@ -414,7 +395,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.list_document_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -423,7 +404,7 @@ class LibraryTableStore:
lst = [
DocumentMetadata(
id = row[0],
user = user,
workspace = workspace,
time = int(time.mktime(row[1].timetuple())),
kind = row[2],
title = row[3],
@ -465,7 +446,7 @@ class LibraryTableStore:
lst = [
DocumentMetadata(
id = row[0],
user = row[1],
workspace = row[1],
time = int(time.mktime(row[2].timetuple())),
kind = row[3],
title = row[4],
@ -489,7 +470,7 @@ class LibraryTableStore:
return lst
async def get_document(self, user, id):
async def get_document(self, workspace, id):
logger.debug("Get document")
@ -497,7 +478,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.get_document_stmt,
(user, id),
(workspace, id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -506,7 +487,7 @@ class LibraryTableStore:
for row in rows:
doc = DocumentMetadata(
id = id,
user = user,
workspace = workspace,
time = int(time.mktime(row[0].timetuple())),
kind = row[1],
title = row[2],
@ -529,7 +510,7 @@ class LibraryTableStore:
raise RuntimeError("No such document row?")
async def get_document_object_id(self, user, id):
async def get_document_object_id(self, workspace, id):
logger.debug("Get document obj ID")
@ -537,7 +518,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.get_document_stmt,
(user, id),
(workspace, id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -549,12 +530,12 @@ class LibraryTableStore:
raise RuntimeError("No such document row?")
async def processing_exists(self, user, id):
async def processing_exists(self, workspace, id):
rows = await async_execute(
self.cassandra,
self.test_processing_exists_stmt,
(user, id),
(workspace, id),
)
return bool(rows)
@ -570,7 +551,7 @@ class LibraryTableStore:
(
processing.id, processing.document_id,
int(processing.time * 1000), processing.flow,
processing.user, processing.collection,
processing.workspace, processing.collection,
processing.tags
),
)
@ -580,7 +561,7 @@ class LibraryTableStore:
logger.debug("Add complete")
async def remove_processing(self, user, processing_id):
async def remove_processing(self, workspace, processing_id):
logger.info(f"Removing processing {processing_id}")
@ -588,7 +569,7 @@ class LibraryTableStore:
await async_execute(
self.cassandra,
self.delete_processing_stmt,
(user, processing_id),
(workspace, processing_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -596,7 +577,7 @@ class LibraryTableStore:
logger.debug("Delete complete")
async def list_processing(self, user):
async def list_processing(self, workspace):
logger.debug("List processing objects")
@ -604,7 +585,7 @@ class LibraryTableStore:
rows = await async_execute(
self.cassandra,
self.list_processing_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
@ -616,7 +597,7 @@ class LibraryTableStore:
document_id = row[1],
time = int(time.mktime(row[2].timetuple())),
flow = row[3],
user = user,
workspace = workspace,
collection = row[4],
tags = row[5] if row[5] else [],
)
@ -632,7 +613,7 @@ class LibraryTableStore:
async def create_upload_session(
self,
upload_id,
user,
workspace,
document_id,
document_metadata,
s3_upload_id,
@ -652,7 +633,7 @@ class LibraryTableStore:
self.cassandra,
self.insert_upload_session_stmt,
(
upload_id, user, document_id, document_metadata,
upload_id, workspace, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, {}, now, now
),
@ -681,7 +662,7 @@ class LibraryTableStore:
for row in rows:
session = {
"upload_id": row[0],
"user": row[1],
"workspace": row[1],
"document_id": row[2],
"document_metadata": row[3],
"s3_upload_id": row[4],
@ -738,16 +719,16 @@ class LibraryTableStore:
logger.debug("Upload session deleted")
async def list_upload_sessions(self, user):
"""List all upload sessions for a user."""
async def list_upload_sessions(self, workspace):
"""List all upload sessions for a workspace."""
logger.debug(f"List upload sessions for {user}")
logger.debug(f"List upload sessions for {workspace}")
try:
rows = await async_execute(
self.cassandra,
self.list_upload_sessions_stmt,
(user,),
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)

View file

@ -2,7 +2,6 @@
Joke Tool Service - An example dynamic tool service.
This service demonstrates the tool service integration by:
- Using the 'user' field to personalize responses
- Using config params (style) to customize joke style
- Using arguments (topic) to generate topic-specific jokes
@ -143,17 +142,16 @@ class Processor(DynamicToolService):
super(Processor, self).__init__(**params)
logger.info("Joke service initialized")
async def invoke(self, user, config, arguments):
async def invoke(self, config, arguments):
"""
Generate a joke based on the topic and style.
Args:
user: The user requesting the joke
config: Config values including 'style' (pun, dad-joke, one-liner)
arguments: Arguments including 'topic' (programming, animals, food)
Returns:
A personalized joke string
A joke string
"""
# Get style from config (default: random)
style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"]))
@ -183,10 +181,9 @@ class Processor(DynamicToolService):
# Pick a random joke
joke = random.choice(jokes)
# Personalize the response
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
response = f"Here's a {style} for you:\n\n{joke}"
logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}")
logger.debug(f"Generated joke: style={style}, topic={topic}")
return response