From b615150624c1223ffbea506c19fec4516ed50b6e Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 18 Apr 2026 22:42:24 +0100 Subject: [PATCH 01/21] fix: resolve multiple Kafka backend issues blocking message delivery (#833) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Producer: add delivery callback to surface send errors instead of silently swallowing them, and raise message.max.bytes to 10MB - Consumer: raise fetch.message.max.bytes to 10MB to match producer, tighten session/heartbeat timeouts for fast group joins, and add partition assign/revoke logging for diagnostics - Topic naming: replace colons with dots in topic names since Kafka rejects colons (flow:tg:document-load:default was producing invalid topic name tg.flow.document-load:default) - Response consumers: use auto.offset.reset=earliest instead of latest so responses published before partition assignment aren't lost - UNKNOWN_TOPIC_OR_PART: treat as timeout instead of fatal error so consumers wait for auto-created topics instead of crashing - Concurrency: cap consumer workers to 1 for Kafka since topics have 1 partition — extra consumers trigger rebalance storms that block all message delivery --- trustgraph-base/trustgraph/base/consumer.py | 11 +++ .../trustgraph/base/kafka_backend.py | 70 ++++++++++++++++--- 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 1b9e5999..d5c67d1b 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -54,6 +54,17 @@ class Consumer: self.running = True self.consumer_task = None + # Kafka topics are created with 1 partition, so multiple + # consumers in the same group causes rebalance storms where + # no consumer can fetch. Cap to the backend's limit. + max_concurrency = getattr(backend, 'max_consumer_concurrency', None) + if max_concurrency is not None and concurrency > max_concurrency: + logger.info( + f"Capping concurrency from {concurrency} to " + f"{max_concurrency} (backend limit)" + ) + concurrency = max_concurrency + self.concurrency = concurrency self.metrics = metrics diff --git a/trustgraph-base/trustgraph/base/kafka_backend.py b/trustgraph-base/trustgraph/base/kafka_backend.py index 8dfe8bfa..cec5f74f 100644 --- a/trustgraph-base/trustgraph/base/kafka_backend.py +++ b/trustgraph-base/trustgraph/base/kafka_backend.py @@ -83,6 +83,7 @@ class KafkaBackendProducer: self._producer = KafkaProducer({ 'bootstrap.servers': bootstrap_servers, 'acks': 'all' if durable else '1', + 'message.max.bytes': 10485760, }) def send(self, message: Any, properties: dict = {}) -> None: @@ -94,13 +95,23 @@ class KafkaBackendProducer: for k, v in properties.items() ] if properties else None + self._delivery_error = None + + def _on_delivery(err, msg): + if err: + self._delivery_error = err + self._producer.produce( topic=self._topic_name, value=json_data, headers=headers, + on_delivery=_on_delivery, ) self._producer.flush() + if self._delivery_error: + raise KafkaException(self._delivery_error) + def flush(self) -> None: self._producer.flush() @@ -126,15 +137,41 @@ class KafkaBackendConsumer: self._consumer = None def _connect(self): + import time + t0 = time.monotonic() + + def _on_assign(consumer, partitions): + elapsed = time.monotonic() - t0 + logger.info( + f"Partition assignment for {self._topic_name}: " + f"{[p.partition for p in partitions]} " + f"after {elapsed:.1f}s" + ) + + def _on_revoke(consumer, partitions): + logger.info( + f"Partition revoke for {self._topic_name}: " + f"{[p.partition for p in partitions]}" + ) + self._consumer = KafkaConsumer({ 'bootstrap.servers': self._bootstrap_servers, 'group.id': self._group_id, 'auto.offset.reset': self._auto_offset_reset, 'enable.auto.commit': False, + 'fetch.message.max.bytes': 10485760, + # Tighten group coordination timeouts for fast + # group join on single-member groups. + 'session.timeout.ms': 6000, + 'heartbeat.interval.ms': 1000, }) - self._consumer.subscribe([self._topic_name]) + self._consumer.subscribe( + [self._topic_name], + on_assign=_on_assign, + on_revoke=_on_revoke, + ) logger.info( - f"Kafka consumer connected: topic={self._topic_name}, " + f"Kafka consumer subscribed: topic={self._topic_name}, " f"group={self._group_id}" ) @@ -151,11 +188,11 @@ class KafkaBackendConsumer: if not self._is_alive(): self._connect() - # Force a partition assignment by polling briefly. - # Without this, the consumer may not be assigned partitions - # until the first real poll(), creating a race where the - # request is sent before assignment completes. - self._consumer.poll(timeout=1.0) + # Kick off group join. With auto.offset.reset=earliest + # on response/notify consumers, any messages published + # before assignment completes will be picked up once + # the consumer starts polling in receive(). + self._consumer.poll(timeout=0.5) def receive(self, timeout_millis: int = 2000) -> Message: """Receive a message. Raises TimeoutError if none available.""" @@ -172,6 +209,8 @@ class KafkaBackendConsumer: error = msg.error() if error.code() == KafkaError._PARTITION_EOF: raise TimeoutError("End of partition reached") + if error.code() == KafkaError.UNKNOWN_TOPIC_OR_PART: + raise TimeoutError("Topic not yet available") raise KafkaException(error) return KafkaMessage(msg, self._schema_cls) @@ -236,6 +275,11 @@ class KafkaBackend: if sasl_password: self._admin_config['sasl.password'] = sasl_password + # Topics are created with 1 partition, so only 1 consumer + # per group can be active. Extra consumers cause rebalance + # storms that block message delivery. + self.max_consumer_concurrency = 1 + logger.info( f"Kafka backend: {bootstrap_servers} " f"protocol={security_protocol}" @@ -270,7 +314,10 @@ class KafkaBackend: f"expected flow, request, response, or notify" ) - topic_name = f"{topicspace}.{cls}.{topic}" + # Replace any remaining colons — flow topics can have + # extra segments (e.g. flow:tg:document-load:default) + # and Kafka rejects colons in topic names. + topic_name = f"{topicspace}.{cls}.{topic}".replace(':', '.') return topic_name, cls, durable @@ -305,8 +352,13 @@ class KafkaBackend: # Per-subscriber: unique group so every instance sees # every message. Filter by correlation ID happens at # the Subscriber layer above. + # Use 'earliest' so that responses published before + # partition assignment completes are not missed. + # Each group is unique (UUID) with no committed offsets, + # so 'earliest' reads from the start of the topic. + # The correlation ID filter discards non-matching messages. group_id = f"{subscription}-{uuid.uuid4()}" - auto_offset_reset = 'latest' + auto_offset_reset = 'earliest' else: # Shared: named group, competing consumers group_id = subscription From 48da6c5f8b2171f67dcec1f52e9e3f83bf411520 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 18 Apr 2026 23:06:01 +0100 Subject: [PATCH 02/21] Test fixes for Kafka (#834) Test fixes for Kafka: - Consumer: isinstance(max_concurrency, int) instead of is not None: MagicMock won't pass the check - Kafka test: updated expected topic name to tg.request.prompt.default (colons replaced with dots) --- tests/unit/test_pubsub/test_kafka_backend.py | 4 ++-- trustgraph-base/trustgraph/base/consumer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_pubsub/test_kafka_backend.py b/tests/unit/test_pubsub/test_kafka_backend.py index 456386f0..d51b1817 100644 --- a/tests/unit/test_pubsub/test_kafka_backend.py +++ b/tests/unit/test_pubsub/test_kafka_backend.py @@ -57,9 +57,9 @@ class TestKafkaParseTopic: backend._parse_topic('unknown:tg:topic') def test_topic_with_flow_suffix(self, backend): - """Topic names with flow suffix (e.g. :default) are preserved.""" + """Topic names with flow suffix (e.g. :default) have colons replaced with dots.""" name, cls, durable = backend._parse_topic('request:tg:prompt:default') - assert name == 'tg.request.prompt:default' + assert name == 'tg.request.prompt.default' class TestKafkaRetention: diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index d5c67d1b..5c59c515 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -58,7 +58,7 @@ class Consumer: # consumers in the same group causes rebalance storms where # no consumer can fetch. Cap to the backend's limit. max_concurrency = getattr(backend, 'max_consumer_concurrency', None) - if max_concurrency is not None and concurrency > max_concurrency: + if isinstance(max_concurrency, int) and concurrency > max_concurrency: logger.info( f"Capping concurrency from {concurrency} to " f"{max_concurrency} (backend limit)" From e7efb673ef1cad3c7f26f50f141bf4e250434db8 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 21 Apr 2026 16:06:41 +0100 Subject: [PATCH 03/21] Structure the tech specs directory (#836) Tech spec some subdirectories for different languages --- docs/tech-specs/{ => ar}/__TEMPLATE.ar.md | 0 docs/tech-specs/{ => ar}/agent-explainability.ar.md | 0 docs/tech-specs/{ => ar}/architecture-principles.ar.md | 0 docs/tech-specs/{ => ar}/cassandra-consolidation.ar.md | 0 docs/tech-specs/{ => ar}/cassandra-performance-refactor.ar.md | 0 docs/tech-specs/{ => ar}/collection-management.ar.md | 0 docs/tech-specs/{ => ar}/document-embeddings-chunk-id.ar.md | 0 docs/tech-specs/{ => ar}/embeddings-batch-processing.ar.md | 0 docs/tech-specs/{ => ar}/entity-centric-graph.ar.md | 0 docs/tech-specs/{ => ar}/explainability-cli.ar.md | 0 docs/tech-specs/{ => ar}/extraction-flows.ar.md | 0 docs/tech-specs/{ => ar}/extraction-provenance-subgraph.ar.md | 0 docs/tech-specs/{ => ar}/extraction-time-provenance.ar.md | 0 docs/tech-specs/{ => ar}/flow-class-definition.ar.md | 0 docs/tech-specs/{ => ar}/flow-configurable-parameters.ar.md | 0 docs/tech-specs/{ => ar}/graph-contexts.ar.md | 0 docs/tech-specs/{ => ar}/graphql-query.ar.md | 0 docs/tech-specs/{ => ar}/graphrag-performance-optimization.ar.md | 0 docs/tech-specs/{ => ar}/import-export-graceful-shutdown.ar.md | 0 docs/tech-specs/{ => ar}/jsonl-prompt-output.ar.md | 0 docs/tech-specs/{ => ar}/large-document-loading.ar.md | 0 docs/tech-specs/{ => ar}/logging-strategy.ar.md | 0 docs/tech-specs/{ => ar}/mcp-tool-arguments.ar.md | 0 docs/tech-specs/{ => ar}/mcp-tool-bearer-token.ar.md | 0 docs/tech-specs/{ => ar}/minio-to-s3-migration.ar.md | 0 docs/tech-specs/{ => ar}/more-config-cli.ar.md | 0 docs/tech-specs/{ => ar}/multi-tenant-support.ar.md | 0 docs/tech-specs/{ => ar}/neo4j-user-collection-isolation.ar.md | 0 docs/tech-specs/{ => ar}/ontology-extract-phase-2.ar.md | 0 docs/tech-specs/{ => ar}/ontology.ar.md | 0 docs/tech-specs/{ => ar}/ontorag.ar.md | 0 docs/tech-specs/{ => ar}/openapi-spec.ar.md | 0 docs/tech-specs/{ => ar}/pubsub.ar.md | 0 docs/tech-specs/{ => ar}/python-api-refactor.ar.md | 0 docs/tech-specs/{ => ar}/query-time-explainability.ar.md | 0 docs/tech-specs/{ => ar}/rag-streaming-support.ar.md | 0 docs/tech-specs/{ => ar}/schema-refactoring-proposal.ar.md | 0 docs/tech-specs/{ => ar}/streaming-llm-responses.ar.md | 0 docs/tech-specs/{ => ar}/structured-data-2.ar.md | 0 docs/tech-specs/{ => ar}/structured-data-descriptor.ar.md | 0 docs/tech-specs/{ => ar}/structured-data-schemas.ar.md | 0 docs/tech-specs/{ => ar}/structured-data.ar.md | 0 docs/tech-specs/{ => ar}/structured-diag-service.ar.md | 0 docs/tech-specs/{ => ar}/tool-group.ar.md | 0 docs/tech-specs/{ => ar}/tool-services.ar.md | 0 docs/tech-specs/{ => ar}/universal-decoder.ar.md | 0 docs/tech-specs/{ => ar}/vector-store-lifecycle.ar.md | 0 docs/tech-specs/{ => es}/__TEMPLATE.es.md | 0 docs/tech-specs/{ => es}/agent-explainability.es.md | 0 docs/tech-specs/{ => es}/architecture-principles.es.md | 0 docs/tech-specs/{ => es}/cassandra-consolidation.es.md | 0 docs/tech-specs/{ => es}/cassandra-performance-refactor.es.md | 0 docs/tech-specs/{ => es}/collection-management.es.md | 0 docs/tech-specs/{ => es}/document-embeddings-chunk-id.es.md | 0 docs/tech-specs/{ => es}/embeddings-batch-processing.es.md | 0 docs/tech-specs/{ => es}/entity-centric-graph.es.md | 0 docs/tech-specs/{ => es}/explainability-cli.es.md | 0 docs/tech-specs/{ => es}/extraction-flows.es.md | 0 docs/tech-specs/{ => es}/extraction-provenance-subgraph.es.md | 0 docs/tech-specs/{ => es}/extraction-time-provenance.es.md | 0 docs/tech-specs/{ => es}/flow-class-definition.es.md | 0 docs/tech-specs/{ => es}/flow-configurable-parameters.es.md | 0 docs/tech-specs/{ => es}/graph-contexts.es.md | 0 docs/tech-specs/{ => es}/graphql-query.es.md | 0 docs/tech-specs/{ => es}/graphrag-performance-optimization.es.md | 0 docs/tech-specs/{ => es}/import-export-graceful-shutdown.es.md | 0 docs/tech-specs/{ => es}/jsonl-prompt-output.es.md | 0 docs/tech-specs/{ => es}/large-document-loading.es.md | 0 docs/tech-specs/{ => es}/logging-strategy.es.md | 0 docs/tech-specs/{ => es}/mcp-tool-arguments.es.md | 0 docs/tech-specs/{ => es}/mcp-tool-bearer-token.es.md | 0 docs/tech-specs/{ => es}/minio-to-s3-migration.es.md | 0 docs/tech-specs/{ => es}/more-config-cli.es.md | 0 docs/tech-specs/{ => es}/multi-tenant-support.es.md | 0 docs/tech-specs/{ => es}/neo4j-user-collection-isolation.es.md | 0 docs/tech-specs/{ => es}/ontology-extract-phase-2.es.md | 0 docs/tech-specs/{ => es}/ontology.es.md | 0 docs/tech-specs/{ => es}/ontorag.es.md | 0 docs/tech-specs/{ => es}/openapi-spec.es.md | 0 docs/tech-specs/{ => es}/pubsub.es.md | 0 docs/tech-specs/{ => es}/python-api-refactor.es.md | 0 docs/tech-specs/{ => es}/query-time-explainability.es.md | 0 docs/tech-specs/{ => es}/rag-streaming-support.es.md | 0 docs/tech-specs/{ => es}/schema-refactoring-proposal.es.md | 0 docs/tech-specs/{ => es}/streaming-llm-responses.es.md | 0 docs/tech-specs/{ => es}/structured-data-2.es.md | 0 docs/tech-specs/{ => es}/structured-data-descriptor.es.md | 0 docs/tech-specs/{ => es}/structured-data-schemas.es.md | 0 docs/tech-specs/{ => es}/structured-data.es.md | 0 docs/tech-specs/{ => es}/structured-diag-service.es.md | 0 docs/tech-specs/{ => es}/tool-group.es.md | 0 docs/tech-specs/{ => es}/tool-services.es.md | 0 docs/tech-specs/{ => es}/universal-decoder.es.md | 0 docs/tech-specs/{ => es}/vector-store-lifecycle.es.md | 0 docs/tech-specs/{ => he}/__TEMPLATE.he.md | 0 docs/tech-specs/{ => he}/agent-explainability.he.md | 0 docs/tech-specs/{ => he}/architecture-principles.he.md | 0 docs/tech-specs/{ => he}/cassandra-consolidation.he.md | 0 docs/tech-specs/{ => he}/cassandra-performance-refactor.he.md | 0 docs/tech-specs/{ => he}/collection-management.he.md | 0 docs/tech-specs/{ => he}/document-embeddings-chunk-id.he.md | 0 docs/tech-specs/{ => he}/embeddings-batch-processing.he.md | 0 docs/tech-specs/{ => he}/entity-centric-graph.he.md | 0 docs/tech-specs/{ => he}/explainability-cli.he.md | 0 docs/tech-specs/{ => he}/extraction-flows.he.md | 0 docs/tech-specs/{ => he}/extraction-provenance-subgraph.he.md | 0 docs/tech-specs/{ => he}/extraction-time-provenance.he.md | 0 docs/tech-specs/{ => he}/flow-class-definition.he.md | 0 docs/tech-specs/{ => he}/flow-configurable-parameters.he.md | 0 docs/tech-specs/{ => he}/graph-contexts.he.md | 0 docs/tech-specs/{ => he}/graphql-query.he.md | 0 docs/tech-specs/{ => he}/graphrag-performance-optimization.he.md | 0 docs/tech-specs/{ => he}/import-export-graceful-shutdown.he.md | 0 docs/tech-specs/{ => he}/jsonl-prompt-output.he.md | 0 docs/tech-specs/{ => he}/large-document-loading.he.md | 0 docs/tech-specs/{ => he}/logging-strategy.he.md | 0 docs/tech-specs/{ => he}/mcp-tool-arguments.he.md | 0 docs/tech-specs/{ => he}/mcp-tool-bearer-token.he.md | 0 docs/tech-specs/{ => he}/minio-to-s3-migration.he.md | 0 docs/tech-specs/{ => he}/more-config-cli.he.md | 0 docs/tech-specs/{ => he}/multi-tenant-support.he.md | 0 docs/tech-specs/{ => he}/neo4j-user-collection-isolation.he.md | 0 docs/tech-specs/{ => he}/ontology-extract-phase-2.he.md | 0 docs/tech-specs/{ => he}/ontology.he.md | 0 docs/tech-specs/{ => he}/ontorag.he.md | 0 docs/tech-specs/{ => he}/openapi-spec.he.md | 0 docs/tech-specs/{ => he}/pubsub.he.md | 0 docs/tech-specs/{ => he}/python-api-refactor.he.md | 0 docs/tech-specs/{ => he}/query-time-explainability.he.md | 0 docs/tech-specs/{ => he}/rag-streaming-support.he.md | 0 docs/tech-specs/{ => he}/schema-refactoring-proposal.he.md | 0 docs/tech-specs/{ => he}/streaming-llm-responses.he.md | 0 docs/tech-specs/{ => he}/structured-data-2.he.md | 0 docs/tech-specs/{ => he}/structured-data-descriptor.he.md | 0 docs/tech-specs/{ => he}/structured-data-schemas.he.md | 0 docs/tech-specs/{ => he}/structured-data.he.md | 0 docs/tech-specs/{ => he}/structured-diag-service.he.md | 0 docs/tech-specs/{ => he}/tool-group.he.md | 0 docs/tech-specs/{ => he}/tool-services.he.md | 0 docs/tech-specs/{ => he}/universal-decoder.he.md | 0 docs/tech-specs/{ => he}/vector-store-lifecycle.he.md | 0 docs/tech-specs/{ => hi}/__TEMPLATE.hi.md | 0 docs/tech-specs/{ => hi}/agent-explainability.hi.md | 0 docs/tech-specs/{ => hi}/architecture-principles.hi.md | 0 docs/tech-specs/{ => hi}/cassandra-consolidation.hi.md | 0 docs/tech-specs/{ => hi}/cassandra-performance-refactor.hi.md | 0 docs/tech-specs/{ => hi}/collection-management.hi.md | 0 docs/tech-specs/{ => hi}/document-embeddings-chunk-id.hi.md | 0 docs/tech-specs/{ => hi}/embeddings-batch-processing.hi.md | 0 docs/tech-specs/{ => hi}/entity-centric-graph.hi.md | 0 docs/tech-specs/{ => hi}/explainability-cli.hi.md | 0 docs/tech-specs/{ => hi}/extraction-flows.hi.md | 0 docs/tech-specs/{ => hi}/extraction-provenance-subgraph.hi.md | 0 docs/tech-specs/{ => hi}/extraction-time-provenance.hi.md | 0 docs/tech-specs/{ => hi}/flow-class-definition.hi.md | 0 docs/tech-specs/{ => hi}/flow-configurable-parameters.hi.md | 0 docs/tech-specs/{ => hi}/graph-contexts.hi.md | 0 docs/tech-specs/{ => hi}/graphql-query.hi.md | 0 docs/tech-specs/{ => hi}/graphrag-performance-optimization.hi.md | 0 docs/tech-specs/{ => hi}/import-export-graceful-shutdown.hi.md | 0 docs/tech-specs/{ => hi}/jsonl-prompt-output.hi.md | 0 docs/tech-specs/{ => hi}/large-document-loading.hi.md | 0 docs/tech-specs/{ => hi}/logging-strategy.hi.md | 0 docs/tech-specs/{ => hi}/mcp-tool-arguments.hi.md | 0 docs/tech-specs/{ => hi}/mcp-tool-bearer-token.hi.md | 0 docs/tech-specs/{ => hi}/minio-to-s3-migration.hi.md | 0 docs/tech-specs/{ => hi}/more-config-cli.hi.md | 0 docs/tech-specs/{ => hi}/multi-tenant-support.hi.md | 0 docs/tech-specs/{ => hi}/neo4j-user-collection-isolation.hi.md | 0 docs/tech-specs/{ => hi}/ontology-extract-phase-2.hi.md | 0 docs/tech-specs/{ => hi}/ontology.hi.md | 0 docs/tech-specs/{ => hi}/ontorag.hi.md | 0 docs/tech-specs/{ => hi}/openapi-spec.hi.md | 0 docs/tech-specs/{ => hi}/pubsub.hi.md | 0 docs/tech-specs/{ => hi}/python-api-refactor.hi.md | 0 docs/tech-specs/{ => hi}/query-time-explainability.hi.md | 0 docs/tech-specs/{ => hi}/rag-streaming-support.hi.md | 0 docs/tech-specs/{ => hi}/schema-refactoring-proposal.hi.md | 0 docs/tech-specs/{ => hi}/streaming-llm-responses.hi.md | 0 docs/tech-specs/{ => hi}/structured-data-2.hi.md | 0 docs/tech-specs/{ => hi}/structured-data-descriptor.hi.md | 0 docs/tech-specs/{ => hi}/structured-data-schemas.hi.md | 0 docs/tech-specs/{ => hi}/structured-data.hi.md | 0 docs/tech-specs/{ => hi}/structured-diag-service.hi.md | 0 docs/tech-specs/{ => hi}/tool-group.hi.md | 0 docs/tech-specs/{ => hi}/tool-services.hi.md | 0 docs/tech-specs/{ => hi}/universal-decoder.hi.md | 0 docs/tech-specs/{ => hi}/vector-store-lifecycle.hi.md | 0 docs/tech-specs/{ => pt}/__TEMPLATE.pt.md | 0 docs/tech-specs/{ => pt}/agent-explainability.pt.md | 0 docs/tech-specs/{ => pt}/architecture-principles.pt.md | 0 docs/tech-specs/{ => pt}/cassandra-consolidation.pt.md | 0 docs/tech-specs/{ => pt}/cassandra-performance-refactor.pt.md | 0 docs/tech-specs/{ => pt}/collection-management.pt.md | 0 docs/tech-specs/{ => pt}/document-embeddings-chunk-id.pt.md | 0 docs/tech-specs/{ => pt}/embeddings-batch-processing.pt.md | 0 docs/tech-specs/{ => pt}/entity-centric-graph.pt.md | 0 docs/tech-specs/{ => pt}/explainability-cli.pt.md | 0 docs/tech-specs/{ => pt}/extraction-flows.pt.md | 0 docs/tech-specs/{ => pt}/extraction-provenance-subgraph.pt.md | 0 docs/tech-specs/{ => pt}/extraction-time-provenance.pt.md | 0 docs/tech-specs/{ => pt}/flow-class-definition.pt.md | 0 docs/tech-specs/{ => pt}/flow-configurable-parameters.pt.md | 0 docs/tech-specs/{ => pt}/graph-contexts.pt.md | 0 docs/tech-specs/{ => pt}/graphql-query.pt.md | 0 docs/tech-specs/{ => pt}/graphrag-performance-optimization.pt.md | 0 docs/tech-specs/{ => pt}/import-export-graceful-shutdown.pt.md | 0 docs/tech-specs/{ => pt}/jsonl-prompt-output.pt.md | 0 docs/tech-specs/{ => pt}/large-document-loading.pt.md | 0 docs/tech-specs/{ => pt}/logging-strategy.pt.md | 0 docs/tech-specs/{ => pt}/mcp-tool-arguments.pt.md | 0 docs/tech-specs/{ => pt}/mcp-tool-bearer-token.pt.md | 0 docs/tech-specs/{ => pt}/minio-to-s3-migration.pt.md | 0 docs/tech-specs/{ => pt}/more-config-cli.pt.md | 0 docs/tech-specs/{ => pt}/multi-tenant-support.pt.md | 0 docs/tech-specs/{ => pt}/neo4j-user-collection-isolation.pt.md | 0 docs/tech-specs/{ => pt}/ontology-extract-phase-2.pt.md | 0 docs/tech-specs/{ => pt}/ontology.pt.md | 0 docs/tech-specs/{ => pt}/ontorag.pt.md | 0 docs/tech-specs/{ => pt}/openapi-spec.pt.md | 0 docs/tech-specs/{ => pt}/pubsub.pt.md | 0 docs/tech-specs/{ => pt}/python-api-refactor.pt.md | 0 docs/tech-specs/{ => pt}/query-time-explainability.pt.md | 0 docs/tech-specs/{ => pt}/rag-streaming-support.pt.md | 0 docs/tech-specs/{ => pt}/schema-refactoring-proposal.pt.md | 0 docs/tech-specs/{ => pt}/streaming-llm-responses.pt.md | 0 docs/tech-specs/{ => pt}/structured-data-2.pt.md | 0 docs/tech-specs/{ => pt}/structured-data-descriptor.pt.md | 0 docs/tech-specs/{ => pt}/structured-data-schemas.pt.md | 0 docs/tech-specs/{ => pt}/structured-data.pt.md | 0 docs/tech-specs/{ => pt}/structured-diag-service.pt.md | 0 docs/tech-specs/{ => pt}/tool-group.pt.md | 0 docs/tech-specs/{ => pt}/tool-services.pt.md | 0 docs/tech-specs/{ => pt}/universal-decoder.pt.md | 0 docs/tech-specs/{ => pt}/vector-store-lifecycle.pt.md | 0 docs/tech-specs/{ => ru}/__TEMPLATE.ru.md | 0 docs/tech-specs/{ => ru}/agent-explainability.ru.md | 0 docs/tech-specs/{ => ru}/architecture-principles.ru.md | 0 docs/tech-specs/{ => ru}/cassandra-consolidation.ru.md | 0 docs/tech-specs/{ => ru}/cassandra-performance-refactor.ru.md | 0 docs/tech-specs/{ => ru}/collection-management.ru.md | 0 docs/tech-specs/{ => ru}/document-embeddings-chunk-id.ru.md | 0 docs/tech-specs/{ => ru}/embeddings-batch-processing.ru.md | 0 docs/tech-specs/{ => ru}/entity-centric-graph.ru.md | 0 docs/tech-specs/{ => ru}/explainability-cli.ru.md | 0 docs/tech-specs/{ => ru}/extraction-flows.ru.md | 0 docs/tech-specs/{ => ru}/extraction-provenance-subgraph.ru.md | 0 docs/tech-specs/{ => ru}/extraction-time-provenance.ru.md | 0 docs/tech-specs/{ => ru}/flow-class-definition.ru.md | 0 docs/tech-specs/{ => ru}/flow-configurable-parameters.ru.md | 0 docs/tech-specs/{ => ru}/graph-contexts.ru.md | 0 docs/tech-specs/{ => ru}/graphql-query.ru.md | 0 docs/tech-specs/{ => ru}/graphrag-performance-optimization.ru.md | 0 docs/tech-specs/{ => ru}/import-export-graceful-shutdown.ru.md | 0 docs/tech-specs/{ => ru}/jsonl-prompt-output.ru.md | 0 docs/tech-specs/{ => ru}/large-document-loading.ru.md | 0 docs/tech-specs/{ => ru}/logging-strategy.ru.md | 0 docs/tech-specs/{ => ru}/mcp-tool-arguments.ru.md | 0 docs/tech-specs/{ => ru}/mcp-tool-bearer-token.ru.md | 0 docs/tech-specs/{ => ru}/minio-to-s3-migration.ru.md | 0 docs/tech-specs/{ => ru}/more-config-cli.ru.md | 0 docs/tech-specs/{ => ru}/multi-tenant-support.ru.md | 0 docs/tech-specs/{ => ru}/neo4j-user-collection-isolation.ru.md | 0 docs/tech-specs/{ => ru}/ontology-extract-phase-2.ru.md | 0 docs/tech-specs/{ => ru}/ontology.ru.md | 0 docs/tech-specs/{ => ru}/ontorag.ru.md | 0 docs/tech-specs/{ => ru}/openapi-spec.ru.md | 0 docs/tech-specs/{ => ru}/pubsub.ru.md | 0 docs/tech-specs/{ => ru}/python-api-refactor.ru.md | 0 docs/tech-specs/{ => ru}/query-time-explainability.ru.md | 0 docs/tech-specs/{ => ru}/rag-streaming-support.ru.md | 0 docs/tech-specs/{ => ru}/schema-refactoring-proposal.ru.md | 0 docs/tech-specs/{ => ru}/streaming-llm-responses.ru.md | 0 docs/tech-specs/{ => ru}/structured-data-2.ru.md | 0 docs/tech-specs/{ => ru}/structured-data-descriptor.ru.md | 0 docs/tech-specs/{ => ru}/structured-data-schemas.ru.md | 0 docs/tech-specs/{ => ru}/structured-data.ru.md | 0 docs/tech-specs/{ => ru}/structured-diag-service.ru.md | 0 docs/tech-specs/{ => ru}/tool-group.ru.md | 0 docs/tech-specs/{ => ru}/tool-services.ru.md | 0 docs/tech-specs/{ => ru}/universal-decoder.ru.md | 0 docs/tech-specs/{ => ru}/vector-store-lifecycle.ru.md | 0 docs/tech-specs/{ => sw}/__TEMPLATE.sw.md | 0 docs/tech-specs/{ => sw}/agent-explainability.sw.md | 0 docs/tech-specs/{ => sw}/architecture-principles.sw.md | 0 docs/tech-specs/{ => sw}/cassandra-consolidation.sw.md | 0 docs/tech-specs/{ => sw}/cassandra-performance-refactor.sw.md | 0 docs/tech-specs/{ => sw}/collection-management.sw.md | 0 docs/tech-specs/{ => sw}/document-embeddings-chunk-id.sw.md | 0 docs/tech-specs/{ => sw}/embeddings-batch-processing.sw.md | 0 docs/tech-specs/{ => sw}/entity-centric-graph.sw.md | 0 docs/tech-specs/{ => sw}/explainability-cli.sw.md | 0 docs/tech-specs/{ => sw}/extraction-flows.sw.md | 0 docs/tech-specs/{ => sw}/extraction-provenance-subgraph.sw.md | 0 docs/tech-specs/{ => sw}/extraction-time-provenance.sw.md | 0 docs/tech-specs/{ => sw}/flow-class-definition.sw.md | 0 docs/tech-specs/{ => sw}/flow-configurable-parameters.sw.md | 0 docs/tech-specs/{ => sw}/graph-contexts.sw.md | 0 docs/tech-specs/{ => sw}/graphql-query.sw.md | 0 docs/tech-specs/{ => sw}/graphrag-performance-optimization.sw.md | 0 docs/tech-specs/{ => sw}/import-export-graceful-shutdown.sw.md | 0 docs/tech-specs/{ => sw}/jsonl-prompt-output.sw.md | 0 docs/tech-specs/{ => sw}/large-document-loading.sw.md | 0 docs/tech-specs/{ => sw}/logging-strategy.sw.md | 0 docs/tech-specs/{ => sw}/mcp-tool-arguments.sw.md | 0 docs/tech-specs/{ => sw}/mcp-tool-bearer-token.sw.md | 0 docs/tech-specs/{ => sw}/minio-to-s3-migration.sw.md | 0 docs/tech-specs/{ => sw}/more-config-cli.sw.md | 0 docs/tech-specs/{ => sw}/multi-tenant-support.sw.md | 0 docs/tech-specs/{ => sw}/neo4j-user-collection-isolation.sw.md | 0 docs/tech-specs/{ => sw}/ontology-extract-phase-2.sw.md | 0 docs/tech-specs/{ => sw}/ontology.sw.md | 0 docs/tech-specs/{ => sw}/ontorag.sw.md | 0 docs/tech-specs/{ => sw}/openapi-spec.sw.md | 0 docs/tech-specs/{ => sw}/pubsub.sw.md | 0 docs/tech-specs/{ => sw}/python-api-refactor.sw.md | 0 docs/tech-specs/{ => sw}/query-time-explainability.sw.md | 0 docs/tech-specs/{ => sw}/rag-streaming-support.sw.md | 0 docs/tech-specs/{ => sw}/schema-refactoring-proposal.sw.md | 0 docs/tech-specs/{ => sw}/streaming-llm-responses.sw.md | 0 docs/tech-specs/{ => sw}/structured-data-2.sw.md | 0 docs/tech-specs/{ => sw}/structured-data-descriptor.sw.md | 0 docs/tech-specs/{ => sw}/structured-data-schemas.sw.md | 0 docs/tech-specs/{ => sw}/structured-data.sw.md | 0 docs/tech-specs/{ => sw}/structured-diag-service.sw.md | 0 docs/tech-specs/{ => sw}/tool-group.sw.md | 0 docs/tech-specs/{ => sw}/tool-services.sw.md | 0 docs/tech-specs/{ => sw}/universal-decoder.sw.md | 0 docs/tech-specs/{ => sw}/vector-store-lifecycle.sw.md | 0 docs/tech-specs/{ => tr}/__TEMPLATE.tr.md | 0 docs/tech-specs/{ => tr}/agent-explainability.tr.md | 0 docs/tech-specs/{ => tr}/architecture-principles.tr.md | 0 docs/tech-specs/{ => tr}/cassandra-consolidation.tr.md | 0 docs/tech-specs/{ => tr}/cassandra-performance-refactor.tr.md | 0 docs/tech-specs/{ => tr}/collection-management.tr.md | 0 docs/tech-specs/{ => tr}/document-embeddings-chunk-id.tr.md | 0 docs/tech-specs/{ => tr}/embeddings-batch-processing.tr.md | 0 docs/tech-specs/{ => tr}/entity-centric-graph.tr.md | 0 docs/tech-specs/{ => tr}/explainability-cli.tr.md | 0 docs/tech-specs/{ => tr}/extraction-flows.tr.md | 0 docs/tech-specs/{ => tr}/extraction-provenance-subgraph.tr.md | 0 docs/tech-specs/{ => tr}/extraction-time-provenance.tr.md | 0 docs/tech-specs/{ => tr}/flow-class-definition.tr.md | 0 docs/tech-specs/{ => tr}/flow-configurable-parameters.tr.md | 0 docs/tech-specs/{ => tr}/graph-contexts.tr.md | 0 docs/tech-specs/{ => tr}/graphql-query.tr.md | 0 docs/tech-specs/{ => tr}/graphrag-performance-optimization.tr.md | 0 docs/tech-specs/{ => tr}/import-export-graceful-shutdown.tr.md | 0 docs/tech-specs/{ => tr}/jsonl-prompt-output.tr.md | 0 docs/tech-specs/{ => tr}/large-document-loading.tr.md | 0 docs/tech-specs/{ => tr}/logging-strategy.tr.md | 0 docs/tech-specs/{ => tr}/mcp-tool-arguments.tr.md | 0 docs/tech-specs/{ => tr}/mcp-tool-bearer-token.tr.md | 0 docs/tech-specs/{ => tr}/minio-to-s3-migration.tr.md | 0 docs/tech-specs/{ => tr}/more-config-cli.tr.md | 0 docs/tech-specs/{ => tr}/multi-tenant-support.tr.md | 0 docs/tech-specs/{ => tr}/neo4j-user-collection-isolation.tr.md | 0 docs/tech-specs/{ => tr}/ontology-extract-phase-2.tr.md | 0 docs/tech-specs/{ => tr}/ontology.tr.md | 0 docs/tech-specs/{ => tr}/ontorag.tr.md | 0 docs/tech-specs/{ => tr}/openapi-spec.tr.md | 0 docs/tech-specs/{ => tr}/pubsub.tr.md | 0 docs/tech-specs/{ => tr}/python-api-refactor.tr.md | 0 docs/tech-specs/{ => tr}/query-time-explainability.tr.md | 0 docs/tech-specs/{ => tr}/rag-streaming-support.tr.md | 0 docs/tech-specs/{ => tr}/schema-refactoring-proposal.tr.md | 0 docs/tech-specs/{ => tr}/streaming-llm-responses.tr.md | 0 docs/tech-specs/{ => tr}/structured-data-2.tr.md | 0 docs/tech-specs/{ => tr}/structured-data-descriptor.tr.md | 0 docs/tech-specs/{ => tr}/structured-data-schemas.tr.md | 0 docs/tech-specs/{ => tr}/structured-data.tr.md | 0 docs/tech-specs/{ => tr}/structured-diag-service.tr.md | 0 docs/tech-specs/{ => tr}/tool-group.tr.md | 0 docs/tech-specs/{ => tr}/tool-services.tr.md | 0 docs/tech-specs/{ => tr}/universal-decoder.tr.md | 0 docs/tech-specs/{ => tr}/vector-store-lifecycle.tr.md | 0 docs/tech-specs/{ => zh-cn}/__TEMPLATE.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/agent-explainability.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/architecture-principles.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/cassandra-consolidation.zh-cn.md | 0 .../{ => zh-cn}/cassandra-performance-refactor.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/collection-management.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/document-embeddings-chunk-id.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/embeddings-batch-processing.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/entity-centric-graph.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/explainability-cli.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/extraction-flows.zh-cn.md | 0 .../{ => zh-cn}/extraction-provenance-subgraph.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/extraction-time-provenance.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/flow-class-definition.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/flow-configurable-parameters.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/graph-contexts.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/graphql-query.zh-cn.md | 0 .../{ => zh-cn}/graphrag-performance-optimization.zh-cn.md | 0 .../{ => zh-cn}/import-export-graceful-shutdown.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/jsonl-prompt-output.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/large-document-loading.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/logging-strategy.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/mcp-tool-arguments.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/mcp-tool-bearer-token.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/minio-to-s3-migration.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/more-config-cli.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/multi-tenant-support.zh-cn.md | 0 .../{ => zh-cn}/neo4j-user-collection-isolation.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/ontology-extract-phase-2.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/ontology.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/ontorag.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/openapi-spec.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/pubsub.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/python-api-refactor.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/query-time-explainability.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/rag-streaming-support.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/schema-refactoring-proposal.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/streaming-llm-responses.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/structured-data-2.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/structured-data-descriptor.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/structured-data-schemas.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/structured-data.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/structured-diag-service.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/tool-group.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/tool-services.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/universal-decoder.zh-cn.md | 0 docs/tech-specs/{ => zh-cn}/vector-store-lifecycle.zh-cn.md | 0 423 files changed, 0 insertions(+), 0 deletions(-) rename docs/tech-specs/{ => ar}/__TEMPLATE.ar.md (100%) rename docs/tech-specs/{ => ar}/agent-explainability.ar.md (100%) rename docs/tech-specs/{ => ar}/architecture-principles.ar.md (100%) rename docs/tech-specs/{ => ar}/cassandra-consolidation.ar.md (100%) rename docs/tech-specs/{ => ar}/cassandra-performance-refactor.ar.md (100%) rename docs/tech-specs/{ => ar}/collection-management.ar.md (100%) rename docs/tech-specs/{ => ar}/document-embeddings-chunk-id.ar.md (100%) rename docs/tech-specs/{ => ar}/embeddings-batch-processing.ar.md (100%) rename docs/tech-specs/{ => ar}/entity-centric-graph.ar.md (100%) rename docs/tech-specs/{ => ar}/explainability-cli.ar.md (100%) rename docs/tech-specs/{ => ar}/extraction-flows.ar.md (100%) rename docs/tech-specs/{ => ar}/extraction-provenance-subgraph.ar.md (100%) rename docs/tech-specs/{ => ar}/extraction-time-provenance.ar.md (100%) rename docs/tech-specs/{ => ar}/flow-class-definition.ar.md (100%) rename docs/tech-specs/{ => ar}/flow-configurable-parameters.ar.md (100%) rename docs/tech-specs/{ => ar}/graph-contexts.ar.md (100%) rename docs/tech-specs/{ => ar}/graphql-query.ar.md (100%) rename docs/tech-specs/{ => ar}/graphrag-performance-optimization.ar.md (100%) rename docs/tech-specs/{ => ar}/import-export-graceful-shutdown.ar.md (100%) rename docs/tech-specs/{ => ar}/jsonl-prompt-output.ar.md (100%) rename docs/tech-specs/{ => ar}/large-document-loading.ar.md (100%) rename docs/tech-specs/{ => ar}/logging-strategy.ar.md (100%) rename docs/tech-specs/{ => ar}/mcp-tool-arguments.ar.md (100%) rename docs/tech-specs/{ => ar}/mcp-tool-bearer-token.ar.md (100%) rename docs/tech-specs/{ => ar}/minio-to-s3-migration.ar.md (100%) rename docs/tech-specs/{ => ar}/more-config-cli.ar.md (100%) rename docs/tech-specs/{ => ar}/multi-tenant-support.ar.md (100%) rename docs/tech-specs/{ => ar}/neo4j-user-collection-isolation.ar.md (100%) rename docs/tech-specs/{ => ar}/ontology-extract-phase-2.ar.md (100%) rename docs/tech-specs/{ => ar}/ontology.ar.md (100%) rename docs/tech-specs/{ => ar}/ontorag.ar.md (100%) rename docs/tech-specs/{ => ar}/openapi-spec.ar.md (100%) rename docs/tech-specs/{ => ar}/pubsub.ar.md (100%) rename docs/tech-specs/{ => ar}/python-api-refactor.ar.md (100%) rename docs/tech-specs/{ => ar}/query-time-explainability.ar.md (100%) rename docs/tech-specs/{ => ar}/rag-streaming-support.ar.md (100%) rename docs/tech-specs/{ => ar}/schema-refactoring-proposal.ar.md (100%) rename docs/tech-specs/{ => ar}/streaming-llm-responses.ar.md (100%) rename docs/tech-specs/{ => ar}/structured-data-2.ar.md (100%) rename docs/tech-specs/{ => ar}/structured-data-descriptor.ar.md (100%) rename docs/tech-specs/{ => ar}/structured-data-schemas.ar.md (100%) rename docs/tech-specs/{ => ar}/structured-data.ar.md (100%) rename docs/tech-specs/{ => ar}/structured-diag-service.ar.md (100%) rename docs/tech-specs/{ => ar}/tool-group.ar.md (100%) rename docs/tech-specs/{ => ar}/tool-services.ar.md (100%) rename docs/tech-specs/{ => ar}/universal-decoder.ar.md (100%) rename docs/tech-specs/{ => ar}/vector-store-lifecycle.ar.md (100%) rename docs/tech-specs/{ => es}/__TEMPLATE.es.md (100%) rename docs/tech-specs/{ => es}/agent-explainability.es.md (100%) rename docs/tech-specs/{ => es}/architecture-principles.es.md (100%) rename docs/tech-specs/{ => es}/cassandra-consolidation.es.md (100%) rename docs/tech-specs/{ => es}/cassandra-performance-refactor.es.md (100%) rename docs/tech-specs/{ => es}/collection-management.es.md (100%) rename docs/tech-specs/{ => es}/document-embeddings-chunk-id.es.md (100%) rename docs/tech-specs/{ => es}/embeddings-batch-processing.es.md (100%) rename docs/tech-specs/{ => es}/entity-centric-graph.es.md (100%) rename docs/tech-specs/{ => es}/explainability-cli.es.md (100%) rename docs/tech-specs/{ => es}/extraction-flows.es.md (100%) rename docs/tech-specs/{ => es}/extraction-provenance-subgraph.es.md (100%) rename docs/tech-specs/{ => es}/extraction-time-provenance.es.md (100%) rename docs/tech-specs/{ => es}/flow-class-definition.es.md (100%) rename docs/tech-specs/{ => es}/flow-configurable-parameters.es.md (100%) rename docs/tech-specs/{ => es}/graph-contexts.es.md (100%) rename docs/tech-specs/{ => es}/graphql-query.es.md (100%) rename docs/tech-specs/{ => es}/graphrag-performance-optimization.es.md (100%) rename docs/tech-specs/{ => es}/import-export-graceful-shutdown.es.md (100%) rename docs/tech-specs/{ => es}/jsonl-prompt-output.es.md (100%) rename docs/tech-specs/{ => es}/large-document-loading.es.md (100%) rename docs/tech-specs/{ => es}/logging-strategy.es.md (100%) rename docs/tech-specs/{ => es}/mcp-tool-arguments.es.md (100%) rename docs/tech-specs/{ => es}/mcp-tool-bearer-token.es.md (100%) rename docs/tech-specs/{ => es}/minio-to-s3-migration.es.md (100%) rename docs/tech-specs/{ => es}/more-config-cli.es.md (100%) rename docs/tech-specs/{ => es}/multi-tenant-support.es.md (100%) rename docs/tech-specs/{ => es}/neo4j-user-collection-isolation.es.md (100%) rename docs/tech-specs/{ => es}/ontology-extract-phase-2.es.md (100%) rename docs/tech-specs/{ => es}/ontology.es.md (100%) rename docs/tech-specs/{ => es}/ontorag.es.md (100%) rename docs/tech-specs/{ => es}/openapi-spec.es.md (100%) rename docs/tech-specs/{ => es}/pubsub.es.md (100%) rename docs/tech-specs/{ => es}/python-api-refactor.es.md (100%) rename docs/tech-specs/{ => es}/query-time-explainability.es.md (100%) rename docs/tech-specs/{ => es}/rag-streaming-support.es.md (100%) rename docs/tech-specs/{ => es}/schema-refactoring-proposal.es.md (100%) rename docs/tech-specs/{ => es}/streaming-llm-responses.es.md (100%) rename docs/tech-specs/{ => es}/structured-data-2.es.md (100%) rename docs/tech-specs/{ => es}/structured-data-descriptor.es.md (100%) rename docs/tech-specs/{ => es}/structured-data-schemas.es.md (100%) rename docs/tech-specs/{ => es}/structured-data.es.md (100%) rename docs/tech-specs/{ => es}/structured-diag-service.es.md (100%) rename docs/tech-specs/{ => es}/tool-group.es.md (100%) rename docs/tech-specs/{ => es}/tool-services.es.md (100%) rename docs/tech-specs/{ => es}/universal-decoder.es.md (100%) rename docs/tech-specs/{ => es}/vector-store-lifecycle.es.md (100%) rename docs/tech-specs/{ => he}/__TEMPLATE.he.md (100%) rename docs/tech-specs/{ => he}/agent-explainability.he.md (100%) rename docs/tech-specs/{ => he}/architecture-principles.he.md (100%) rename docs/tech-specs/{ => he}/cassandra-consolidation.he.md (100%) rename docs/tech-specs/{ => he}/cassandra-performance-refactor.he.md (100%) rename docs/tech-specs/{ => he}/collection-management.he.md (100%) rename docs/tech-specs/{ => he}/document-embeddings-chunk-id.he.md (100%) rename docs/tech-specs/{ => he}/embeddings-batch-processing.he.md (100%) rename docs/tech-specs/{ => he}/entity-centric-graph.he.md (100%) rename docs/tech-specs/{ => he}/explainability-cli.he.md (100%) rename docs/tech-specs/{ => he}/extraction-flows.he.md (100%) rename docs/tech-specs/{ => he}/extraction-provenance-subgraph.he.md (100%) rename docs/tech-specs/{ => he}/extraction-time-provenance.he.md (100%) rename docs/tech-specs/{ => he}/flow-class-definition.he.md (100%) rename docs/tech-specs/{ => he}/flow-configurable-parameters.he.md (100%) rename docs/tech-specs/{ => he}/graph-contexts.he.md (100%) rename docs/tech-specs/{ => he}/graphql-query.he.md (100%) rename docs/tech-specs/{ => he}/graphrag-performance-optimization.he.md (100%) rename docs/tech-specs/{ => he}/import-export-graceful-shutdown.he.md (100%) rename docs/tech-specs/{ => he}/jsonl-prompt-output.he.md (100%) rename docs/tech-specs/{ => he}/large-document-loading.he.md (100%) rename docs/tech-specs/{ => he}/logging-strategy.he.md (100%) rename docs/tech-specs/{ => he}/mcp-tool-arguments.he.md (100%) rename docs/tech-specs/{ => he}/mcp-tool-bearer-token.he.md (100%) rename docs/tech-specs/{ => he}/minio-to-s3-migration.he.md (100%) rename docs/tech-specs/{ => he}/more-config-cli.he.md (100%) rename docs/tech-specs/{ => he}/multi-tenant-support.he.md (100%) rename docs/tech-specs/{ => he}/neo4j-user-collection-isolation.he.md (100%) rename docs/tech-specs/{ => he}/ontology-extract-phase-2.he.md (100%) rename docs/tech-specs/{ => he}/ontology.he.md (100%) rename docs/tech-specs/{ => he}/ontorag.he.md (100%) rename docs/tech-specs/{ => he}/openapi-spec.he.md (100%) rename docs/tech-specs/{ => he}/pubsub.he.md (100%) rename docs/tech-specs/{ => he}/python-api-refactor.he.md (100%) rename docs/tech-specs/{ => he}/query-time-explainability.he.md (100%) rename docs/tech-specs/{ => he}/rag-streaming-support.he.md (100%) rename docs/tech-specs/{ => he}/schema-refactoring-proposal.he.md (100%) rename docs/tech-specs/{ => he}/streaming-llm-responses.he.md (100%) rename docs/tech-specs/{ => he}/structured-data-2.he.md (100%) rename docs/tech-specs/{ => he}/structured-data-descriptor.he.md (100%) rename docs/tech-specs/{ => he}/structured-data-schemas.he.md (100%) rename docs/tech-specs/{ => he}/structured-data.he.md (100%) rename docs/tech-specs/{ => he}/structured-diag-service.he.md (100%) rename docs/tech-specs/{ => he}/tool-group.he.md (100%) rename docs/tech-specs/{ => he}/tool-services.he.md (100%) rename docs/tech-specs/{ => he}/universal-decoder.he.md (100%) rename docs/tech-specs/{ => he}/vector-store-lifecycle.he.md (100%) rename docs/tech-specs/{ => hi}/__TEMPLATE.hi.md (100%) rename docs/tech-specs/{ => hi}/agent-explainability.hi.md (100%) rename docs/tech-specs/{ => hi}/architecture-principles.hi.md (100%) rename docs/tech-specs/{ => hi}/cassandra-consolidation.hi.md (100%) rename docs/tech-specs/{ => hi}/cassandra-performance-refactor.hi.md (100%) rename docs/tech-specs/{ => hi}/collection-management.hi.md (100%) rename docs/tech-specs/{ => hi}/document-embeddings-chunk-id.hi.md (100%) rename docs/tech-specs/{ => hi}/embeddings-batch-processing.hi.md (100%) rename docs/tech-specs/{ => hi}/entity-centric-graph.hi.md (100%) rename docs/tech-specs/{ => hi}/explainability-cli.hi.md (100%) rename docs/tech-specs/{ => hi}/extraction-flows.hi.md (100%) rename docs/tech-specs/{ => hi}/extraction-provenance-subgraph.hi.md (100%) rename docs/tech-specs/{ => hi}/extraction-time-provenance.hi.md (100%) rename docs/tech-specs/{ => hi}/flow-class-definition.hi.md (100%) rename docs/tech-specs/{ => hi}/flow-configurable-parameters.hi.md (100%) rename docs/tech-specs/{ => hi}/graph-contexts.hi.md (100%) rename docs/tech-specs/{ => hi}/graphql-query.hi.md (100%) rename docs/tech-specs/{ => hi}/graphrag-performance-optimization.hi.md (100%) rename docs/tech-specs/{ => hi}/import-export-graceful-shutdown.hi.md (100%) rename docs/tech-specs/{ => hi}/jsonl-prompt-output.hi.md (100%) rename docs/tech-specs/{ => hi}/large-document-loading.hi.md (100%) rename docs/tech-specs/{ => hi}/logging-strategy.hi.md (100%) rename docs/tech-specs/{ => hi}/mcp-tool-arguments.hi.md (100%) rename docs/tech-specs/{ => hi}/mcp-tool-bearer-token.hi.md (100%) rename docs/tech-specs/{ => hi}/minio-to-s3-migration.hi.md (100%) rename docs/tech-specs/{ => hi}/more-config-cli.hi.md (100%) rename docs/tech-specs/{ => hi}/multi-tenant-support.hi.md (100%) rename docs/tech-specs/{ => hi}/neo4j-user-collection-isolation.hi.md (100%) rename docs/tech-specs/{ => hi}/ontology-extract-phase-2.hi.md (100%) rename docs/tech-specs/{ => hi}/ontology.hi.md (100%) rename docs/tech-specs/{ => hi}/ontorag.hi.md (100%) rename docs/tech-specs/{ => hi}/openapi-spec.hi.md (100%) rename docs/tech-specs/{ => hi}/pubsub.hi.md (100%) rename docs/tech-specs/{ => hi}/python-api-refactor.hi.md (100%) rename docs/tech-specs/{ => hi}/query-time-explainability.hi.md (100%) rename docs/tech-specs/{ => hi}/rag-streaming-support.hi.md (100%) rename docs/tech-specs/{ => hi}/schema-refactoring-proposal.hi.md (100%) rename docs/tech-specs/{ => hi}/streaming-llm-responses.hi.md (100%) rename docs/tech-specs/{ => hi}/structured-data-2.hi.md (100%) rename docs/tech-specs/{ => hi}/structured-data-descriptor.hi.md (100%) rename docs/tech-specs/{ => hi}/structured-data-schemas.hi.md (100%) rename docs/tech-specs/{ => hi}/structured-data.hi.md (100%) rename docs/tech-specs/{ => hi}/structured-diag-service.hi.md (100%) rename docs/tech-specs/{ => hi}/tool-group.hi.md (100%) rename docs/tech-specs/{ => hi}/tool-services.hi.md (100%) rename docs/tech-specs/{ => hi}/universal-decoder.hi.md (100%) rename docs/tech-specs/{ => hi}/vector-store-lifecycle.hi.md (100%) rename docs/tech-specs/{ => pt}/__TEMPLATE.pt.md (100%) rename docs/tech-specs/{ => pt}/agent-explainability.pt.md (100%) rename docs/tech-specs/{ => pt}/architecture-principles.pt.md (100%) rename docs/tech-specs/{ => pt}/cassandra-consolidation.pt.md (100%) rename docs/tech-specs/{ => pt}/cassandra-performance-refactor.pt.md (100%) rename docs/tech-specs/{ => pt}/collection-management.pt.md (100%) rename docs/tech-specs/{ => pt}/document-embeddings-chunk-id.pt.md (100%) rename docs/tech-specs/{ => pt}/embeddings-batch-processing.pt.md (100%) rename docs/tech-specs/{ => pt}/entity-centric-graph.pt.md (100%) rename docs/tech-specs/{ => pt}/explainability-cli.pt.md (100%) rename docs/tech-specs/{ => pt}/extraction-flows.pt.md (100%) rename docs/tech-specs/{ => pt}/extraction-provenance-subgraph.pt.md (100%) rename docs/tech-specs/{ => pt}/extraction-time-provenance.pt.md (100%) rename docs/tech-specs/{ => pt}/flow-class-definition.pt.md (100%) rename docs/tech-specs/{ => pt}/flow-configurable-parameters.pt.md (100%) rename docs/tech-specs/{ => pt}/graph-contexts.pt.md (100%) rename docs/tech-specs/{ => pt}/graphql-query.pt.md (100%) rename docs/tech-specs/{ => pt}/graphrag-performance-optimization.pt.md (100%) rename docs/tech-specs/{ => pt}/import-export-graceful-shutdown.pt.md (100%) rename docs/tech-specs/{ => pt}/jsonl-prompt-output.pt.md (100%) rename docs/tech-specs/{ => pt}/large-document-loading.pt.md (100%) rename docs/tech-specs/{ => pt}/logging-strategy.pt.md (100%) rename docs/tech-specs/{ => pt}/mcp-tool-arguments.pt.md (100%) rename docs/tech-specs/{ => pt}/mcp-tool-bearer-token.pt.md (100%) rename docs/tech-specs/{ => pt}/minio-to-s3-migration.pt.md (100%) rename docs/tech-specs/{ => pt}/more-config-cli.pt.md (100%) rename docs/tech-specs/{ => pt}/multi-tenant-support.pt.md (100%) rename docs/tech-specs/{ => pt}/neo4j-user-collection-isolation.pt.md (100%) rename docs/tech-specs/{ => pt}/ontology-extract-phase-2.pt.md (100%) rename docs/tech-specs/{ => pt}/ontology.pt.md (100%) rename docs/tech-specs/{ => pt}/ontorag.pt.md (100%) rename docs/tech-specs/{ => pt}/openapi-spec.pt.md (100%) rename docs/tech-specs/{ => pt}/pubsub.pt.md (100%) rename docs/tech-specs/{ => pt}/python-api-refactor.pt.md (100%) rename docs/tech-specs/{ => pt}/query-time-explainability.pt.md (100%) rename docs/tech-specs/{ => pt}/rag-streaming-support.pt.md (100%) rename docs/tech-specs/{ => pt}/schema-refactoring-proposal.pt.md (100%) rename docs/tech-specs/{ => pt}/streaming-llm-responses.pt.md (100%) rename docs/tech-specs/{ => pt}/structured-data-2.pt.md (100%) rename docs/tech-specs/{ => pt}/structured-data-descriptor.pt.md (100%) rename docs/tech-specs/{ => pt}/structured-data-schemas.pt.md (100%) rename docs/tech-specs/{ => pt}/structured-data.pt.md (100%) rename docs/tech-specs/{ => pt}/structured-diag-service.pt.md (100%) rename docs/tech-specs/{ => pt}/tool-group.pt.md (100%) rename docs/tech-specs/{ => pt}/tool-services.pt.md (100%) rename docs/tech-specs/{ => pt}/universal-decoder.pt.md (100%) rename docs/tech-specs/{ => pt}/vector-store-lifecycle.pt.md (100%) rename docs/tech-specs/{ => ru}/__TEMPLATE.ru.md (100%) rename docs/tech-specs/{ => ru}/agent-explainability.ru.md (100%) rename docs/tech-specs/{ => ru}/architecture-principles.ru.md (100%) rename docs/tech-specs/{ => ru}/cassandra-consolidation.ru.md (100%) rename docs/tech-specs/{ => ru}/cassandra-performance-refactor.ru.md (100%) rename docs/tech-specs/{ => ru}/collection-management.ru.md (100%) rename docs/tech-specs/{ => ru}/document-embeddings-chunk-id.ru.md (100%) rename docs/tech-specs/{ => ru}/embeddings-batch-processing.ru.md (100%) rename docs/tech-specs/{ => ru}/entity-centric-graph.ru.md (100%) rename docs/tech-specs/{ => ru}/explainability-cli.ru.md (100%) rename docs/tech-specs/{ => ru}/extraction-flows.ru.md (100%) rename docs/tech-specs/{ => ru}/extraction-provenance-subgraph.ru.md (100%) rename docs/tech-specs/{ => ru}/extraction-time-provenance.ru.md (100%) rename docs/tech-specs/{ => ru}/flow-class-definition.ru.md (100%) rename docs/tech-specs/{ => ru}/flow-configurable-parameters.ru.md (100%) rename docs/tech-specs/{ => ru}/graph-contexts.ru.md (100%) rename docs/tech-specs/{ => ru}/graphql-query.ru.md (100%) rename docs/tech-specs/{ => ru}/graphrag-performance-optimization.ru.md (100%) rename docs/tech-specs/{ => ru}/import-export-graceful-shutdown.ru.md (100%) rename docs/tech-specs/{ => ru}/jsonl-prompt-output.ru.md (100%) rename docs/tech-specs/{ => ru}/large-document-loading.ru.md (100%) rename docs/tech-specs/{ => ru}/logging-strategy.ru.md (100%) rename docs/tech-specs/{ => ru}/mcp-tool-arguments.ru.md (100%) rename docs/tech-specs/{ => ru}/mcp-tool-bearer-token.ru.md (100%) rename docs/tech-specs/{ => ru}/minio-to-s3-migration.ru.md (100%) rename docs/tech-specs/{ => ru}/more-config-cli.ru.md (100%) rename docs/tech-specs/{ => ru}/multi-tenant-support.ru.md (100%) rename docs/tech-specs/{ => ru}/neo4j-user-collection-isolation.ru.md (100%) rename docs/tech-specs/{ => ru}/ontology-extract-phase-2.ru.md (100%) rename docs/tech-specs/{ => ru}/ontology.ru.md (100%) rename docs/tech-specs/{ => ru}/ontorag.ru.md (100%) rename docs/tech-specs/{ => ru}/openapi-spec.ru.md (100%) rename docs/tech-specs/{ => ru}/pubsub.ru.md (100%) rename docs/tech-specs/{ => ru}/python-api-refactor.ru.md (100%) rename docs/tech-specs/{ => ru}/query-time-explainability.ru.md (100%) rename docs/tech-specs/{ => ru}/rag-streaming-support.ru.md (100%) rename docs/tech-specs/{ => ru}/schema-refactoring-proposal.ru.md (100%) rename docs/tech-specs/{ => ru}/streaming-llm-responses.ru.md (100%) rename docs/tech-specs/{ => ru}/structured-data-2.ru.md (100%) rename docs/tech-specs/{ => ru}/structured-data-descriptor.ru.md (100%) rename docs/tech-specs/{ => ru}/structured-data-schemas.ru.md (100%) rename docs/tech-specs/{ => ru}/structured-data.ru.md (100%) rename docs/tech-specs/{ => ru}/structured-diag-service.ru.md (100%) rename docs/tech-specs/{ => ru}/tool-group.ru.md (100%) rename docs/tech-specs/{ => ru}/tool-services.ru.md (100%) rename docs/tech-specs/{ => ru}/universal-decoder.ru.md (100%) rename docs/tech-specs/{ => ru}/vector-store-lifecycle.ru.md (100%) rename docs/tech-specs/{ => sw}/__TEMPLATE.sw.md (100%) rename docs/tech-specs/{ => sw}/agent-explainability.sw.md (100%) rename docs/tech-specs/{ => sw}/architecture-principles.sw.md (100%) rename docs/tech-specs/{ => sw}/cassandra-consolidation.sw.md (100%) rename docs/tech-specs/{ => sw}/cassandra-performance-refactor.sw.md (100%) rename docs/tech-specs/{ => sw}/collection-management.sw.md (100%) rename docs/tech-specs/{ => sw}/document-embeddings-chunk-id.sw.md (100%) rename docs/tech-specs/{ => sw}/embeddings-batch-processing.sw.md (100%) rename docs/tech-specs/{ => sw}/entity-centric-graph.sw.md (100%) rename docs/tech-specs/{ => sw}/explainability-cli.sw.md (100%) rename docs/tech-specs/{ => sw}/extraction-flows.sw.md (100%) rename docs/tech-specs/{ => sw}/extraction-provenance-subgraph.sw.md (100%) rename docs/tech-specs/{ => sw}/extraction-time-provenance.sw.md (100%) rename docs/tech-specs/{ => sw}/flow-class-definition.sw.md (100%) rename docs/tech-specs/{ => sw}/flow-configurable-parameters.sw.md (100%) rename docs/tech-specs/{ => sw}/graph-contexts.sw.md (100%) rename docs/tech-specs/{ => sw}/graphql-query.sw.md (100%) rename docs/tech-specs/{ => sw}/graphrag-performance-optimization.sw.md (100%) rename docs/tech-specs/{ => sw}/import-export-graceful-shutdown.sw.md (100%) rename docs/tech-specs/{ => sw}/jsonl-prompt-output.sw.md (100%) rename docs/tech-specs/{ => sw}/large-document-loading.sw.md (100%) rename docs/tech-specs/{ => sw}/logging-strategy.sw.md (100%) rename docs/tech-specs/{ => sw}/mcp-tool-arguments.sw.md (100%) rename docs/tech-specs/{ => sw}/mcp-tool-bearer-token.sw.md (100%) rename docs/tech-specs/{ => sw}/minio-to-s3-migration.sw.md (100%) rename docs/tech-specs/{ => sw}/more-config-cli.sw.md (100%) rename docs/tech-specs/{ => sw}/multi-tenant-support.sw.md (100%) rename docs/tech-specs/{ => sw}/neo4j-user-collection-isolation.sw.md (100%) rename docs/tech-specs/{ => sw}/ontology-extract-phase-2.sw.md (100%) rename docs/tech-specs/{ => sw}/ontology.sw.md (100%) rename docs/tech-specs/{ => sw}/ontorag.sw.md (100%) rename docs/tech-specs/{ => sw}/openapi-spec.sw.md (100%) rename docs/tech-specs/{ => sw}/pubsub.sw.md (100%) rename docs/tech-specs/{ => sw}/python-api-refactor.sw.md (100%) rename docs/tech-specs/{ => sw}/query-time-explainability.sw.md (100%) rename docs/tech-specs/{ => sw}/rag-streaming-support.sw.md (100%) rename docs/tech-specs/{ => sw}/schema-refactoring-proposal.sw.md (100%) rename docs/tech-specs/{ => sw}/streaming-llm-responses.sw.md (100%) rename docs/tech-specs/{ => sw}/structured-data-2.sw.md (100%) rename docs/tech-specs/{ => sw}/structured-data-descriptor.sw.md (100%) rename docs/tech-specs/{ => sw}/structured-data-schemas.sw.md (100%) rename docs/tech-specs/{ => sw}/structured-data.sw.md (100%) rename docs/tech-specs/{ => sw}/structured-diag-service.sw.md (100%) rename docs/tech-specs/{ => sw}/tool-group.sw.md (100%) rename docs/tech-specs/{ => sw}/tool-services.sw.md (100%) rename docs/tech-specs/{ => sw}/universal-decoder.sw.md (100%) rename docs/tech-specs/{ => sw}/vector-store-lifecycle.sw.md (100%) rename docs/tech-specs/{ => tr}/__TEMPLATE.tr.md (100%) rename docs/tech-specs/{ => tr}/agent-explainability.tr.md (100%) rename docs/tech-specs/{ => tr}/architecture-principles.tr.md (100%) rename docs/tech-specs/{ => tr}/cassandra-consolidation.tr.md (100%) rename docs/tech-specs/{ => tr}/cassandra-performance-refactor.tr.md (100%) rename docs/tech-specs/{ => tr}/collection-management.tr.md (100%) rename docs/tech-specs/{ => tr}/document-embeddings-chunk-id.tr.md (100%) rename docs/tech-specs/{ => tr}/embeddings-batch-processing.tr.md (100%) rename docs/tech-specs/{ => tr}/entity-centric-graph.tr.md (100%) rename docs/tech-specs/{ => tr}/explainability-cli.tr.md (100%) rename docs/tech-specs/{ => tr}/extraction-flows.tr.md (100%) rename docs/tech-specs/{ => tr}/extraction-provenance-subgraph.tr.md (100%) rename docs/tech-specs/{ => tr}/extraction-time-provenance.tr.md (100%) rename docs/tech-specs/{ => tr}/flow-class-definition.tr.md (100%) rename docs/tech-specs/{ => tr}/flow-configurable-parameters.tr.md (100%) rename docs/tech-specs/{ => tr}/graph-contexts.tr.md (100%) rename docs/tech-specs/{ => tr}/graphql-query.tr.md (100%) rename docs/tech-specs/{ => tr}/graphrag-performance-optimization.tr.md (100%) rename docs/tech-specs/{ => tr}/import-export-graceful-shutdown.tr.md (100%) rename docs/tech-specs/{ => tr}/jsonl-prompt-output.tr.md (100%) rename docs/tech-specs/{ => tr}/large-document-loading.tr.md (100%) rename docs/tech-specs/{ => tr}/logging-strategy.tr.md (100%) rename docs/tech-specs/{ => tr}/mcp-tool-arguments.tr.md (100%) rename docs/tech-specs/{ => tr}/mcp-tool-bearer-token.tr.md (100%) rename docs/tech-specs/{ => tr}/minio-to-s3-migration.tr.md (100%) rename docs/tech-specs/{ => tr}/more-config-cli.tr.md (100%) rename docs/tech-specs/{ => tr}/multi-tenant-support.tr.md (100%) rename docs/tech-specs/{ => tr}/neo4j-user-collection-isolation.tr.md (100%) rename docs/tech-specs/{ => tr}/ontology-extract-phase-2.tr.md (100%) rename docs/tech-specs/{ => tr}/ontology.tr.md (100%) rename docs/tech-specs/{ => tr}/ontorag.tr.md (100%) rename docs/tech-specs/{ => tr}/openapi-spec.tr.md (100%) rename docs/tech-specs/{ => tr}/pubsub.tr.md (100%) rename docs/tech-specs/{ => tr}/python-api-refactor.tr.md (100%) rename docs/tech-specs/{ => tr}/query-time-explainability.tr.md (100%) rename docs/tech-specs/{ => tr}/rag-streaming-support.tr.md (100%) rename docs/tech-specs/{ => tr}/schema-refactoring-proposal.tr.md (100%) rename docs/tech-specs/{ => tr}/streaming-llm-responses.tr.md (100%) rename docs/tech-specs/{ => tr}/structured-data-2.tr.md (100%) rename docs/tech-specs/{ => tr}/structured-data-descriptor.tr.md (100%) rename docs/tech-specs/{ => tr}/structured-data-schemas.tr.md (100%) rename docs/tech-specs/{ => tr}/structured-data.tr.md (100%) rename docs/tech-specs/{ => tr}/structured-diag-service.tr.md (100%) rename docs/tech-specs/{ => tr}/tool-group.tr.md (100%) rename docs/tech-specs/{ => tr}/tool-services.tr.md (100%) rename docs/tech-specs/{ => tr}/universal-decoder.tr.md (100%) rename docs/tech-specs/{ => tr}/vector-store-lifecycle.tr.md (100%) rename docs/tech-specs/{ => zh-cn}/__TEMPLATE.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/agent-explainability.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/architecture-principles.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/cassandra-consolidation.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/cassandra-performance-refactor.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/collection-management.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/document-embeddings-chunk-id.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/embeddings-batch-processing.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/entity-centric-graph.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/explainability-cli.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/extraction-flows.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/extraction-provenance-subgraph.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/extraction-time-provenance.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/flow-class-definition.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/flow-configurable-parameters.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/graph-contexts.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/graphql-query.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/graphrag-performance-optimization.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/import-export-graceful-shutdown.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/jsonl-prompt-output.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/large-document-loading.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/logging-strategy.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/mcp-tool-arguments.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/mcp-tool-bearer-token.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/minio-to-s3-migration.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/more-config-cli.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/multi-tenant-support.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/neo4j-user-collection-isolation.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/ontology-extract-phase-2.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/ontology.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/ontorag.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/openapi-spec.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/pubsub.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/python-api-refactor.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/query-time-explainability.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/rag-streaming-support.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/schema-refactoring-proposal.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/streaming-llm-responses.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/structured-data-2.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/structured-data-descriptor.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/structured-data-schemas.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/structured-data.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/structured-diag-service.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/tool-group.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/tool-services.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/universal-decoder.zh-cn.md (100%) rename docs/tech-specs/{ => zh-cn}/vector-store-lifecycle.zh-cn.md (100%) diff --git a/docs/tech-specs/__TEMPLATE.ar.md b/docs/tech-specs/ar/__TEMPLATE.ar.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.ar.md rename to docs/tech-specs/ar/__TEMPLATE.ar.md diff --git a/docs/tech-specs/agent-explainability.ar.md b/docs/tech-specs/ar/agent-explainability.ar.md similarity index 100% rename from docs/tech-specs/agent-explainability.ar.md rename to docs/tech-specs/ar/agent-explainability.ar.md diff --git a/docs/tech-specs/architecture-principles.ar.md b/docs/tech-specs/ar/architecture-principles.ar.md similarity index 100% rename from docs/tech-specs/architecture-principles.ar.md rename to docs/tech-specs/ar/architecture-principles.ar.md diff --git a/docs/tech-specs/cassandra-consolidation.ar.md b/docs/tech-specs/ar/cassandra-consolidation.ar.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.ar.md rename to docs/tech-specs/ar/cassandra-consolidation.ar.md diff --git a/docs/tech-specs/cassandra-performance-refactor.ar.md b/docs/tech-specs/ar/cassandra-performance-refactor.ar.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.ar.md rename to docs/tech-specs/ar/cassandra-performance-refactor.ar.md diff --git a/docs/tech-specs/collection-management.ar.md b/docs/tech-specs/ar/collection-management.ar.md similarity index 100% rename from docs/tech-specs/collection-management.ar.md rename to docs/tech-specs/ar/collection-management.ar.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.ar.md b/docs/tech-specs/ar/document-embeddings-chunk-id.ar.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.ar.md rename to docs/tech-specs/ar/document-embeddings-chunk-id.ar.md diff --git a/docs/tech-specs/embeddings-batch-processing.ar.md b/docs/tech-specs/ar/embeddings-batch-processing.ar.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.ar.md rename to docs/tech-specs/ar/embeddings-batch-processing.ar.md diff --git a/docs/tech-specs/entity-centric-graph.ar.md b/docs/tech-specs/ar/entity-centric-graph.ar.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.ar.md rename to docs/tech-specs/ar/entity-centric-graph.ar.md diff --git a/docs/tech-specs/explainability-cli.ar.md b/docs/tech-specs/ar/explainability-cli.ar.md similarity index 100% rename from docs/tech-specs/explainability-cli.ar.md rename to docs/tech-specs/ar/explainability-cli.ar.md diff --git a/docs/tech-specs/extraction-flows.ar.md b/docs/tech-specs/ar/extraction-flows.ar.md similarity index 100% rename from docs/tech-specs/extraction-flows.ar.md rename to docs/tech-specs/ar/extraction-flows.ar.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.ar.md b/docs/tech-specs/ar/extraction-provenance-subgraph.ar.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.ar.md rename to docs/tech-specs/ar/extraction-provenance-subgraph.ar.md diff --git a/docs/tech-specs/extraction-time-provenance.ar.md b/docs/tech-specs/ar/extraction-time-provenance.ar.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.ar.md rename to docs/tech-specs/ar/extraction-time-provenance.ar.md diff --git a/docs/tech-specs/flow-class-definition.ar.md b/docs/tech-specs/ar/flow-class-definition.ar.md similarity index 100% rename from docs/tech-specs/flow-class-definition.ar.md rename to docs/tech-specs/ar/flow-class-definition.ar.md diff --git a/docs/tech-specs/flow-configurable-parameters.ar.md b/docs/tech-specs/ar/flow-configurable-parameters.ar.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.ar.md rename to docs/tech-specs/ar/flow-configurable-parameters.ar.md diff --git a/docs/tech-specs/graph-contexts.ar.md b/docs/tech-specs/ar/graph-contexts.ar.md similarity index 100% rename from docs/tech-specs/graph-contexts.ar.md rename to docs/tech-specs/ar/graph-contexts.ar.md diff --git a/docs/tech-specs/graphql-query.ar.md b/docs/tech-specs/ar/graphql-query.ar.md similarity index 100% rename from docs/tech-specs/graphql-query.ar.md rename to docs/tech-specs/ar/graphql-query.ar.md diff --git a/docs/tech-specs/graphrag-performance-optimization.ar.md b/docs/tech-specs/ar/graphrag-performance-optimization.ar.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.ar.md rename to docs/tech-specs/ar/graphrag-performance-optimization.ar.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.ar.md b/docs/tech-specs/ar/import-export-graceful-shutdown.ar.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.ar.md rename to docs/tech-specs/ar/import-export-graceful-shutdown.ar.md diff --git a/docs/tech-specs/jsonl-prompt-output.ar.md b/docs/tech-specs/ar/jsonl-prompt-output.ar.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.ar.md rename to docs/tech-specs/ar/jsonl-prompt-output.ar.md diff --git a/docs/tech-specs/large-document-loading.ar.md b/docs/tech-specs/ar/large-document-loading.ar.md similarity index 100% rename from docs/tech-specs/large-document-loading.ar.md rename to docs/tech-specs/ar/large-document-loading.ar.md diff --git a/docs/tech-specs/logging-strategy.ar.md b/docs/tech-specs/ar/logging-strategy.ar.md similarity index 100% rename from docs/tech-specs/logging-strategy.ar.md rename to docs/tech-specs/ar/logging-strategy.ar.md diff --git a/docs/tech-specs/mcp-tool-arguments.ar.md b/docs/tech-specs/ar/mcp-tool-arguments.ar.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.ar.md rename to docs/tech-specs/ar/mcp-tool-arguments.ar.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.ar.md b/docs/tech-specs/ar/mcp-tool-bearer-token.ar.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.ar.md rename to docs/tech-specs/ar/mcp-tool-bearer-token.ar.md diff --git a/docs/tech-specs/minio-to-s3-migration.ar.md b/docs/tech-specs/ar/minio-to-s3-migration.ar.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.ar.md rename to docs/tech-specs/ar/minio-to-s3-migration.ar.md diff --git a/docs/tech-specs/more-config-cli.ar.md b/docs/tech-specs/ar/more-config-cli.ar.md similarity index 100% rename from docs/tech-specs/more-config-cli.ar.md rename to docs/tech-specs/ar/more-config-cli.ar.md diff --git a/docs/tech-specs/multi-tenant-support.ar.md b/docs/tech-specs/ar/multi-tenant-support.ar.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.ar.md rename to docs/tech-specs/ar/multi-tenant-support.ar.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.ar.md b/docs/tech-specs/ar/neo4j-user-collection-isolation.ar.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.ar.md rename to docs/tech-specs/ar/neo4j-user-collection-isolation.ar.md diff --git a/docs/tech-specs/ontology-extract-phase-2.ar.md b/docs/tech-specs/ar/ontology-extract-phase-2.ar.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.ar.md rename to docs/tech-specs/ar/ontology-extract-phase-2.ar.md diff --git a/docs/tech-specs/ontology.ar.md b/docs/tech-specs/ar/ontology.ar.md similarity index 100% rename from docs/tech-specs/ontology.ar.md rename to docs/tech-specs/ar/ontology.ar.md diff --git a/docs/tech-specs/ontorag.ar.md b/docs/tech-specs/ar/ontorag.ar.md similarity index 100% rename from docs/tech-specs/ontorag.ar.md rename to docs/tech-specs/ar/ontorag.ar.md diff --git a/docs/tech-specs/openapi-spec.ar.md b/docs/tech-specs/ar/openapi-spec.ar.md similarity index 100% rename from docs/tech-specs/openapi-spec.ar.md rename to docs/tech-specs/ar/openapi-spec.ar.md diff --git a/docs/tech-specs/pubsub.ar.md b/docs/tech-specs/ar/pubsub.ar.md similarity index 100% rename from docs/tech-specs/pubsub.ar.md rename to docs/tech-specs/ar/pubsub.ar.md diff --git a/docs/tech-specs/python-api-refactor.ar.md b/docs/tech-specs/ar/python-api-refactor.ar.md similarity index 100% rename from docs/tech-specs/python-api-refactor.ar.md rename to docs/tech-specs/ar/python-api-refactor.ar.md diff --git a/docs/tech-specs/query-time-explainability.ar.md b/docs/tech-specs/ar/query-time-explainability.ar.md similarity index 100% rename from docs/tech-specs/query-time-explainability.ar.md rename to docs/tech-specs/ar/query-time-explainability.ar.md diff --git a/docs/tech-specs/rag-streaming-support.ar.md b/docs/tech-specs/ar/rag-streaming-support.ar.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.ar.md rename to docs/tech-specs/ar/rag-streaming-support.ar.md diff --git a/docs/tech-specs/schema-refactoring-proposal.ar.md b/docs/tech-specs/ar/schema-refactoring-proposal.ar.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.ar.md rename to docs/tech-specs/ar/schema-refactoring-proposal.ar.md diff --git a/docs/tech-specs/streaming-llm-responses.ar.md b/docs/tech-specs/ar/streaming-llm-responses.ar.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.ar.md rename to docs/tech-specs/ar/streaming-llm-responses.ar.md diff --git a/docs/tech-specs/structured-data-2.ar.md b/docs/tech-specs/ar/structured-data-2.ar.md similarity index 100% rename from docs/tech-specs/structured-data-2.ar.md rename to docs/tech-specs/ar/structured-data-2.ar.md diff --git a/docs/tech-specs/structured-data-descriptor.ar.md b/docs/tech-specs/ar/structured-data-descriptor.ar.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.ar.md rename to docs/tech-specs/ar/structured-data-descriptor.ar.md diff --git a/docs/tech-specs/structured-data-schemas.ar.md b/docs/tech-specs/ar/structured-data-schemas.ar.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.ar.md rename to docs/tech-specs/ar/structured-data-schemas.ar.md diff --git a/docs/tech-specs/structured-data.ar.md b/docs/tech-specs/ar/structured-data.ar.md similarity index 100% rename from docs/tech-specs/structured-data.ar.md rename to docs/tech-specs/ar/structured-data.ar.md diff --git a/docs/tech-specs/structured-diag-service.ar.md b/docs/tech-specs/ar/structured-diag-service.ar.md similarity index 100% rename from docs/tech-specs/structured-diag-service.ar.md rename to docs/tech-specs/ar/structured-diag-service.ar.md diff --git a/docs/tech-specs/tool-group.ar.md b/docs/tech-specs/ar/tool-group.ar.md similarity index 100% rename from docs/tech-specs/tool-group.ar.md rename to docs/tech-specs/ar/tool-group.ar.md diff --git a/docs/tech-specs/tool-services.ar.md b/docs/tech-specs/ar/tool-services.ar.md similarity index 100% rename from docs/tech-specs/tool-services.ar.md rename to docs/tech-specs/ar/tool-services.ar.md diff --git a/docs/tech-specs/universal-decoder.ar.md b/docs/tech-specs/ar/universal-decoder.ar.md similarity index 100% rename from docs/tech-specs/universal-decoder.ar.md rename to docs/tech-specs/ar/universal-decoder.ar.md diff --git a/docs/tech-specs/vector-store-lifecycle.ar.md b/docs/tech-specs/ar/vector-store-lifecycle.ar.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.ar.md rename to docs/tech-specs/ar/vector-store-lifecycle.ar.md diff --git a/docs/tech-specs/__TEMPLATE.es.md b/docs/tech-specs/es/__TEMPLATE.es.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.es.md rename to docs/tech-specs/es/__TEMPLATE.es.md diff --git a/docs/tech-specs/agent-explainability.es.md b/docs/tech-specs/es/agent-explainability.es.md similarity index 100% rename from docs/tech-specs/agent-explainability.es.md rename to docs/tech-specs/es/agent-explainability.es.md diff --git a/docs/tech-specs/architecture-principles.es.md b/docs/tech-specs/es/architecture-principles.es.md similarity index 100% rename from docs/tech-specs/architecture-principles.es.md rename to docs/tech-specs/es/architecture-principles.es.md diff --git a/docs/tech-specs/cassandra-consolidation.es.md b/docs/tech-specs/es/cassandra-consolidation.es.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.es.md rename to docs/tech-specs/es/cassandra-consolidation.es.md diff --git a/docs/tech-specs/cassandra-performance-refactor.es.md b/docs/tech-specs/es/cassandra-performance-refactor.es.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.es.md rename to docs/tech-specs/es/cassandra-performance-refactor.es.md diff --git a/docs/tech-specs/collection-management.es.md b/docs/tech-specs/es/collection-management.es.md similarity index 100% rename from docs/tech-specs/collection-management.es.md rename to docs/tech-specs/es/collection-management.es.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.es.md b/docs/tech-specs/es/document-embeddings-chunk-id.es.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.es.md rename to docs/tech-specs/es/document-embeddings-chunk-id.es.md diff --git a/docs/tech-specs/embeddings-batch-processing.es.md b/docs/tech-specs/es/embeddings-batch-processing.es.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.es.md rename to docs/tech-specs/es/embeddings-batch-processing.es.md diff --git a/docs/tech-specs/entity-centric-graph.es.md b/docs/tech-specs/es/entity-centric-graph.es.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.es.md rename to docs/tech-specs/es/entity-centric-graph.es.md diff --git a/docs/tech-specs/explainability-cli.es.md b/docs/tech-specs/es/explainability-cli.es.md similarity index 100% rename from docs/tech-specs/explainability-cli.es.md rename to docs/tech-specs/es/explainability-cli.es.md diff --git a/docs/tech-specs/extraction-flows.es.md b/docs/tech-specs/es/extraction-flows.es.md similarity index 100% rename from docs/tech-specs/extraction-flows.es.md rename to docs/tech-specs/es/extraction-flows.es.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.es.md b/docs/tech-specs/es/extraction-provenance-subgraph.es.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.es.md rename to docs/tech-specs/es/extraction-provenance-subgraph.es.md diff --git a/docs/tech-specs/extraction-time-provenance.es.md b/docs/tech-specs/es/extraction-time-provenance.es.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.es.md rename to docs/tech-specs/es/extraction-time-provenance.es.md diff --git a/docs/tech-specs/flow-class-definition.es.md b/docs/tech-specs/es/flow-class-definition.es.md similarity index 100% rename from docs/tech-specs/flow-class-definition.es.md rename to docs/tech-specs/es/flow-class-definition.es.md diff --git a/docs/tech-specs/flow-configurable-parameters.es.md b/docs/tech-specs/es/flow-configurable-parameters.es.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.es.md rename to docs/tech-specs/es/flow-configurable-parameters.es.md diff --git a/docs/tech-specs/graph-contexts.es.md b/docs/tech-specs/es/graph-contexts.es.md similarity index 100% rename from docs/tech-specs/graph-contexts.es.md rename to docs/tech-specs/es/graph-contexts.es.md diff --git a/docs/tech-specs/graphql-query.es.md b/docs/tech-specs/es/graphql-query.es.md similarity index 100% rename from docs/tech-specs/graphql-query.es.md rename to docs/tech-specs/es/graphql-query.es.md diff --git a/docs/tech-specs/graphrag-performance-optimization.es.md b/docs/tech-specs/es/graphrag-performance-optimization.es.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.es.md rename to docs/tech-specs/es/graphrag-performance-optimization.es.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.es.md b/docs/tech-specs/es/import-export-graceful-shutdown.es.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.es.md rename to docs/tech-specs/es/import-export-graceful-shutdown.es.md diff --git a/docs/tech-specs/jsonl-prompt-output.es.md b/docs/tech-specs/es/jsonl-prompt-output.es.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.es.md rename to docs/tech-specs/es/jsonl-prompt-output.es.md diff --git a/docs/tech-specs/large-document-loading.es.md b/docs/tech-specs/es/large-document-loading.es.md similarity index 100% rename from docs/tech-specs/large-document-loading.es.md rename to docs/tech-specs/es/large-document-loading.es.md diff --git a/docs/tech-specs/logging-strategy.es.md b/docs/tech-specs/es/logging-strategy.es.md similarity index 100% rename from docs/tech-specs/logging-strategy.es.md rename to docs/tech-specs/es/logging-strategy.es.md diff --git a/docs/tech-specs/mcp-tool-arguments.es.md b/docs/tech-specs/es/mcp-tool-arguments.es.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.es.md rename to docs/tech-specs/es/mcp-tool-arguments.es.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.es.md b/docs/tech-specs/es/mcp-tool-bearer-token.es.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.es.md rename to docs/tech-specs/es/mcp-tool-bearer-token.es.md diff --git a/docs/tech-specs/minio-to-s3-migration.es.md b/docs/tech-specs/es/minio-to-s3-migration.es.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.es.md rename to docs/tech-specs/es/minio-to-s3-migration.es.md diff --git a/docs/tech-specs/more-config-cli.es.md b/docs/tech-specs/es/more-config-cli.es.md similarity index 100% rename from docs/tech-specs/more-config-cli.es.md rename to docs/tech-specs/es/more-config-cli.es.md diff --git a/docs/tech-specs/multi-tenant-support.es.md b/docs/tech-specs/es/multi-tenant-support.es.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.es.md rename to docs/tech-specs/es/multi-tenant-support.es.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.es.md b/docs/tech-specs/es/neo4j-user-collection-isolation.es.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.es.md rename to docs/tech-specs/es/neo4j-user-collection-isolation.es.md diff --git a/docs/tech-specs/ontology-extract-phase-2.es.md b/docs/tech-specs/es/ontology-extract-phase-2.es.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.es.md rename to docs/tech-specs/es/ontology-extract-phase-2.es.md diff --git a/docs/tech-specs/ontology.es.md b/docs/tech-specs/es/ontology.es.md similarity index 100% rename from docs/tech-specs/ontology.es.md rename to docs/tech-specs/es/ontology.es.md diff --git a/docs/tech-specs/ontorag.es.md b/docs/tech-specs/es/ontorag.es.md similarity index 100% rename from docs/tech-specs/ontorag.es.md rename to docs/tech-specs/es/ontorag.es.md diff --git a/docs/tech-specs/openapi-spec.es.md b/docs/tech-specs/es/openapi-spec.es.md similarity index 100% rename from docs/tech-specs/openapi-spec.es.md rename to docs/tech-specs/es/openapi-spec.es.md diff --git a/docs/tech-specs/pubsub.es.md b/docs/tech-specs/es/pubsub.es.md similarity index 100% rename from docs/tech-specs/pubsub.es.md rename to docs/tech-specs/es/pubsub.es.md diff --git a/docs/tech-specs/python-api-refactor.es.md b/docs/tech-specs/es/python-api-refactor.es.md similarity index 100% rename from docs/tech-specs/python-api-refactor.es.md rename to docs/tech-specs/es/python-api-refactor.es.md diff --git a/docs/tech-specs/query-time-explainability.es.md b/docs/tech-specs/es/query-time-explainability.es.md similarity index 100% rename from docs/tech-specs/query-time-explainability.es.md rename to docs/tech-specs/es/query-time-explainability.es.md diff --git a/docs/tech-specs/rag-streaming-support.es.md b/docs/tech-specs/es/rag-streaming-support.es.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.es.md rename to docs/tech-specs/es/rag-streaming-support.es.md diff --git a/docs/tech-specs/schema-refactoring-proposal.es.md b/docs/tech-specs/es/schema-refactoring-proposal.es.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.es.md rename to docs/tech-specs/es/schema-refactoring-proposal.es.md diff --git a/docs/tech-specs/streaming-llm-responses.es.md b/docs/tech-specs/es/streaming-llm-responses.es.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.es.md rename to docs/tech-specs/es/streaming-llm-responses.es.md diff --git a/docs/tech-specs/structured-data-2.es.md b/docs/tech-specs/es/structured-data-2.es.md similarity index 100% rename from docs/tech-specs/structured-data-2.es.md rename to docs/tech-specs/es/structured-data-2.es.md diff --git a/docs/tech-specs/structured-data-descriptor.es.md b/docs/tech-specs/es/structured-data-descriptor.es.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.es.md rename to docs/tech-specs/es/structured-data-descriptor.es.md diff --git a/docs/tech-specs/structured-data-schemas.es.md b/docs/tech-specs/es/structured-data-schemas.es.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.es.md rename to docs/tech-specs/es/structured-data-schemas.es.md diff --git a/docs/tech-specs/structured-data.es.md b/docs/tech-specs/es/structured-data.es.md similarity index 100% rename from docs/tech-specs/structured-data.es.md rename to docs/tech-specs/es/structured-data.es.md diff --git a/docs/tech-specs/structured-diag-service.es.md b/docs/tech-specs/es/structured-diag-service.es.md similarity index 100% rename from docs/tech-specs/structured-diag-service.es.md rename to docs/tech-specs/es/structured-diag-service.es.md diff --git a/docs/tech-specs/tool-group.es.md b/docs/tech-specs/es/tool-group.es.md similarity index 100% rename from docs/tech-specs/tool-group.es.md rename to docs/tech-specs/es/tool-group.es.md diff --git a/docs/tech-specs/tool-services.es.md b/docs/tech-specs/es/tool-services.es.md similarity index 100% rename from docs/tech-specs/tool-services.es.md rename to docs/tech-specs/es/tool-services.es.md diff --git a/docs/tech-specs/universal-decoder.es.md b/docs/tech-specs/es/universal-decoder.es.md similarity index 100% rename from docs/tech-specs/universal-decoder.es.md rename to docs/tech-specs/es/universal-decoder.es.md diff --git a/docs/tech-specs/vector-store-lifecycle.es.md b/docs/tech-specs/es/vector-store-lifecycle.es.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.es.md rename to docs/tech-specs/es/vector-store-lifecycle.es.md diff --git a/docs/tech-specs/__TEMPLATE.he.md b/docs/tech-specs/he/__TEMPLATE.he.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.he.md rename to docs/tech-specs/he/__TEMPLATE.he.md diff --git a/docs/tech-specs/agent-explainability.he.md b/docs/tech-specs/he/agent-explainability.he.md similarity index 100% rename from docs/tech-specs/agent-explainability.he.md rename to docs/tech-specs/he/agent-explainability.he.md diff --git a/docs/tech-specs/architecture-principles.he.md b/docs/tech-specs/he/architecture-principles.he.md similarity index 100% rename from docs/tech-specs/architecture-principles.he.md rename to docs/tech-specs/he/architecture-principles.he.md diff --git a/docs/tech-specs/cassandra-consolidation.he.md b/docs/tech-specs/he/cassandra-consolidation.he.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.he.md rename to docs/tech-specs/he/cassandra-consolidation.he.md diff --git a/docs/tech-specs/cassandra-performance-refactor.he.md b/docs/tech-specs/he/cassandra-performance-refactor.he.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.he.md rename to docs/tech-specs/he/cassandra-performance-refactor.he.md diff --git a/docs/tech-specs/collection-management.he.md b/docs/tech-specs/he/collection-management.he.md similarity index 100% rename from docs/tech-specs/collection-management.he.md rename to docs/tech-specs/he/collection-management.he.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.he.md b/docs/tech-specs/he/document-embeddings-chunk-id.he.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.he.md rename to docs/tech-specs/he/document-embeddings-chunk-id.he.md diff --git a/docs/tech-specs/embeddings-batch-processing.he.md b/docs/tech-specs/he/embeddings-batch-processing.he.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.he.md rename to docs/tech-specs/he/embeddings-batch-processing.he.md diff --git a/docs/tech-specs/entity-centric-graph.he.md b/docs/tech-specs/he/entity-centric-graph.he.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.he.md rename to docs/tech-specs/he/entity-centric-graph.he.md diff --git a/docs/tech-specs/explainability-cli.he.md b/docs/tech-specs/he/explainability-cli.he.md similarity index 100% rename from docs/tech-specs/explainability-cli.he.md rename to docs/tech-specs/he/explainability-cli.he.md diff --git a/docs/tech-specs/extraction-flows.he.md b/docs/tech-specs/he/extraction-flows.he.md similarity index 100% rename from docs/tech-specs/extraction-flows.he.md rename to docs/tech-specs/he/extraction-flows.he.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.he.md b/docs/tech-specs/he/extraction-provenance-subgraph.he.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.he.md rename to docs/tech-specs/he/extraction-provenance-subgraph.he.md diff --git a/docs/tech-specs/extraction-time-provenance.he.md b/docs/tech-specs/he/extraction-time-provenance.he.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.he.md rename to docs/tech-specs/he/extraction-time-provenance.he.md diff --git a/docs/tech-specs/flow-class-definition.he.md b/docs/tech-specs/he/flow-class-definition.he.md similarity index 100% rename from docs/tech-specs/flow-class-definition.he.md rename to docs/tech-specs/he/flow-class-definition.he.md diff --git a/docs/tech-specs/flow-configurable-parameters.he.md b/docs/tech-specs/he/flow-configurable-parameters.he.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.he.md rename to docs/tech-specs/he/flow-configurable-parameters.he.md diff --git a/docs/tech-specs/graph-contexts.he.md b/docs/tech-specs/he/graph-contexts.he.md similarity index 100% rename from docs/tech-specs/graph-contexts.he.md rename to docs/tech-specs/he/graph-contexts.he.md diff --git a/docs/tech-specs/graphql-query.he.md b/docs/tech-specs/he/graphql-query.he.md similarity index 100% rename from docs/tech-specs/graphql-query.he.md rename to docs/tech-specs/he/graphql-query.he.md diff --git a/docs/tech-specs/graphrag-performance-optimization.he.md b/docs/tech-specs/he/graphrag-performance-optimization.he.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.he.md rename to docs/tech-specs/he/graphrag-performance-optimization.he.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.he.md b/docs/tech-specs/he/import-export-graceful-shutdown.he.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.he.md rename to docs/tech-specs/he/import-export-graceful-shutdown.he.md diff --git a/docs/tech-specs/jsonl-prompt-output.he.md b/docs/tech-specs/he/jsonl-prompt-output.he.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.he.md rename to docs/tech-specs/he/jsonl-prompt-output.he.md diff --git a/docs/tech-specs/large-document-loading.he.md b/docs/tech-specs/he/large-document-loading.he.md similarity index 100% rename from docs/tech-specs/large-document-loading.he.md rename to docs/tech-specs/he/large-document-loading.he.md diff --git a/docs/tech-specs/logging-strategy.he.md b/docs/tech-specs/he/logging-strategy.he.md similarity index 100% rename from docs/tech-specs/logging-strategy.he.md rename to docs/tech-specs/he/logging-strategy.he.md diff --git a/docs/tech-specs/mcp-tool-arguments.he.md b/docs/tech-specs/he/mcp-tool-arguments.he.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.he.md rename to docs/tech-specs/he/mcp-tool-arguments.he.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.he.md b/docs/tech-specs/he/mcp-tool-bearer-token.he.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.he.md rename to docs/tech-specs/he/mcp-tool-bearer-token.he.md diff --git a/docs/tech-specs/minio-to-s3-migration.he.md b/docs/tech-specs/he/minio-to-s3-migration.he.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.he.md rename to docs/tech-specs/he/minio-to-s3-migration.he.md diff --git a/docs/tech-specs/more-config-cli.he.md b/docs/tech-specs/he/more-config-cli.he.md similarity index 100% rename from docs/tech-specs/more-config-cli.he.md rename to docs/tech-specs/he/more-config-cli.he.md diff --git a/docs/tech-specs/multi-tenant-support.he.md b/docs/tech-specs/he/multi-tenant-support.he.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.he.md rename to docs/tech-specs/he/multi-tenant-support.he.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.he.md b/docs/tech-specs/he/neo4j-user-collection-isolation.he.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.he.md rename to docs/tech-specs/he/neo4j-user-collection-isolation.he.md diff --git a/docs/tech-specs/ontology-extract-phase-2.he.md b/docs/tech-specs/he/ontology-extract-phase-2.he.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.he.md rename to docs/tech-specs/he/ontology-extract-phase-2.he.md diff --git a/docs/tech-specs/ontology.he.md b/docs/tech-specs/he/ontology.he.md similarity index 100% rename from docs/tech-specs/ontology.he.md rename to docs/tech-specs/he/ontology.he.md diff --git a/docs/tech-specs/ontorag.he.md b/docs/tech-specs/he/ontorag.he.md similarity index 100% rename from docs/tech-specs/ontorag.he.md rename to docs/tech-specs/he/ontorag.he.md diff --git a/docs/tech-specs/openapi-spec.he.md b/docs/tech-specs/he/openapi-spec.he.md similarity index 100% rename from docs/tech-specs/openapi-spec.he.md rename to docs/tech-specs/he/openapi-spec.he.md diff --git a/docs/tech-specs/pubsub.he.md b/docs/tech-specs/he/pubsub.he.md similarity index 100% rename from docs/tech-specs/pubsub.he.md rename to docs/tech-specs/he/pubsub.he.md diff --git a/docs/tech-specs/python-api-refactor.he.md b/docs/tech-specs/he/python-api-refactor.he.md similarity index 100% rename from docs/tech-specs/python-api-refactor.he.md rename to docs/tech-specs/he/python-api-refactor.he.md diff --git a/docs/tech-specs/query-time-explainability.he.md b/docs/tech-specs/he/query-time-explainability.he.md similarity index 100% rename from docs/tech-specs/query-time-explainability.he.md rename to docs/tech-specs/he/query-time-explainability.he.md diff --git a/docs/tech-specs/rag-streaming-support.he.md b/docs/tech-specs/he/rag-streaming-support.he.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.he.md rename to docs/tech-specs/he/rag-streaming-support.he.md diff --git a/docs/tech-specs/schema-refactoring-proposal.he.md b/docs/tech-specs/he/schema-refactoring-proposal.he.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.he.md rename to docs/tech-specs/he/schema-refactoring-proposal.he.md diff --git a/docs/tech-specs/streaming-llm-responses.he.md b/docs/tech-specs/he/streaming-llm-responses.he.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.he.md rename to docs/tech-specs/he/streaming-llm-responses.he.md diff --git a/docs/tech-specs/structured-data-2.he.md b/docs/tech-specs/he/structured-data-2.he.md similarity index 100% rename from docs/tech-specs/structured-data-2.he.md rename to docs/tech-specs/he/structured-data-2.he.md diff --git a/docs/tech-specs/structured-data-descriptor.he.md b/docs/tech-specs/he/structured-data-descriptor.he.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.he.md rename to docs/tech-specs/he/structured-data-descriptor.he.md diff --git a/docs/tech-specs/structured-data-schemas.he.md b/docs/tech-specs/he/structured-data-schemas.he.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.he.md rename to docs/tech-specs/he/structured-data-schemas.he.md diff --git a/docs/tech-specs/structured-data.he.md b/docs/tech-specs/he/structured-data.he.md similarity index 100% rename from docs/tech-specs/structured-data.he.md rename to docs/tech-specs/he/structured-data.he.md diff --git a/docs/tech-specs/structured-diag-service.he.md b/docs/tech-specs/he/structured-diag-service.he.md similarity index 100% rename from docs/tech-specs/structured-diag-service.he.md rename to docs/tech-specs/he/structured-diag-service.he.md diff --git a/docs/tech-specs/tool-group.he.md b/docs/tech-specs/he/tool-group.he.md similarity index 100% rename from docs/tech-specs/tool-group.he.md rename to docs/tech-specs/he/tool-group.he.md diff --git a/docs/tech-specs/tool-services.he.md b/docs/tech-specs/he/tool-services.he.md similarity index 100% rename from docs/tech-specs/tool-services.he.md rename to docs/tech-specs/he/tool-services.he.md diff --git a/docs/tech-specs/universal-decoder.he.md b/docs/tech-specs/he/universal-decoder.he.md similarity index 100% rename from docs/tech-specs/universal-decoder.he.md rename to docs/tech-specs/he/universal-decoder.he.md diff --git a/docs/tech-specs/vector-store-lifecycle.he.md b/docs/tech-specs/he/vector-store-lifecycle.he.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.he.md rename to docs/tech-specs/he/vector-store-lifecycle.he.md diff --git a/docs/tech-specs/__TEMPLATE.hi.md b/docs/tech-specs/hi/__TEMPLATE.hi.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.hi.md rename to docs/tech-specs/hi/__TEMPLATE.hi.md diff --git a/docs/tech-specs/agent-explainability.hi.md b/docs/tech-specs/hi/agent-explainability.hi.md similarity index 100% rename from docs/tech-specs/agent-explainability.hi.md rename to docs/tech-specs/hi/agent-explainability.hi.md diff --git a/docs/tech-specs/architecture-principles.hi.md b/docs/tech-specs/hi/architecture-principles.hi.md similarity index 100% rename from docs/tech-specs/architecture-principles.hi.md rename to docs/tech-specs/hi/architecture-principles.hi.md diff --git a/docs/tech-specs/cassandra-consolidation.hi.md b/docs/tech-specs/hi/cassandra-consolidation.hi.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.hi.md rename to docs/tech-specs/hi/cassandra-consolidation.hi.md diff --git a/docs/tech-specs/cassandra-performance-refactor.hi.md b/docs/tech-specs/hi/cassandra-performance-refactor.hi.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.hi.md rename to docs/tech-specs/hi/cassandra-performance-refactor.hi.md diff --git a/docs/tech-specs/collection-management.hi.md b/docs/tech-specs/hi/collection-management.hi.md similarity index 100% rename from docs/tech-specs/collection-management.hi.md rename to docs/tech-specs/hi/collection-management.hi.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.hi.md b/docs/tech-specs/hi/document-embeddings-chunk-id.hi.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.hi.md rename to docs/tech-specs/hi/document-embeddings-chunk-id.hi.md diff --git a/docs/tech-specs/embeddings-batch-processing.hi.md b/docs/tech-specs/hi/embeddings-batch-processing.hi.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.hi.md rename to docs/tech-specs/hi/embeddings-batch-processing.hi.md diff --git a/docs/tech-specs/entity-centric-graph.hi.md b/docs/tech-specs/hi/entity-centric-graph.hi.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.hi.md rename to docs/tech-specs/hi/entity-centric-graph.hi.md diff --git a/docs/tech-specs/explainability-cli.hi.md b/docs/tech-specs/hi/explainability-cli.hi.md similarity index 100% rename from docs/tech-specs/explainability-cli.hi.md rename to docs/tech-specs/hi/explainability-cli.hi.md diff --git a/docs/tech-specs/extraction-flows.hi.md b/docs/tech-specs/hi/extraction-flows.hi.md similarity index 100% rename from docs/tech-specs/extraction-flows.hi.md rename to docs/tech-specs/hi/extraction-flows.hi.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.hi.md b/docs/tech-specs/hi/extraction-provenance-subgraph.hi.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.hi.md rename to docs/tech-specs/hi/extraction-provenance-subgraph.hi.md diff --git a/docs/tech-specs/extraction-time-provenance.hi.md b/docs/tech-specs/hi/extraction-time-provenance.hi.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.hi.md rename to docs/tech-specs/hi/extraction-time-provenance.hi.md diff --git a/docs/tech-specs/flow-class-definition.hi.md b/docs/tech-specs/hi/flow-class-definition.hi.md similarity index 100% rename from docs/tech-specs/flow-class-definition.hi.md rename to docs/tech-specs/hi/flow-class-definition.hi.md diff --git a/docs/tech-specs/flow-configurable-parameters.hi.md b/docs/tech-specs/hi/flow-configurable-parameters.hi.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.hi.md rename to docs/tech-specs/hi/flow-configurable-parameters.hi.md diff --git a/docs/tech-specs/graph-contexts.hi.md b/docs/tech-specs/hi/graph-contexts.hi.md similarity index 100% rename from docs/tech-specs/graph-contexts.hi.md rename to docs/tech-specs/hi/graph-contexts.hi.md diff --git a/docs/tech-specs/graphql-query.hi.md b/docs/tech-specs/hi/graphql-query.hi.md similarity index 100% rename from docs/tech-specs/graphql-query.hi.md rename to docs/tech-specs/hi/graphql-query.hi.md diff --git a/docs/tech-specs/graphrag-performance-optimization.hi.md b/docs/tech-specs/hi/graphrag-performance-optimization.hi.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.hi.md rename to docs/tech-specs/hi/graphrag-performance-optimization.hi.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.hi.md b/docs/tech-specs/hi/import-export-graceful-shutdown.hi.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.hi.md rename to docs/tech-specs/hi/import-export-graceful-shutdown.hi.md diff --git a/docs/tech-specs/jsonl-prompt-output.hi.md b/docs/tech-specs/hi/jsonl-prompt-output.hi.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.hi.md rename to docs/tech-specs/hi/jsonl-prompt-output.hi.md diff --git a/docs/tech-specs/large-document-loading.hi.md b/docs/tech-specs/hi/large-document-loading.hi.md similarity index 100% rename from docs/tech-specs/large-document-loading.hi.md rename to docs/tech-specs/hi/large-document-loading.hi.md diff --git a/docs/tech-specs/logging-strategy.hi.md b/docs/tech-specs/hi/logging-strategy.hi.md similarity index 100% rename from docs/tech-specs/logging-strategy.hi.md rename to docs/tech-specs/hi/logging-strategy.hi.md diff --git a/docs/tech-specs/mcp-tool-arguments.hi.md b/docs/tech-specs/hi/mcp-tool-arguments.hi.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.hi.md rename to docs/tech-specs/hi/mcp-tool-arguments.hi.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.hi.md b/docs/tech-specs/hi/mcp-tool-bearer-token.hi.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.hi.md rename to docs/tech-specs/hi/mcp-tool-bearer-token.hi.md diff --git a/docs/tech-specs/minio-to-s3-migration.hi.md b/docs/tech-specs/hi/minio-to-s3-migration.hi.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.hi.md rename to docs/tech-specs/hi/minio-to-s3-migration.hi.md diff --git a/docs/tech-specs/more-config-cli.hi.md b/docs/tech-specs/hi/more-config-cli.hi.md similarity index 100% rename from docs/tech-specs/more-config-cli.hi.md rename to docs/tech-specs/hi/more-config-cli.hi.md diff --git a/docs/tech-specs/multi-tenant-support.hi.md b/docs/tech-specs/hi/multi-tenant-support.hi.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.hi.md rename to docs/tech-specs/hi/multi-tenant-support.hi.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.hi.md b/docs/tech-specs/hi/neo4j-user-collection-isolation.hi.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.hi.md rename to docs/tech-specs/hi/neo4j-user-collection-isolation.hi.md diff --git a/docs/tech-specs/ontology-extract-phase-2.hi.md b/docs/tech-specs/hi/ontology-extract-phase-2.hi.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.hi.md rename to docs/tech-specs/hi/ontology-extract-phase-2.hi.md diff --git a/docs/tech-specs/ontology.hi.md b/docs/tech-specs/hi/ontology.hi.md similarity index 100% rename from docs/tech-specs/ontology.hi.md rename to docs/tech-specs/hi/ontology.hi.md diff --git a/docs/tech-specs/ontorag.hi.md b/docs/tech-specs/hi/ontorag.hi.md similarity index 100% rename from docs/tech-specs/ontorag.hi.md rename to docs/tech-specs/hi/ontorag.hi.md diff --git a/docs/tech-specs/openapi-spec.hi.md b/docs/tech-specs/hi/openapi-spec.hi.md similarity index 100% rename from docs/tech-specs/openapi-spec.hi.md rename to docs/tech-specs/hi/openapi-spec.hi.md diff --git a/docs/tech-specs/pubsub.hi.md b/docs/tech-specs/hi/pubsub.hi.md similarity index 100% rename from docs/tech-specs/pubsub.hi.md rename to docs/tech-specs/hi/pubsub.hi.md diff --git a/docs/tech-specs/python-api-refactor.hi.md b/docs/tech-specs/hi/python-api-refactor.hi.md similarity index 100% rename from docs/tech-specs/python-api-refactor.hi.md rename to docs/tech-specs/hi/python-api-refactor.hi.md diff --git a/docs/tech-specs/query-time-explainability.hi.md b/docs/tech-specs/hi/query-time-explainability.hi.md similarity index 100% rename from docs/tech-specs/query-time-explainability.hi.md rename to docs/tech-specs/hi/query-time-explainability.hi.md diff --git a/docs/tech-specs/rag-streaming-support.hi.md b/docs/tech-specs/hi/rag-streaming-support.hi.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.hi.md rename to docs/tech-specs/hi/rag-streaming-support.hi.md diff --git a/docs/tech-specs/schema-refactoring-proposal.hi.md b/docs/tech-specs/hi/schema-refactoring-proposal.hi.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.hi.md rename to docs/tech-specs/hi/schema-refactoring-proposal.hi.md diff --git a/docs/tech-specs/streaming-llm-responses.hi.md b/docs/tech-specs/hi/streaming-llm-responses.hi.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.hi.md rename to docs/tech-specs/hi/streaming-llm-responses.hi.md diff --git a/docs/tech-specs/structured-data-2.hi.md b/docs/tech-specs/hi/structured-data-2.hi.md similarity index 100% rename from docs/tech-specs/structured-data-2.hi.md rename to docs/tech-specs/hi/structured-data-2.hi.md diff --git a/docs/tech-specs/structured-data-descriptor.hi.md b/docs/tech-specs/hi/structured-data-descriptor.hi.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.hi.md rename to docs/tech-specs/hi/structured-data-descriptor.hi.md diff --git a/docs/tech-specs/structured-data-schemas.hi.md b/docs/tech-specs/hi/structured-data-schemas.hi.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.hi.md rename to docs/tech-specs/hi/structured-data-schemas.hi.md diff --git a/docs/tech-specs/structured-data.hi.md b/docs/tech-specs/hi/structured-data.hi.md similarity index 100% rename from docs/tech-specs/structured-data.hi.md rename to docs/tech-specs/hi/structured-data.hi.md diff --git a/docs/tech-specs/structured-diag-service.hi.md b/docs/tech-specs/hi/structured-diag-service.hi.md similarity index 100% rename from docs/tech-specs/structured-diag-service.hi.md rename to docs/tech-specs/hi/structured-diag-service.hi.md diff --git a/docs/tech-specs/tool-group.hi.md b/docs/tech-specs/hi/tool-group.hi.md similarity index 100% rename from docs/tech-specs/tool-group.hi.md rename to docs/tech-specs/hi/tool-group.hi.md diff --git a/docs/tech-specs/tool-services.hi.md b/docs/tech-specs/hi/tool-services.hi.md similarity index 100% rename from docs/tech-specs/tool-services.hi.md rename to docs/tech-specs/hi/tool-services.hi.md diff --git a/docs/tech-specs/universal-decoder.hi.md b/docs/tech-specs/hi/universal-decoder.hi.md similarity index 100% rename from docs/tech-specs/universal-decoder.hi.md rename to docs/tech-specs/hi/universal-decoder.hi.md diff --git a/docs/tech-specs/vector-store-lifecycle.hi.md b/docs/tech-specs/hi/vector-store-lifecycle.hi.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.hi.md rename to docs/tech-specs/hi/vector-store-lifecycle.hi.md diff --git a/docs/tech-specs/__TEMPLATE.pt.md b/docs/tech-specs/pt/__TEMPLATE.pt.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.pt.md rename to docs/tech-specs/pt/__TEMPLATE.pt.md diff --git a/docs/tech-specs/agent-explainability.pt.md b/docs/tech-specs/pt/agent-explainability.pt.md similarity index 100% rename from docs/tech-specs/agent-explainability.pt.md rename to docs/tech-specs/pt/agent-explainability.pt.md diff --git a/docs/tech-specs/architecture-principles.pt.md b/docs/tech-specs/pt/architecture-principles.pt.md similarity index 100% rename from docs/tech-specs/architecture-principles.pt.md rename to docs/tech-specs/pt/architecture-principles.pt.md diff --git a/docs/tech-specs/cassandra-consolidation.pt.md b/docs/tech-specs/pt/cassandra-consolidation.pt.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.pt.md rename to docs/tech-specs/pt/cassandra-consolidation.pt.md diff --git a/docs/tech-specs/cassandra-performance-refactor.pt.md b/docs/tech-specs/pt/cassandra-performance-refactor.pt.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.pt.md rename to docs/tech-specs/pt/cassandra-performance-refactor.pt.md diff --git a/docs/tech-specs/collection-management.pt.md b/docs/tech-specs/pt/collection-management.pt.md similarity index 100% rename from docs/tech-specs/collection-management.pt.md rename to docs/tech-specs/pt/collection-management.pt.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.pt.md b/docs/tech-specs/pt/document-embeddings-chunk-id.pt.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.pt.md rename to docs/tech-specs/pt/document-embeddings-chunk-id.pt.md diff --git a/docs/tech-specs/embeddings-batch-processing.pt.md b/docs/tech-specs/pt/embeddings-batch-processing.pt.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.pt.md rename to docs/tech-specs/pt/embeddings-batch-processing.pt.md diff --git a/docs/tech-specs/entity-centric-graph.pt.md b/docs/tech-specs/pt/entity-centric-graph.pt.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.pt.md rename to docs/tech-specs/pt/entity-centric-graph.pt.md diff --git a/docs/tech-specs/explainability-cli.pt.md b/docs/tech-specs/pt/explainability-cli.pt.md similarity index 100% rename from docs/tech-specs/explainability-cli.pt.md rename to docs/tech-specs/pt/explainability-cli.pt.md diff --git a/docs/tech-specs/extraction-flows.pt.md b/docs/tech-specs/pt/extraction-flows.pt.md similarity index 100% rename from docs/tech-specs/extraction-flows.pt.md rename to docs/tech-specs/pt/extraction-flows.pt.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.pt.md b/docs/tech-specs/pt/extraction-provenance-subgraph.pt.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.pt.md rename to docs/tech-specs/pt/extraction-provenance-subgraph.pt.md diff --git a/docs/tech-specs/extraction-time-provenance.pt.md b/docs/tech-specs/pt/extraction-time-provenance.pt.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.pt.md rename to docs/tech-specs/pt/extraction-time-provenance.pt.md diff --git a/docs/tech-specs/flow-class-definition.pt.md b/docs/tech-specs/pt/flow-class-definition.pt.md similarity index 100% rename from docs/tech-specs/flow-class-definition.pt.md rename to docs/tech-specs/pt/flow-class-definition.pt.md diff --git a/docs/tech-specs/flow-configurable-parameters.pt.md b/docs/tech-specs/pt/flow-configurable-parameters.pt.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.pt.md rename to docs/tech-specs/pt/flow-configurable-parameters.pt.md diff --git a/docs/tech-specs/graph-contexts.pt.md b/docs/tech-specs/pt/graph-contexts.pt.md similarity index 100% rename from docs/tech-specs/graph-contexts.pt.md rename to docs/tech-specs/pt/graph-contexts.pt.md diff --git a/docs/tech-specs/graphql-query.pt.md b/docs/tech-specs/pt/graphql-query.pt.md similarity index 100% rename from docs/tech-specs/graphql-query.pt.md rename to docs/tech-specs/pt/graphql-query.pt.md diff --git a/docs/tech-specs/graphrag-performance-optimization.pt.md b/docs/tech-specs/pt/graphrag-performance-optimization.pt.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.pt.md rename to docs/tech-specs/pt/graphrag-performance-optimization.pt.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.pt.md b/docs/tech-specs/pt/import-export-graceful-shutdown.pt.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.pt.md rename to docs/tech-specs/pt/import-export-graceful-shutdown.pt.md diff --git a/docs/tech-specs/jsonl-prompt-output.pt.md b/docs/tech-specs/pt/jsonl-prompt-output.pt.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.pt.md rename to docs/tech-specs/pt/jsonl-prompt-output.pt.md diff --git a/docs/tech-specs/large-document-loading.pt.md b/docs/tech-specs/pt/large-document-loading.pt.md similarity index 100% rename from docs/tech-specs/large-document-loading.pt.md rename to docs/tech-specs/pt/large-document-loading.pt.md diff --git a/docs/tech-specs/logging-strategy.pt.md b/docs/tech-specs/pt/logging-strategy.pt.md similarity index 100% rename from docs/tech-specs/logging-strategy.pt.md rename to docs/tech-specs/pt/logging-strategy.pt.md diff --git a/docs/tech-specs/mcp-tool-arguments.pt.md b/docs/tech-specs/pt/mcp-tool-arguments.pt.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.pt.md rename to docs/tech-specs/pt/mcp-tool-arguments.pt.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.pt.md b/docs/tech-specs/pt/mcp-tool-bearer-token.pt.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.pt.md rename to docs/tech-specs/pt/mcp-tool-bearer-token.pt.md diff --git a/docs/tech-specs/minio-to-s3-migration.pt.md b/docs/tech-specs/pt/minio-to-s3-migration.pt.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.pt.md rename to docs/tech-specs/pt/minio-to-s3-migration.pt.md diff --git a/docs/tech-specs/more-config-cli.pt.md b/docs/tech-specs/pt/more-config-cli.pt.md similarity index 100% rename from docs/tech-specs/more-config-cli.pt.md rename to docs/tech-specs/pt/more-config-cli.pt.md diff --git a/docs/tech-specs/multi-tenant-support.pt.md b/docs/tech-specs/pt/multi-tenant-support.pt.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.pt.md rename to docs/tech-specs/pt/multi-tenant-support.pt.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.pt.md b/docs/tech-specs/pt/neo4j-user-collection-isolation.pt.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.pt.md rename to docs/tech-specs/pt/neo4j-user-collection-isolation.pt.md diff --git a/docs/tech-specs/ontology-extract-phase-2.pt.md b/docs/tech-specs/pt/ontology-extract-phase-2.pt.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.pt.md rename to docs/tech-specs/pt/ontology-extract-phase-2.pt.md diff --git a/docs/tech-specs/ontology.pt.md b/docs/tech-specs/pt/ontology.pt.md similarity index 100% rename from docs/tech-specs/ontology.pt.md rename to docs/tech-specs/pt/ontology.pt.md diff --git a/docs/tech-specs/ontorag.pt.md b/docs/tech-specs/pt/ontorag.pt.md similarity index 100% rename from docs/tech-specs/ontorag.pt.md rename to docs/tech-specs/pt/ontorag.pt.md diff --git a/docs/tech-specs/openapi-spec.pt.md b/docs/tech-specs/pt/openapi-spec.pt.md similarity index 100% rename from docs/tech-specs/openapi-spec.pt.md rename to docs/tech-specs/pt/openapi-spec.pt.md diff --git a/docs/tech-specs/pubsub.pt.md b/docs/tech-specs/pt/pubsub.pt.md similarity index 100% rename from docs/tech-specs/pubsub.pt.md rename to docs/tech-specs/pt/pubsub.pt.md diff --git a/docs/tech-specs/python-api-refactor.pt.md b/docs/tech-specs/pt/python-api-refactor.pt.md similarity index 100% rename from docs/tech-specs/python-api-refactor.pt.md rename to docs/tech-specs/pt/python-api-refactor.pt.md diff --git a/docs/tech-specs/query-time-explainability.pt.md b/docs/tech-specs/pt/query-time-explainability.pt.md similarity index 100% rename from docs/tech-specs/query-time-explainability.pt.md rename to docs/tech-specs/pt/query-time-explainability.pt.md diff --git a/docs/tech-specs/rag-streaming-support.pt.md b/docs/tech-specs/pt/rag-streaming-support.pt.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.pt.md rename to docs/tech-specs/pt/rag-streaming-support.pt.md diff --git a/docs/tech-specs/schema-refactoring-proposal.pt.md b/docs/tech-specs/pt/schema-refactoring-proposal.pt.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.pt.md rename to docs/tech-specs/pt/schema-refactoring-proposal.pt.md diff --git a/docs/tech-specs/streaming-llm-responses.pt.md b/docs/tech-specs/pt/streaming-llm-responses.pt.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.pt.md rename to docs/tech-specs/pt/streaming-llm-responses.pt.md diff --git a/docs/tech-specs/structured-data-2.pt.md b/docs/tech-specs/pt/structured-data-2.pt.md similarity index 100% rename from docs/tech-specs/structured-data-2.pt.md rename to docs/tech-specs/pt/structured-data-2.pt.md diff --git a/docs/tech-specs/structured-data-descriptor.pt.md b/docs/tech-specs/pt/structured-data-descriptor.pt.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.pt.md rename to docs/tech-specs/pt/structured-data-descriptor.pt.md diff --git a/docs/tech-specs/structured-data-schemas.pt.md b/docs/tech-specs/pt/structured-data-schemas.pt.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.pt.md rename to docs/tech-specs/pt/structured-data-schemas.pt.md diff --git a/docs/tech-specs/structured-data.pt.md b/docs/tech-specs/pt/structured-data.pt.md similarity index 100% rename from docs/tech-specs/structured-data.pt.md rename to docs/tech-specs/pt/structured-data.pt.md diff --git a/docs/tech-specs/structured-diag-service.pt.md b/docs/tech-specs/pt/structured-diag-service.pt.md similarity index 100% rename from docs/tech-specs/structured-diag-service.pt.md rename to docs/tech-specs/pt/structured-diag-service.pt.md diff --git a/docs/tech-specs/tool-group.pt.md b/docs/tech-specs/pt/tool-group.pt.md similarity index 100% rename from docs/tech-specs/tool-group.pt.md rename to docs/tech-specs/pt/tool-group.pt.md diff --git a/docs/tech-specs/tool-services.pt.md b/docs/tech-specs/pt/tool-services.pt.md similarity index 100% rename from docs/tech-specs/tool-services.pt.md rename to docs/tech-specs/pt/tool-services.pt.md diff --git a/docs/tech-specs/universal-decoder.pt.md b/docs/tech-specs/pt/universal-decoder.pt.md similarity index 100% rename from docs/tech-specs/universal-decoder.pt.md rename to docs/tech-specs/pt/universal-decoder.pt.md diff --git a/docs/tech-specs/vector-store-lifecycle.pt.md b/docs/tech-specs/pt/vector-store-lifecycle.pt.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.pt.md rename to docs/tech-specs/pt/vector-store-lifecycle.pt.md diff --git a/docs/tech-specs/__TEMPLATE.ru.md b/docs/tech-specs/ru/__TEMPLATE.ru.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.ru.md rename to docs/tech-specs/ru/__TEMPLATE.ru.md diff --git a/docs/tech-specs/agent-explainability.ru.md b/docs/tech-specs/ru/agent-explainability.ru.md similarity index 100% rename from docs/tech-specs/agent-explainability.ru.md rename to docs/tech-specs/ru/agent-explainability.ru.md diff --git a/docs/tech-specs/architecture-principles.ru.md b/docs/tech-specs/ru/architecture-principles.ru.md similarity index 100% rename from docs/tech-specs/architecture-principles.ru.md rename to docs/tech-specs/ru/architecture-principles.ru.md diff --git a/docs/tech-specs/cassandra-consolidation.ru.md b/docs/tech-specs/ru/cassandra-consolidation.ru.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.ru.md rename to docs/tech-specs/ru/cassandra-consolidation.ru.md diff --git a/docs/tech-specs/cassandra-performance-refactor.ru.md b/docs/tech-specs/ru/cassandra-performance-refactor.ru.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.ru.md rename to docs/tech-specs/ru/cassandra-performance-refactor.ru.md diff --git a/docs/tech-specs/collection-management.ru.md b/docs/tech-specs/ru/collection-management.ru.md similarity index 100% rename from docs/tech-specs/collection-management.ru.md rename to docs/tech-specs/ru/collection-management.ru.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.ru.md b/docs/tech-specs/ru/document-embeddings-chunk-id.ru.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.ru.md rename to docs/tech-specs/ru/document-embeddings-chunk-id.ru.md diff --git a/docs/tech-specs/embeddings-batch-processing.ru.md b/docs/tech-specs/ru/embeddings-batch-processing.ru.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.ru.md rename to docs/tech-specs/ru/embeddings-batch-processing.ru.md diff --git a/docs/tech-specs/entity-centric-graph.ru.md b/docs/tech-specs/ru/entity-centric-graph.ru.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.ru.md rename to docs/tech-specs/ru/entity-centric-graph.ru.md diff --git a/docs/tech-specs/explainability-cli.ru.md b/docs/tech-specs/ru/explainability-cli.ru.md similarity index 100% rename from docs/tech-specs/explainability-cli.ru.md rename to docs/tech-specs/ru/explainability-cli.ru.md diff --git a/docs/tech-specs/extraction-flows.ru.md b/docs/tech-specs/ru/extraction-flows.ru.md similarity index 100% rename from docs/tech-specs/extraction-flows.ru.md rename to docs/tech-specs/ru/extraction-flows.ru.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.ru.md b/docs/tech-specs/ru/extraction-provenance-subgraph.ru.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.ru.md rename to docs/tech-specs/ru/extraction-provenance-subgraph.ru.md diff --git a/docs/tech-specs/extraction-time-provenance.ru.md b/docs/tech-specs/ru/extraction-time-provenance.ru.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.ru.md rename to docs/tech-specs/ru/extraction-time-provenance.ru.md diff --git a/docs/tech-specs/flow-class-definition.ru.md b/docs/tech-specs/ru/flow-class-definition.ru.md similarity index 100% rename from docs/tech-specs/flow-class-definition.ru.md rename to docs/tech-specs/ru/flow-class-definition.ru.md diff --git a/docs/tech-specs/flow-configurable-parameters.ru.md b/docs/tech-specs/ru/flow-configurable-parameters.ru.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.ru.md rename to docs/tech-specs/ru/flow-configurable-parameters.ru.md diff --git a/docs/tech-specs/graph-contexts.ru.md b/docs/tech-specs/ru/graph-contexts.ru.md similarity index 100% rename from docs/tech-specs/graph-contexts.ru.md rename to docs/tech-specs/ru/graph-contexts.ru.md diff --git a/docs/tech-specs/graphql-query.ru.md b/docs/tech-specs/ru/graphql-query.ru.md similarity index 100% rename from docs/tech-specs/graphql-query.ru.md rename to docs/tech-specs/ru/graphql-query.ru.md diff --git a/docs/tech-specs/graphrag-performance-optimization.ru.md b/docs/tech-specs/ru/graphrag-performance-optimization.ru.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.ru.md rename to docs/tech-specs/ru/graphrag-performance-optimization.ru.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.ru.md b/docs/tech-specs/ru/import-export-graceful-shutdown.ru.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.ru.md rename to docs/tech-specs/ru/import-export-graceful-shutdown.ru.md diff --git a/docs/tech-specs/jsonl-prompt-output.ru.md b/docs/tech-specs/ru/jsonl-prompt-output.ru.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.ru.md rename to docs/tech-specs/ru/jsonl-prompt-output.ru.md diff --git a/docs/tech-specs/large-document-loading.ru.md b/docs/tech-specs/ru/large-document-loading.ru.md similarity index 100% rename from docs/tech-specs/large-document-loading.ru.md rename to docs/tech-specs/ru/large-document-loading.ru.md diff --git a/docs/tech-specs/logging-strategy.ru.md b/docs/tech-specs/ru/logging-strategy.ru.md similarity index 100% rename from docs/tech-specs/logging-strategy.ru.md rename to docs/tech-specs/ru/logging-strategy.ru.md diff --git a/docs/tech-specs/mcp-tool-arguments.ru.md b/docs/tech-specs/ru/mcp-tool-arguments.ru.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.ru.md rename to docs/tech-specs/ru/mcp-tool-arguments.ru.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.ru.md b/docs/tech-specs/ru/mcp-tool-bearer-token.ru.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.ru.md rename to docs/tech-specs/ru/mcp-tool-bearer-token.ru.md diff --git a/docs/tech-specs/minio-to-s3-migration.ru.md b/docs/tech-specs/ru/minio-to-s3-migration.ru.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.ru.md rename to docs/tech-specs/ru/minio-to-s3-migration.ru.md diff --git a/docs/tech-specs/more-config-cli.ru.md b/docs/tech-specs/ru/more-config-cli.ru.md similarity index 100% rename from docs/tech-specs/more-config-cli.ru.md rename to docs/tech-specs/ru/more-config-cli.ru.md diff --git a/docs/tech-specs/multi-tenant-support.ru.md b/docs/tech-specs/ru/multi-tenant-support.ru.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.ru.md rename to docs/tech-specs/ru/multi-tenant-support.ru.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.ru.md b/docs/tech-specs/ru/neo4j-user-collection-isolation.ru.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.ru.md rename to docs/tech-specs/ru/neo4j-user-collection-isolation.ru.md diff --git a/docs/tech-specs/ontology-extract-phase-2.ru.md b/docs/tech-specs/ru/ontology-extract-phase-2.ru.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.ru.md rename to docs/tech-specs/ru/ontology-extract-phase-2.ru.md diff --git a/docs/tech-specs/ontology.ru.md b/docs/tech-specs/ru/ontology.ru.md similarity index 100% rename from docs/tech-specs/ontology.ru.md rename to docs/tech-specs/ru/ontology.ru.md diff --git a/docs/tech-specs/ontorag.ru.md b/docs/tech-specs/ru/ontorag.ru.md similarity index 100% rename from docs/tech-specs/ontorag.ru.md rename to docs/tech-specs/ru/ontorag.ru.md diff --git a/docs/tech-specs/openapi-spec.ru.md b/docs/tech-specs/ru/openapi-spec.ru.md similarity index 100% rename from docs/tech-specs/openapi-spec.ru.md rename to docs/tech-specs/ru/openapi-spec.ru.md diff --git a/docs/tech-specs/pubsub.ru.md b/docs/tech-specs/ru/pubsub.ru.md similarity index 100% rename from docs/tech-specs/pubsub.ru.md rename to docs/tech-specs/ru/pubsub.ru.md diff --git a/docs/tech-specs/python-api-refactor.ru.md b/docs/tech-specs/ru/python-api-refactor.ru.md similarity index 100% rename from docs/tech-specs/python-api-refactor.ru.md rename to docs/tech-specs/ru/python-api-refactor.ru.md diff --git a/docs/tech-specs/query-time-explainability.ru.md b/docs/tech-specs/ru/query-time-explainability.ru.md similarity index 100% rename from docs/tech-specs/query-time-explainability.ru.md rename to docs/tech-specs/ru/query-time-explainability.ru.md diff --git a/docs/tech-specs/rag-streaming-support.ru.md b/docs/tech-specs/ru/rag-streaming-support.ru.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.ru.md rename to docs/tech-specs/ru/rag-streaming-support.ru.md diff --git a/docs/tech-specs/schema-refactoring-proposal.ru.md b/docs/tech-specs/ru/schema-refactoring-proposal.ru.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.ru.md rename to docs/tech-specs/ru/schema-refactoring-proposal.ru.md diff --git a/docs/tech-specs/streaming-llm-responses.ru.md b/docs/tech-specs/ru/streaming-llm-responses.ru.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.ru.md rename to docs/tech-specs/ru/streaming-llm-responses.ru.md diff --git a/docs/tech-specs/structured-data-2.ru.md b/docs/tech-specs/ru/structured-data-2.ru.md similarity index 100% rename from docs/tech-specs/structured-data-2.ru.md rename to docs/tech-specs/ru/structured-data-2.ru.md diff --git a/docs/tech-specs/structured-data-descriptor.ru.md b/docs/tech-specs/ru/structured-data-descriptor.ru.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.ru.md rename to docs/tech-specs/ru/structured-data-descriptor.ru.md diff --git a/docs/tech-specs/structured-data-schemas.ru.md b/docs/tech-specs/ru/structured-data-schemas.ru.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.ru.md rename to docs/tech-specs/ru/structured-data-schemas.ru.md diff --git a/docs/tech-specs/structured-data.ru.md b/docs/tech-specs/ru/structured-data.ru.md similarity index 100% rename from docs/tech-specs/structured-data.ru.md rename to docs/tech-specs/ru/structured-data.ru.md diff --git a/docs/tech-specs/structured-diag-service.ru.md b/docs/tech-specs/ru/structured-diag-service.ru.md similarity index 100% rename from docs/tech-specs/structured-diag-service.ru.md rename to docs/tech-specs/ru/structured-diag-service.ru.md diff --git a/docs/tech-specs/tool-group.ru.md b/docs/tech-specs/ru/tool-group.ru.md similarity index 100% rename from docs/tech-specs/tool-group.ru.md rename to docs/tech-specs/ru/tool-group.ru.md diff --git a/docs/tech-specs/tool-services.ru.md b/docs/tech-specs/ru/tool-services.ru.md similarity index 100% rename from docs/tech-specs/tool-services.ru.md rename to docs/tech-specs/ru/tool-services.ru.md diff --git a/docs/tech-specs/universal-decoder.ru.md b/docs/tech-specs/ru/universal-decoder.ru.md similarity index 100% rename from docs/tech-specs/universal-decoder.ru.md rename to docs/tech-specs/ru/universal-decoder.ru.md diff --git a/docs/tech-specs/vector-store-lifecycle.ru.md b/docs/tech-specs/ru/vector-store-lifecycle.ru.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.ru.md rename to docs/tech-specs/ru/vector-store-lifecycle.ru.md diff --git a/docs/tech-specs/__TEMPLATE.sw.md b/docs/tech-specs/sw/__TEMPLATE.sw.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.sw.md rename to docs/tech-specs/sw/__TEMPLATE.sw.md diff --git a/docs/tech-specs/agent-explainability.sw.md b/docs/tech-specs/sw/agent-explainability.sw.md similarity index 100% rename from docs/tech-specs/agent-explainability.sw.md rename to docs/tech-specs/sw/agent-explainability.sw.md diff --git a/docs/tech-specs/architecture-principles.sw.md b/docs/tech-specs/sw/architecture-principles.sw.md similarity index 100% rename from docs/tech-specs/architecture-principles.sw.md rename to docs/tech-specs/sw/architecture-principles.sw.md diff --git a/docs/tech-specs/cassandra-consolidation.sw.md b/docs/tech-specs/sw/cassandra-consolidation.sw.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.sw.md rename to docs/tech-specs/sw/cassandra-consolidation.sw.md diff --git a/docs/tech-specs/cassandra-performance-refactor.sw.md b/docs/tech-specs/sw/cassandra-performance-refactor.sw.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.sw.md rename to docs/tech-specs/sw/cassandra-performance-refactor.sw.md diff --git a/docs/tech-specs/collection-management.sw.md b/docs/tech-specs/sw/collection-management.sw.md similarity index 100% rename from docs/tech-specs/collection-management.sw.md rename to docs/tech-specs/sw/collection-management.sw.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.sw.md b/docs/tech-specs/sw/document-embeddings-chunk-id.sw.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.sw.md rename to docs/tech-specs/sw/document-embeddings-chunk-id.sw.md diff --git a/docs/tech-specs/embeddings-batch-processing.sw.md b/docs/tech-specs/sw/embeddings-batch-processing.sw.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.sw.md rename to docs/tech-specs/sw/embeddings-batch-processing.sw.md diff --git a/docs/tech-specs/entity-centric-graph.sw.md b/docs/tech-specs/sw/entity-centric-graph.sw.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.sw.md rename to docs/tech-specs/sw/entity-centric-graph.sw.md diff --git a/docs/tech-specs/explainability-cli.sw.md b/docs/tech-specs/sw/explainability-cli.sw.md similarity index 100% rename from docs/tech-specs/explainability-cli.sw.md rename to docs/tech-specs/sw/explainability-cli.sw.md diff --git a/docs/tech-specs/extraction-flows.sw.md b/docs/tech-specs/sw/extraction-flows.sw.md similarity index 100% rename from docs/tech-specs/extraction-flows.sw.md rename to docs/tech-specs/sw/extraction-flows.sw.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.sw.md b/docs/tech-specs/sw/extraction-provenance-subgraph.sw.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.sw.md rename to docs/tech-specs/sw/extraction-provenance-subgraph.sw.md diff --git a/docs/tech-specs/extraction-time-provenance.sw.md b/docs/tech-specs/sw/extraction-time-provenance.sw.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.sw.md rename to docs/tech-specs/sw/extraction-time-provenance.sw.md diff --git a/docs/tech-specs/flow-class-definition.sw.md b/docs/tech-specs/sw/flow-class-definition.sw.md similarity index 100% rename from docs/tech-specs/flow-class-definition.sw.md rename to docs/tech-specs/sw/flow-class-definition.sw.md diff --git a/docs/tech-specs/flow-configurable-parameters.sw.md b/docs/tech-specs/sw/flow-configurable-parameters.sw.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.sw.md rename to docs/tech-specs/sw/flow-configurable-parameters.sw.md diff --git a/docs/tech-specs/graph-contexts.sw.md b/docs/tech-specs/sw/graph-contexts.sw.md similarity index 100% rename from docs/tech-specs/graph-contexts.sw.md rename to docs/tech-specs/sw/graph-contexts.sw.md diff --git a/docs/tech-specs/graphql-query.sw.md b/docs/tech-specs/sw/graphql-query.sw.md similarity index 100% rename from docs/tech-specs/graphql-query.sw.md rename to docs/tech-specs/sw/graphql-query.sw.md diff --git a/docs/tech-specs/graphrag-performance-optimization.sw.md b/docs/tech-specs/sw/graphrag-performance-optimization.sw.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.sw.md rename to docs/tech-specs/sw/graphrag-performance-optimization.sw.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.sw.md b/docs/tech-specs/sw/import-export-graceful-shutdown.sw.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.sw.md rename to docs/tech-specs/sw/import-export-graceful-shutdown.sw.md diff --git a/docs/tech-specs/jsonl-prompt-output.sw.md b/docs/tech-specs/sw/jsonl-prompt-output.sw.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.sw.md rename to docs/tech-specs/sw/jsonl-prompt-output.sw.md diff --git a/docs/tech-specs/large-document-loading.sw.md b/docs/tech-specs/sw/large-document-loading.sw.md similarity index 100% rename from docs/tech-specs/large-document-loading.sw.md rename to docs/tech-specs/sw/large-document-loading.sw.md diff --git a/docs/tech-specs/logging-strategy.sw.md b/docs/tech-specs/sw/logging-strategy.sw.md similarity index 100% rename from docs/tech-specs/logging-strategy.sw.md rename to docs/tech-specs/sw/logging-strategy.sw.md diff --git a/docs/tech-specs/mcp-tool-arguments.sw.md b/docs/tech-specs/sw/mcp-tool-arguments.sw.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.sw.md rename to docs/tech-specs/sw/mcp-tool-arguments.sw.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.sw.md b/docs/tech-specs/sw/mcp-tool-bearer-token.sw.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.sw.md rename to docs/tech-specs/sw/mcp-tool-bearer-token.sw.md diff --git a/docs/tech-specs/minio-to-s3-migration.sw.md b/docs/tech-specs/sw/minio-to-s3-migration.sw.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.sw.md rename to docs/tech-specs/sw/minio-to-s3-migration.sw.md diff --git a/docs/tech-specs/more-config-cli.sw.md b/docs/tech-specs/sw/more-config-cli.sw.md similarity index 100% rename from docs/tech-specs/more-config-cli.sw.md rename to docs/tech-specs/sw/more-config-cli.sw.md diff --git a/docs/tech-specs/multi-tenant-support.sw.md b/docs/tech-specs/sw/multi-tenant-support.sw.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.sw.md rename to docs/tech-specs/sw/multi-tenant-support.sw.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.sw.md b/docs/tech-specs/sw/neo4j-user-collection-isolation.sw.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.sw.md rename to docs/tech-specs/sw/neo4j-user-collection-isolation.sw.md diff --git a/docs/tech-specs/ontology-extract-phase-2.sw.md b/docs/tech-specs/sw/ontology-extract-phase-2.sw.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.sw.md rename to docs/tech-specs/sw/ontology-extract-phase-2.sw.md diff --git a/docs/tech-specs/ontology.sw.md b/docs/tech-specs/sw/ontology.sw.md similarity index 100% rename from docs/tech-specs/ontology.sw.md rename to docs/tech-specs/sw/ontology.sw.md diff --git a/docs/tech-specs/ontorag.sw.md b/docs/tech-specs/sw/ontorag.sw.md similarity index 100% rename from docs/tech-specs/ontorag.sw.md rename to docs/tech-specs/sw/ontorag.sw.md diff --git a/docs/tech-specs/openapi-spec.sw.md b/docs/tech-specs/sw/openapi-spec.sw.md similarity index 100% rename from docs/tech-specs/openapi-spec.sw.md rename to docs/tech-specs/sw/openapi-spec.sw.md diff --git a/docs/tech-specs/pubsub.sw.md b/docs/tech-specs/sw/pubsub.sw.md similarity index 100% rename from docs/tech-specs/pubsub.sw.md rename to docs/tech-specs/sw/pubsub.sw.md diff --git a/docs/tech-specs/python-api-refactor.sw.md b/docs/tech-specs/sw/python-api-refactor.sw.md similarity index 100% rename from docs/tech-specs/python-api-refactor.sw.md rename to docs/tech-specs/sw/python-api-refactor.sw.md diff --git a/docs/tech-specs/query-time-explainability.sw.md b/docs/tech-specs/sw/query-time-explainability.sw.md similarity index 100% rename from docs/tech-specs/query-time-explainability.sw.md rename to docs/tech-specs/sw/query-time-explainability.sw.md diff --git a/docs/tech-specs/rag-streaming-support.sw.md b/docs/tech-specs/sw/rag-streaming-support.sw.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.sw.md rename to docs/tech-specs/sw/rag-streaming-support.sw.md diff --git a/docs/tech-specs/schema-refactoring-proposal.sw.md b/docs/tech-specs/sw/schema-refactoring-proposal.sw.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.sw.md rename to docs/tech-specs/sw/schema-refactoring-proposal.sw.md diff --git a/docs/tech-specs/streaming-llm-responses.sw.md b/docs/tech-specs/sw/streaming-llm-responses.sw.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.sw.md rename to docs/tech-specs/sw/streaming-llm-responses.sw.md diff --git a/docs/tech-specs/structured-data-2.sw.md b/docs/tech-specs/sw/structured-data-2.sw.md similarity index 100% rename from docs/tech-specs/structured-data-2.sw.md rename to docs/tech-specs/sw/structured-data-2.sw.md diff --git a/docs/tech-specs/structured-data-descriptor.sw.md b/docs/tech-specs/sw/structured-data-descriptor.sw.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.sw.md rename to docs/tech-specs/sw/structured-data-descriptor.sw.md diff --git a/docs/tech-specs/structured-data-schemas.sw.md b/docs/tech-specs/sw/structured-data-schemas.sw.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.sw.md rename to docs/tech-specs/sw/structured-data-schemas.sw.md diff --git a/docs/tech-specs/structured-data.sw.md b/docs/tech-specs/sw/structured-data.sw.md similarity index 100% rename from docs/tech-specs/structured-data.sw.md rename to docs/tech-specs/sw/structured-data.sw.md diff --git a/docs/tech-specs/structured-diag-service.sw.md b/docs/tech-specs/sw/structured-diag-service.sw.md similarity index 100% rename from docs/tech-specs/structured-diag-service.sw.md rename to docs/tech-specs/sw/structured-diag-service.sw.md diff --git a/docs/tech-specs/tool-group.sw.md b/docs/tech-specs/sw/tool-group.sw.md similarity index 100% rename from docs/tech-specs/tool-group.sw.md rename to docs/tech-specs/sw/tool-group.sw.md diff --git a/docs/tech-specs/tool-services.sw.md b/docs/tech-specs/sw/tool-services.sw.md similarity index 100% rename from docs/tech-specs/tool-services.sw.md rename to docs/tech-specs/sw/tool-services.sw.md diff --git a/docs/tech-specs/universal-decoder.sw.md b/docs/tech-specs/sw/universal-decoder.sw.md similarity index 100% rename from docs/tech-specs/universal-decoder.sw.md rename to docs/tech-specs/sw/universal-decoder.sw.md diff --git a/docs/tech-specs/vector-store-lifecycle.sw.md b/docs/tech-specs/sw/vector-store-lifecycle.sw.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.sw.md rename to docs/tech-specs/sw/vector-store-lifecycle.sw.md diff --git a/docs/tech-specs/__TEMPLATE.tr.md b/docs/tech-specs/tr/__TEMPLATE.tr.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.tr.md rename to docs/tech-specs/tr/__TEMPLATE.tr.md diff --git a/docs/tech-specs/agent-explainability.tr.md b/docs/tech-specs/tr/agent-explainability.tr.md similarity index 100% rename from docs/tech-specs/agent-explainability.tr.md rename to docs/tech-specs/tr/agent-explainability.tr.md diff --git a/docs/tech-specs/architecture-principles.tr.md b/docs/tech-specs/tr/architecture-principles.tr.md similarity index 100% rename from docs/tech-specs/architecture-principles.tr.md rename to docs/tech-specs/tr/architecture-principles.tr.md diff --git a/docs/tech-specs/cassandra-consolidation.tr.md b/docs/tech-specs/tr/cassandra-consolidation.tr.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.tr.md rename to docs/tech-specs/tr/cassandra-consolidation.tr.md diff --git a/docs/tech-specs/cassandra-performance-refactor.tr.md b/docs/tech-specs/tr/cassandra-performance-refactor.tr.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.tr.md rename to docs/tech-specs/tr/cassandra-performance-refactor.tr.md diff --git a/docs/tech-specs/collection-management.tr.md b/docs/tech-specs/tr/collection-management.tr.md similarity index 100% rename from docs/tech-specs/collection-management.tr.md rename to docs/tech-specs/tr/collection-management.tr.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.tr.md b/docs/tech-specs/tr/document-embeddings-chunk-id.tr.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.tr.md rename to docs/tech-specs/tr/document-embeddings-chunk-id.tr.md diff --git a/docs/tech-specs/embeddings-batch-processing.tr.md b/docs/tech-specs/tr/embeddings-batch-processing.tr.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.tr.md rename to docs/tech-specs/tr/embeddings-batch-processing.tr.md diff --git a/docs/tech-specs/entity-centric-graph.tr.md b/docs/tech-specs/tr/entity-centric-graph.tr.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.tr.md rename to docs/tech-specs/tr/entity-centric-graph.tr.md diff --git a/docs/tech-specs/explainability-cli.tr.md b/docs/tech-specs/tr/explainability-cli.tr.md similarity index 100% rename from docs/tech-specs/explainability-cli.tr.md rename to docs/tech-specs/tr/explainability-cli.tr.md diff --git a/docs/tech-specs/extraction-flows.tr.md b/docs/tech-specs/tr/extraction-flows.tr.md similarity index 100% rename from docs/tech-specs/extraction-flows.tr.md rename to docs/tech-specs/tr/extraction-flows.tr.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.tr.md b/docs/tech-specs/tr/extraction-provenance-subgraph.tr.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.tr.md rename to docs/tech-specs/tr/extraction-provenance-subgraph.tr.md diff --git a/docs/tech-specs/extraction-time-provenance.tr.md b/docs/tech-specs/tr/extraction-time-provenance.tr.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.tr.md rename to docs/tech-specs/tr/extraction-time-provenance.tr.md diff --git a/docs/tech-specs/flow-class-definition.tr.md b/docs/tech-specs/tr/flow-class-definition.tr.md similarity index 100% rename from docs/tech-specs/flow-class-definition.tr.md rename to docs/tech-specs/tr/flow-class-definition.tr.md diff --git a/docs/tech-specs/flow-configurable-parameters.tr.md b/docs/tech-specs/tr/flow-configurable-parameters.tr.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.tr.md rename to docs/tech-specs/tr/flow-configurable-parameters.tr.md diff --git a/docs/tech-specs/graph-contexts.tr.md b/docs/tech-specs/tr/graph-contexts.tr.md similarity index 100% rename from docs/tech-specs/graph-contexts.tr.md rename to docs/tech-specs/tr/graph-contexts.tr.md diff --git a/docs/tech-specs/graphql-query.tr.md b/docs/tech-specs/tr/graphql-query.tr.md similarity index 100% rename from docs/tech-specs/graphql-query.tr.md rename to docs/tech-specs/tr/graphql-query.tr.md diff --git a/docs/tech-specs/graphrag-performance-optimization.tr.md b/docs/tech-specs/tr/graphrag-performance-optimization.tr.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.tr.md rename to docs/tech-specs/tr/graphrag-performance-optimization.tr.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.tr.md b/docs/tech-specs/tr/import-export-graceful-shutdown.tr.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.tr.md rename to docs/tech-specs/tr/import-export-graceful-shutdown.tr.md diff --git a/docs/tech-specs/jsonl-prompt-output.tr.md b/docs/tech-specs/tr/jsonl-prompt-output.tr.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.tr.md rename to docs/tech-specs/tr/jsonl-prompt-output.tr.md diff --git a/docs/tech-specs/large-document-loading.tr.md b/docs/tech-specs/tr/large-document-loading.tr.md similarity index 100% rename from docs/tech-specs/large-document-loading.tr.md rename to docs/tech-specs/tr/large-document-loading.tr.md diff --git a/docs/tech-specs/logging-strategy.tr.md b/docs/tech-specs/tr/logging-strategy.tr.md similarity index 100% rename from docs/tech-specs/logging-strategy.tr.md rename to docs/tech-specs/tr/logging-strategy.tr.md diff --git a/docs/tech-specs/mcp-tool-arguments.tr.md b/docs/tech-specs/tr/mcp-tool-arguments.tr.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.tr.md rename to docs/tech-specs/tr/mcp-tool-arguments.tr.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.tr.md b/docs/tech-specs/tr/mcp-tool-bearer-token.tr.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.tr.md rename to docs/tech-specs/tr/mcp-tool-bearer-token.tr.md diff --git a/docs/tech-specs/minio-to-s3-migration.tr.md b/docs/tech-specs/tr/minio-to-s3-migration.tr.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.tr.md rename to docs/tech-specs/tr/minio-to-s3-migration.tr.md diff --git a/docs/tech-specs/more-config-cli.tr.md b/docs/tech-specs/tr/more-config-cli.tr.md similarity index 100% rename from docs/tech-specs/more-config-cli.tr.md rename to docs/tech-specs/tr/more-config-cli.tr.md diff --git a/docs/tech-specs/multi-tenant-support.tr.md b/docs/tech-specs/tr/multi-tenant-support.tr.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.tr.md rename to docs/tech-specs/tr/multi-tenant-support.tr.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.tr.md b/docs/tech-specs/tr/neo4j-user-collection-isolation.tr.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.tr.md rename to docs/tech-specs/tr/neo4j-user-collection-isolation.tr.md diff --git a/docs/tech-specs/ontology-extract-phase-2.tr.md b/docs/tech-specs/tr/ontology-extract-phase-2.tr.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.tr.md rename to docs/tech-specs/tr/ontology-extract-phase-2.tr.md diff --git a/docs/tech-specs/ontology.tr.md b/docs/tech-specs/tr/ontology.tr.md similarity index 100% rename from docs/tech-specs/ontology.tr.md rename to docs/tech-specs/tr/ontology.tr.md diff --git a/docs/tech-specs/ontorag.tr.md b/docs/tech-specs/tr/ontorag.tr.md similarity index 100% rename from docs/tech-specs/ontorag.tr.md rename to docs/tech-specs/tr/ontorag.tr.md diff --git a/docs/tech-specs/openapi-spec.tr.md b/docs/tech-specs/tr/openapi-spec.tr.md similarity index 100% rename from docs/tech-specs/openapi-spec.tr.md rename to docs/tech-specs/tr/openapi-spec.tr.md diff --git a/docs/tech-specs/pubsub.tr.md b/docs/tech-specs/tr/pubsub.tr.md similarity index 100% rename from docs/tech-specs/pubsub.tr.md rename to docs/tech-specs/tr/pubsub.tr.md diff --git a/docs/tech-specs/python-api-refactor.tr.md b/docs/tech-specs/tr/python-api-refactor.tr.md similarity index 100% rename from docs/tech-specs/python-api-refactor.tr.md rename to docs/tech-specs/tr/python-api-refactor.tr.md diff --git a/docs/tech-specs/query-time-explainability.tr.md b/docs/tech-specs/tr/query-time-explainability.tr.md similarity index 100% rename from docs/tech-specs/query-time-explainability.tr.md rename to docs/tech-specs/tr/query-time-explainability.tr.md diff --git a/docs/tech-specs/rag-streaming-support.tr.md b/docs/tech-specs/tr/rag-streaming-support.tr.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.tr.md rename to docs/tech-specs/tr/rag-streaming-support.tr.md diff --git a/docs/tech-specs/schema-refactoring-proposal.tr.md b/docs/tech-specs/tr/schema-refactoring-proposal.tr.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.tr.md rename to docs/tech-specs/tr/schema-refactoring-proposal.tr.md diff --git a/docs/tech-specs/streaming-llm-responses.tr.md b/docs/tech-specs/tr/streaming-llm-responses.tr.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.tr.md rename to docs/tech-specs/tr/streaming-llm-responses.tr.md diff --git a/docs/tech-specs/structured-data-2.tr.md b/docs/tech-specs/tr/structured-data-2.tr.md similarity index 100% rename from docs/tech-specs/structured-data-2.tr.md rename to docs/tech-specs/tr/structured-data-2.tr.md diff --git a/docs/tech-specs/structured-data-descriptor.tr.md b/docs/tech-specs/tr/structured-data-descriptor.tr.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.tr.md rename to docs/tech-specs/tr/structured-data-descriptor.tr.md diff --git a/docs/tech-specs/structured-data-schemas.tr.md b/docs/tech-specs/tr/structured-data-schemas.tr.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.tr.md rename to docs/tech-specs/tr/structured-data-schemas.tr.md diff --git a/docs/tech-specs/structured-data.tr.md b/docs/tech-specs/tr/structured-data.tr.md similarity index 100% rename from docs/tech-specs/structured-data.tr.md rename to docs/tech-specs/tr/structured-data.tr.md diff --git a/docs/tech-specs/structured-diag-service.tr.md b/docs/tech-specs/tr/structured-diag-service.tr.md similarity index 100% rename from docs/tech-specs/structured-diag-service.tr.md rename to docs/tech-specs/tr/structured-diag-service.tr.md diff --git a/docs/tech-specs/tool-group.tr.md b/docs/tech-specs/tr/tool-group.tr.md similarity index 100% rename from docs/tech-specs/tool-group.tr.md rename to docs/tech-specs/tr/tool-group.tr.md diff --git a/docs/tech-specs/tool-services.tr.md b/docs/tech-specs/tr/tool-services.tr.md similarity index 100% rename from docs/tech-specs/tool-services.tr.md rename to docs/tech-specs/tr/tool-services.tr.md diff --git a/docs/tech-specs/universal-decoder.tr.md b/docs/tech-specs/tr/universal-decoder.tr.md similarity index 100% rename from docs/tech-specs/universal-decoder.tr.md rename to docs/tech-specs/tr/universal-decoder.tr.md diff --git a/docs/tech-specs/vector-store-lifecycle.tr.md b/docs/tech-specs/tr/vector-store-lifecycle.tr.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.tr.md rename to docs/tech-specs/tr/vector-store-lifecycle.tr.md diff --git a/docs/tech-specs/__TEMPLATE.zh-cn.md b/docs/tech-specs/zh-cn/__TEMPLATE.zh-cn.md similarity index 100% rename from docs/tech-specs/__TEMPLATE.zh-cn.md rename to docs/tech-specs/zh-cn/__TEMPLATE.zh-cn.md diff --git a/docs/tech-specs/agent-explainability.zh-cn.md b/docs/tech-specs/zh-cn/agent-explainability.zh-cn.md similarity index 100% rename from docs/tech-specs/agent-explainability.zh-cn.md rename to docs/tech-specs/zh-cn/agent-explainability.zh-cn.md diff --git a/docs/tech-specs/architecture-principles.zh-cn.md b/docs/tech-specs/zh-cn/architecture-principles.zh-cn.md similarity index 100% rename from docs/tech-specs/architecture-principles.zh-cn.md rename to docs/tech-specs/zh-cn/architecture-principles.zh-cn.md diff --git a/docs/tech-specs/cassandra-consolidation.zh-cn.md b/docs/tech-specs/zh-cn/cassandra-consolidation.zh-cn.md similarity index 100% rename from docs/tech-specs/cassandra-consolidation.zh-cn.md rename to docs/tech-specs/zh-cn/cassandra-consolidation.zh-cn.md diff --git a/docs/tech-specs/cassandra-performance-refactor.zh-cn.md b/docs/tech-specs/zh-cn/cassandra-performance-refactor.zh-cn.md similarity index 100% rename from docs/tech-specs/cassandra-performance-refactor.zh-cn.md rename to docs/tech-specs/zh-cn/cassandra-performance-refactor.zh-cn.md diff --git a/docs/tech-specs/collection-management.zh-cn.md b/docs/tech-specs/zh-cn/collection-management.zh-cn.md similarity index 100% rename from docs/tech-specs/collection-management.zh-cn.md rename to docs/tech-specs/zh-cn/collection-management.zh-cn.md diff --git a/docs/tech-specs/document-embeddings-chunk-id.zh-cn.md b/docs/tech-specs/zh-cn/document-embeddings-chunk-id.zh-cn.md similarity index 100% rename from docs/tech-specs/document-embeddings-chunk-id.zh-cn.md rename to docs/tech-specs/zh-cn/document-embeddings-chunk-id.zh-cn.md diff --git a/docs/tech-specs/embeddings-batch-processing.zh-cn.md b/docs/tech-specs/zh-cn/embeddings-batch-processing.zh-cn.md similarity index 100% rename from docs/tech-specs/embeddings-batch-processing.zh-cn.md rename to docs/tech-specs/zh-cn/embeddings-batch-processing.zh-cn.md diff --git a/docs/tech-specs/entity-centric-graph.zh-cn.md b/docs/tech-specs/zh-cn/entity-centric-graph.zh-cn.md similarity index 100% rename from docs/tech-specs/entity-centric-graph.zh-cn.md rename to docs/tech-specs/zh-cn/entity-centric-graph.zh-cn.md diff --git a/docs/tech-specs/explainability-cli.zh-cn.md b/docs/tech-specs/zh-cn/explainability-cli.zh-cn.md similarity index 100% rename from docs/tech-specs/explainability-cli.zh-cn.md rename to docs/tech-specs/zh-cn/explainability-cli.zh-cn.md diff --git a/docs/tech-specs/extraction-flows.zh-cn.md b/docs/tech-specs/zh-cn/extraction-flows.zh-cn.md similarity index 100% rename from docs/tech-specs/extraction-flows.zh-cn.md rename to docs/tech-specs/zh-cn/extraction-flows.zh-cn.md diff --git a/docs/tech-specs/extraction-provenance-subgraph.zh-cn.md b/docs/tech-specs/zh-cn/extraction-provenance-subgraph.zh-cn.md similarity index 100% rename from docs/tech-specs/extraction-provenance-subgraph.zh-cn.md rename to docs/tech-specs/zh-cn/extraction-provenance-subgraph.zh-cn.md diff --git a/docs/tech-specs/extraction-time-provenance.zh-cn.md b/docs/tech-specs/zh-cn/extraction-time-provenance.zh-cn.md similarity index 100% rename from docs/tech-specs/extraction-time-provenance.zh-cn.md rename to docs/tech-specs/zh-cn/extraction-time-provenance.zh-cn.md diff --git a/docs/tech-specs/flow-class-definition.zh-cn.md b/docs/tech-specs/zh-cn/flow-class-definition.zh-cn.md similarity index 100% rename from docs/tech-specs/flow-class-definition.zh-cn.md rename to docs/tech-specs/zh-cn/flow-class-definition.zh-cn.md diff --git a/docs/tech-specs/flow-configurable-parameters.zh-cn.md b/docs/tech-specs/zh-cn/flow-configurable-parameters.zh-cn.md similarity index 100% rename from docs/tech-specs/flow-configurable-parameters.zh-cn.md rename to docs/tech-specs/zh-cn/flow-configurable-parameters.zh-cn.md diff --git a/docs/tech-specs/graph-contexts.zh-cn.md b/docs/tech-specs/zh-cn/graph-contexts.zh-cn.md similarity index 100% rename from docs/tech-specs/graph-contexts.zh-cn.md rename to docs/tech-specs/zh-cn/graph-contexts.zh-cn.md diff --git a/docs/tech-specs/graphql-query.zh-cn.md b/docs/tech-specs/zh-cn/graphql-query.zh-cn.md similarity index 100% rename from docs/tech-specs/graphql-query.zh-cn.md rename to docs/tech-specs/zh-cn/graphql-query.zh-cn.md diff --git a/docs/tech-specs/graphrag-performance-optimization.zh-cn.md b/docs/tech-specs/zh-cn/graphrag-performance-optimization.zh-cn.md similarity index 100% rename from docs/tech-specs/graphrag-performance-optimization.zh-cn.md rename to docs/tech-specs/zh-cn/graphrag-performance-optimization.zh-cn.md diff --git a/docs/tech-specs/import-export-graceful-shutdown.zh-cn.md b/docs/tech-specs/zh-cn/import-export-graceful-shutdown.zh-cn.md similarity index 100% rename from docs/tech-specs/import-export-graceful-shutdown.zh-cn.md rename to docs/tech-specs/zh-cn/import-export-graceful-shutdown.zh-cn.md diff --git a/docs/tech-specs/jsonl-prompt-output.zh-cn.md b/docs/tech-specs/zh-cn/jsonl-prompt-output.zh-cn.md similarity index 100% rename from docs/tech-specs/jsonl-prompt-output.zh-cn.md rename to docs/tech-specs/zh-cn/jsonl-prompt-output.zh-cn.md diff --git a/docs/tech-specs/large-document-loading.zh-cn.md b/docs/tech-specs/zh-cn/large-document-loading.zh-cn.md similarity index 100% rename from docs/tech-specs/large-document-loading.zh-cn.md rename to docs/tech-specs/zh-cn/large-document-loading.zh-cn.md diff --git a/docs/tech-specs/logging-strategy.zh-cn.md b/docs/tech-specs/zh-cn/logging-strategy.zh-cn.md similarity index 100% rename from docs/tech-specs/logging-strategy.zh-cn.md rename to docs/tech-specs/zh-cn/logging-strategy.zh-cn.md diff --git a/docs/tech-specs/mcp-tool-arguments.zh-cn.md b/docs/tech-specs/zh-cn/mcp-tool-arguments.zh-cn.md similarity index 100% rename from docs/tech-specs/mcp-tool-arguments.zh-cn.md rename to docs/tech-specs/zh-cn/mcp-tool-arguments.zh-cn.md diff --git a/docs/tech-specs/mcp-tool-bearer-token.zh-cn.md b/docs/tech-specs/zh-cn/mcp-tool-bearer-token.zh-cn.md similarity index 100% rename from docs/tech-specs/mcp-tool-bearer-token.zh-cn.md rename to docs/tech-specs/zh-cn/mcp-tool-bearer-token.zh-cn.md diff --git a/docs/tech-specs/minio-to-s3-migration.zh-cn.md b/docs/tech-specs/zh-cn/minio-to-s3-migration.zh-cn.md similarity index 100% rename from docs/tech-specs/minio-to-s3-migration.zh-cn.md rename to docs/tech-specs/zh-cn/minio-to-s3-migration.zh-cn.md diff --git a/docs/tech-specs/more-config-cli.zh-cn.md b/docs/tech-specs/zh-cn/more-config-cli.zh-cn.md similarity index 100% rename from docs/tech-specs/more-config-cli.zh-cn.md rename to docs/tech-specs/zh-cn/more-config-cli.zh-cn.md diff --git a/docs/tech-specs/multi-tenant-support.zh-cn.md b/docs/tech-specs/zh-cn/multi-tenant-support.zh-cn.md similarity index 100% rename from docs/tech-specs/multi-tenant-support.zh-cn.md rename to docs/tech-specs/zh-cn/multi-tenant-support.zh-cn.md diff --git a/docs/tech-specs/neo4j-user-collection-isolation.zh-cn.md b/docs/tech-specs/zh-cn/neo4j-user-collection-isolation.zh-cn.md similarity index 100% rename from docs/tech-specs/neo4j-user-collection-isolation.zh-cn.md rename to docs/tech-specs/zh-cn/neo4j-user-collection-isolation.zh-cn.md diff --git a/docs/tech-specs/ontology-extract-phase-2.zh-cn.md b/docs/tech-specs/zh-cn/ontology-extract-phase-2.zh-cn.md similarity index 100% rename from docs/tech-specs/ontology-extract-phase-2.zh-cn.md rename to docs/tech-specs/zh-cn/ontology-extract-phase-2.zh-cn.md diff --git a/docs/tech-specs/ontology.zh-cn.md b/docs/tech-specs/zh-cn/ontology.zh-cn.md similarity index 100% rename from docs/tech-specs/ontology.zh-cn.md rename to docs/tech-specs/zh-cn/ontology.zh-cn.md diff --git a/docs/tech-specs/ontorag.zh-cn.md b/docs/tech-specs/zh-cn/ontorag.zh-cn.md similarity index 100% rename from docs/tech-specs/ontorag.zh-cn.md rename to docs/tech-specs/zh-cn/ontorag.zh-cn.md diff --git a/docs/tech-specs/openapi-spec.zh-cn.md b/docs/tech-specs/zh-cn/openapi-spec.zh-cn.md similarity index 100% rename from docs/tech-specs/openapi-spec.zh-cn.md rename to docs/tech-specs/zh-cn/openapi-spec.zh-cn.md diff --git a/docs/tech-specs/pubsub.zh-cn.md b/docs/tech-specs/zh-cn/pubsub.zh-cn.md similarity index 100% rename from docs/tech-specs/pubsub.zh-cn.md rename to docs/tech-specs/zh-cn/pubsub.zh-cn.md diff --git a/docs/tech-specs/python-api-refactor.zh-cn.md b/docs/tech-specs/zh-cn/python-api-refactor.zh-cn.md similarity index 100% rename from docs/tech-specs/python-api-refactor.zh-cn.md rename to docs/tech-specs/zh-cn/python-api-refactor.zh-cn.md diff --git a/docs/tech-specs/query-time-explainability.zh-cn.md b/docs/tech-specs/zh-cn/query-time-explainability.zh-cn.md similarity index 100% rename from docs/tech-specs/query-time-explainability.zh-cn.md rename to docs/tech-specs/zh-cn/query-time-explainability.zh-cn.md diff --git a/docs/tech-specs/rag-streaming-support.zh-cn.md b/docs/tech-specs/zh-cn/rag-streaming-support.zh-cn.md similarity index 100% rename from docs/tech-specs/rag-streaming-support.zh-cn.md rename to docs/tech-specs/zh-cn/rag-streaming-support.zh-cn.md diff --git a/docs/tech-specs/schema-refactoring-proposal.zh-cn.md b/docs/tech-specs/zh-cn/schema-refactoring-proposal.zh-cn.md similarity index 100% rename from docs/tech-specs/schema-refactoring-proposal.zh-cn.md rename to docs/tech-specs/zh-cn/schema-refactoring-proposal.zh-cn.md diff --git a/docs/tech-specs/streaming-llm-responses.zh-cn.md b/docs/tech-specs/zh-cn/streaming-llm-responses.zh-cn.md similarity index 100% rename from docs/tech-specs/streaming-llm-responses.zh-cn.md rename to docs/tech-specs/zh-cn/streaming-llm-responses.zh-cn.md diff --git a/docs/tech-specs/structured-data-2.zh-cn.md b/docs/tech-specs/zh-cn/structured-data-2.zh-cn.md similarity index 100% rename from docs/tech-specs/structured-data-2.zh-cn.md rename to docs/tech-specs/zh-cn/structured-data-2.zh-cn.md diff --git a/docs/tech-specs/structured-data-descriptor.zh-cn.md b/docs/tech-specs/zh-cn/structured-data-descriptor.zh-cn.md similarity index 100% rename from docs/tech-specs/structured-data-descriptor.zh-cn.md rename to docs/tech-specs/zh-cn/structured-data-descriptor.zh-cn.md diff --git a/docs/tech-specs/structured-data-schemas.zh-cn.md b/docs/tech-specs/zh-cn/structured-data-schemas.zh-cn.md similarity index 100% rename from docs/tech-specs/structured-data-schemas.zh-cn.md rename to docs/tech-specs/zh-cn/structured-data-schemas.zh-cn.md diff --git a/docs/tech-specs/structured-data.zh-cn.md b/docs/tech-specs/zh-cn/structured-data.zh-cn.md similarity index 100% rename from docs/tech-specs/structured-data.zh-cn.md rename to docs/tech-specs/zh-cn/structured-data.zh-cn.md diff --git a/docs/tech-specs/structured-diag-service.zh-cn.md b/docs/tech-specs/zh-cn/structured-diag-service.zh-cn.md similarity index 100% rename from docs/tech-specs/structured-diag-service.zh-cn.md rename to docs/tech-specs/zh-cn/structured-diag-service.zh-cn.md diff --git a/docs/tech-specs/tool-group.zh-cn.md b/docs/tech-specs/zh-cn/tool-group.zh-cn.md similarity index 100% rename from docs/tech-specs/tool-group.zh-cn.md rename to docs/tech-specs/zh-cn/tool-group.zh-cn.md diff --git a/docs/tech-specs/tool-services.zh-cn.md b/docs/tech-specs/zh-cn/tool-services.zh-cn.md similarity index 100% rename from docs/tech-specs/tool-services.zh-cn.md rename to docs/tech-specs/zh-cn/tool-services.zh-cn.md diff --git a/docs/tech-specs/universal-decoder.zh-cn.md b/docs/tech-specs/zh-cn/universal-decoder.zh-cn.md similarity index 100% rename from docs/tech-specs/universal-decoder.zh-cn.md rename to docs/tech-specs/zh-cn/universal-decoder.zh-cn.md diff --git a/docs/tech-specs/vector-store-lifecycle.zh-cn.md b/docs/tech-specs/zh-cn/vector-store-lifecycle.zh-cn.md similarity index 100% rename from docs/tech-specs/vector-store-lifecycle.zh-cn.md rename to docs/tech-specs/zh-cn/vector-store-lifecycle.zh-cn.md From 0ef49ab6ae4df8aed5c50aca172f208b7361a0f2 Mon Sep 17 00:00:00 2001 From: Het Patel <102606191+CuriousHet@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:45:11 +0530 Subject: [PATCH 04/21] feat: standardize LLM rate-limiting and exception handling (#835) - HTTP 429 translates to TooManyRequests (retryable) - HTTP 503 translates to LlmError --- .../test_rate_limit_contract.py | 124 ++++++++++++++++-- .../model/text_completion/cohere/llm.py | 18 ++- .../model/text_completion/mistral/llm.py | 33 ++--- .../model/text_completion/openai/llm.py | 15 ++- .../model/text_completion/vllm/llm.py | 18 ++- 5 files changed, 170 insertions(+), 38 deletions(-) diff --git a/tests/unit/test_text_completion/test_rate_limit_contract.py b/tests/unit/test_text_completion/test_rate_limit_contract.py index c9df217b..9cf00b7c 100644 --- a/tests/unit/test_text_completion/test_rate_limit_contract.py +++ b/tests/unit/test_text_completion/test_rate_limit_contract.py @@ -10,7 +10,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch from unittest import IsolatedAsyncioTestCase -from trustgraph.exceptions import TooManyRequests +from trustgraph.exceptions import TooManyRequests, LlmError class TestAzureServerless429(IsolatedAsyncioTestCase): @@ -77,6 +77,24 @@ class TestOpenAIRateLimit(IsolatedAsyncioTestCase): with pytest.raises(TooManyRequests): await proc.generate_content("sys", "prompt") + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_cls): + from openai import InternalServerError + from trustgraph.model.text_completion.openai.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_client.chat.completions.create.side_effect = InternalServerError( + "service unavailable", response=MagicMock(), body=None + ) + + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + class TestClaudeRateLimit(IsolatedAsyncioTestCase): """Claude/Anthropic: anthropic.RateLimitError → TooManyRequests""" @@ -103,32 +121,120 @@ class TestClaudeRateLimit(IsolatedAsyncioTestCase): await proc.generate_content("sys", "prompt") +class TestMistralRateLimit(IsolatedAsyncioTestCase): + """Mistral: models.SDKError (429/503) → TooManyRequests/LlmError""" + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.model.text_completion.mistral.llm.models') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_429_raises_too_many_requests(self, _llm, _async, mock_models, mock_cls): + from trustgraph.model.text_completion.mistral.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + # Define a mock exception class + mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 429}) + mock_client.chat.complete.side_effect = mock_models.SDKError() + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.model.text_completion.mistral.llm.models') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_models, mock_cls): + from trustgraph.model.text_completion.mistral.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 503}) + mock_client.chat.complete.side_effect = mock_models.SDKError() + + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + + class TestCohereRateLimit(IsolatedAsyncioTestCase): - """Cohere: cohere.TooManyRequestsError → TooManyRequests""" + """Cohere: cohere.errors (429/503) → TooManyRequests/LlmError""" @patch('trustgraph.model.text_completion.cohere.llm.cohere') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere): from trustgraph.model.text_completion.cohere.llm import Processor - + import trustgraph.model.text_completion.cohere.llm as cohere_llm + mock_client = MagicMock() mock_cohere.Client.return_value = mock_client - proc = Processor( api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", ) + + ErrorCls = type("TooManyRequestsError", (Exception,), {}) + with patch.object(cohere_llm, 'TooManyRequestsError', ErrorCls): + mock_client.chat.side_effect = ErrorCls() + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") - mock_cohere.TooManyRequestsError = type( - "TooManyRequestsError", (Exception,), {} - ) - mock_client.chat.side_effect = mock_cohere.TooManyRequestsError( - "rate limited" + @patch('trustgraph.model.text_completion.cohere.llm.cohere') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_cohere): + from trustgraph.model.text_completion.cohere.llm import Processor + import trustgraph.model.text_completion.cohere.llm as cohere_llm + + mock_client = MagicMock() + mock_cohere.Client.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", ) + + ErrorCls = type("ServiceUnavailableError", (Exception,), {}) + with patch.object(cohere_llm, 'ServiceUnavailableError', ErrorCls): + mock_client.chat.side_effect = ErrorCls() + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + + +class TestVllmRateLimit(IsolatedAsyncioTestCase): + """vLLM: HTTP 429/503 → TooManyRequests/LlmError""" + + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_429_raises_too_many_requests(self, _llm, _async, mock_session): + from trustgraph.model.text_completion.vllm.llm import Processor + proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t") + + mock_resp = AsyncMock() + mock_resp.status = 429 + mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp with pytest.raises(TooManyRequests): await proc.generate_content("sys", "prompt") + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_session): + from trustgraph.model.text_completion.vllm.llm import Processor + proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t") + + mock_resp = AsyncMock() + mock_resp.status = 503 + mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp + + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + class TestClientSideRateLimitTranslation: """Client base class: error type 'too-many-requests' → TooManyRequests""" diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index 5093e556..4190cb98 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -5,6 +5,7 @@ Input is prompt, output is response. """ import cohere +from cohere.errors import TooManyRequestsError, ServiceUnavailableError from prometheus_client import Histogram import os import logging @@ -12,7 +13,7 @@ import logging # Module logger logger = logging.getLogger(__name__) -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" @@ -84,13 +85,14 @@ class Processor(LlmService): return resp - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit - except cohere.TooManyRequestsError: - + except TooManyRequestsError: # Leave rate limit retries to the base handler raise TooManyRequests() + except ServiceUnavailableError: + # Treat 503 as a retryable LlmError + raise LlmError() + except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable @@ -152,10 +154,14 @@ class Processor(LlmService): logger.debug("Streaming complete") - except cohere.TooManyRequestsError: + except TooManyRequestsError: logger.warning("Rate limit exceeded during streaming") raise TooManyRequests() + except ServiceUnavailableError: + logger.warning("Service unavailable during streaming") + raise LlmError() + except Exception as e: logger.error(f"Cohere streaming exception ({type(e).__name__}): {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py index fab41ecd..e53f6f6e 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -4,14 +4,14 @@ Simple LLM service, performs text prompt completion using Mistral. Input is prompt, output is response. """ -from mistralai import Mistral +from mistralai import Mistral, models import os import logging # Module logger logger = logging.getLogger(__name__) -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" @@ -100,18 +100,14 @@ class Processor(LlmService): return resp - # FIXME: Wrong exception. The MistralAI library has retry logic - # so retry-able errors are retried transparently. It means we - # don't get rate limit events. - - # We could choose to turn off retry and handle all that here - # or subclass BackoffStrategy to keep the retry logic, but - # get the events out. - -# except Mistral.RateLimitError: - -# # Leave rate limit retries to the base handler -# raise TooManyRequests() + except models.SDKError as e: + if e.status_code == 429: + # Leave rate limit retries to the base handler + raise TooManyRequests() + elif e.status_code == 503: + # Treat 503 as a retryable LlmError + raise LlmError() + raise e except Exception as e: @@ -185,8 +181,13 @@ class Processor(LlmService): logger.debug("Streaming complete") - except Exception as e: - logger.error(f"Mistral streaming exception ({type(e).__name__}): {e}", exc_info=True) + except models.SDKError as e: + if e.status_code == 429: + logger.warning("Hit rate limit during streaming") + raise TooManyRequests() + elif e.status_code == 503: + logger.warning("Hit internal server error during streaming") + raise LlmError() raise e @staticmethod diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index cdc8602a..0ee61521 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -4,11 +4,11 @@ Simple LLM service, performs text prompt completion using OpenAI. Input is prompt, output is response. """ -from openai import OpenAI, RateLimitError +from openai import OpenAI, RateLimitError, InternalServerError import os import logging -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk # Module logger @@ -104,13 +104,14 @@ class Processor(LlmService): return resp - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit except RateLimitError: - # Leave rate limit retries to the base handler raise TooManyRequests() + except InternalServerError: + # Treat 503 as a retryable LlmError + raise LlmError() + except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable @@ -191,6 +192,10 @@ class Processor(LlmService): logger.warning("Hit rate limit during streaming") raise TooManyRequests() + except InternalServerError: + logger.warning("Hit internal server error during streaming") + raise LlmError() + except Exception as e: logger.error(f"OpenAI streaming exception ({type(e).__name__}): {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py index 2dd4576e..7570fa40 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py @@ -11,7 +11,7 @@ import logging # Module logger logger = logging.getLogger(__name__) -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" @@ -83,6 +83,10 @@ class Processor(LlmService): json=request, ) as response: + if response.status == 429: + raise TooManyRequests() + if response.status == 503: + raise LlmError() if response.status != 200: raise RuntimeError("Bad status: " + str(response.status)) @@ -104,7 +108,13 @@ class Processor(LlmService): return resp - # FIXME: Assuming vLLM won't produce rate limits? + except TooManyRequests: + # Leave rate limit retries to the base handler + raise TooManyRequests() + + except LlmError: + # Treat 503 as a retryable LlmError + raise LlmError() except Exception as e: @@ -150,6 +160,10 @@ class Processor(LlmService): json=request, ) as response: + if response.status == 429: + raise TooManyRequests() + if response.status == 503: + raise LlmError() if response.status != 200: raise RuntimeError("Bad status: " + str(response.status)) From 424ace44c436fd31b2dc8ecd48baf96f1f38aa3f Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 21 Apr 2026 21:30:19 +0100 Subject: [PATCH 05/21] Fix library queue lifecycle (#838) * Don't delete the global queues (librarian) when flows are deleted * 60s heartbeat timeouts on RabbitMQ --- .../trustgraph/base/rabbitmq_backend.py | 10 +- .../trustgraph/flow/service/flow.py | 135 +++++++++++++++++- 2 files changed, 140 insertions(+), 5 deletions(-) diff --git a/trustgraph-base/trustgraph/base/rabbitmq_backend.py b/trustgraph-base/trustgraph/base/rabbitmq_backend.py index 73b80cb9..a8133a44 100644 --- a/trustgraph-base/trustgraph/base/rabbitmq_backend.py +++ b/trustgraph-base/trustgraph/base/rabbitmq_backend.py @@ -326,7 +326,15 @@ class RabbitMQBackend: port=port, virtual_host=vhost, credentials=pika.PlainCredentials(username, password), - heartbeat=0, + # Heartbeats let us detect silently-dead connections + # (broker restarts, network partitions, orphaned channels) + # within ~2×interval. Consumer threads drive pika's I/O + # loop every 100ms via process_data_events() in receive(), + # so heartbeat frames get pumped automatically. Producers + # reconnect lazily on the next publish if their connection + # has been aged out by the broker. + heartbeat=60, + blocked_connection_timeout=300, ) logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}") diff --git a/trustgraph-flow/trustgraph/flow/service/flow.py b/trustgraph-flow/trustgraph/flow/service/flow.py index b864faf9..a5e4a7e1 100644 --- a/trustgraph-flow/trustgraph/flow/service/flow.py +++ b/trustgraph-flow/trustgraph/flow/service/flow.py @@ -363,6 +363,112 @@ class FlowConfig: return topics + @staticmethod + def _topic_is_flow_owned(raw_template): + """Is a topic template owned by the flow system? + + A topic is flow-owned if its template contains at least one + variable substitution (``{id}``, ``{blueprint}``, + ``{workspace}``, ``{param}``, etc.). Pure literal templates + name topics that are created and owned by something else (a + global service, e.g. ``request:tg:librarian``) and must never + be touched by the flow service. + """ + return '{' in raw_template + + def _collect_owned_topics(self, cls, repl_template): + """Resolved set of flow-owned topics for a single flow. + + Only includes topics whose raw template was parameterised + (contains ``{...}``). Literal templates are skipped — they + refer to global topics the flow service does not own. + """ + topics = set() + + for k, v in cls["flow"].items(): + for spec_name, topic_template in v.get("topics", {}).items(): + if not self._topic_is_flow_owned(topic_template): + continue + topics.add(repl_template(topic_template)) + + return topics + + async def _live_owned_topic_closure(self, exclude_flow_id=None): + """Union of flow-owned topics referenced by all live flows. + + Walks every flow record currently registered in the config + service (except ``exclude_flow_id``, typically the flow being + torn down), resolves its blueprint + parameter templates, and + collects the set of flow-owned topics those templates produce. + + Used to drive closure-based topic cleanup on flow stop: a + topic may only be deleted if no remaining live flow would + still template to it. This handles all three scoping cases + transparently — ``{id}`` topics have no other references once + their flow is excluded; ``{blueprint}`` topics stay alive + while another flow of the same blueprint exists; ``{workspace}`` + (when introduced) stays alive while any flow in the workspace + exists. + """ + + live = set() + + flow_ids = await self.config.keys("flow") + + for fid in flow_ids: + + if fid == exclude_flow_id: + continue + + try: + frec_raw = await self.config.get("flow", fid) + if frec_raw is None: + continue + frec = json.loads(frec_raw) + except Exception as e: + logger.warning( + f"Closure sweep: skipping flow {fid}: {e}" + ) + continue + + # Flows mid-shutdown don't keep their topics alive. + if frec.get("status") == "stopping": + continue + + bp_name = frec.get("blueprint-name") + if bp_name is None: + continue + + try: + bp_raw = await self.config.get("flow-blueprint", bp_name) + if bp_raw is None: + continue + bp = json.loads(bp_raw) + except Exception as e: + logger.warning( + f"Closure sweep: skipping flow {fid} " + f"(blueprint {bp_name}): {e}" + ) + continue + + parameters = frec.get("parameters", {}) + + def repl(tmp, bp_name=bp_name, fid=fid, parameters=parameters): + result = tmp.replace( + "{blueprint}", bp_name + ).replace( + "{id}", fid + ) + for pname, pvalue in parameters.items(): + result = result.replace( + f"{{{pname}}}", str(pvalue) + ) + return result + + live.update(self._collect_owned_topics(bp, repl)) + + return live + async def _delete_topics(self, topics): """Delete topics with retries. Best-effort — logs failures but does not raise.""" @@ -424,8 +530,10 @@ class FlowConfig: result = result.replace(f"{{{param_name}}}", str(param_value)) return result - # Collect topic identifiers before removing config - topics = self._collect_flow_topics(cls, repl_template) + # Collect this flow's owned topics before any config changes. + # Global (literal-template) topics are never touched — they are + # managed by whichever service owns them, not by flow-svc. + this_flow_owned = self._collect_owned_topics(cls, repl_template) # Phase 1: Set status to "stopping" and remove processor config. # The config push tells processors to shut down their consumers. @@ -446,9 +554,28 @@ class FlowConfig: await self.config.delete_many(deletes) - # Phase 2: Delete topics with retries, then remove the flow record. - await self._delete_topics(topics) + # Phase 2: Closure-based sweep. Only delete topics that no + # other live flow still references via its blueprint templates. + # This preserves {blueprint}-scoped topics while another flow + # of the same blueprint is still running, and {workspace}-scoped + # topics while any flow in that workspace remains. + live_owned = await self._live_owned_topic_closure( + exclude_flow_id=msg.flow_id, + ) + to_delete = this_flow_owned - live_owned + + if to_delete: + await self._delete_topics(to_delete) + + kept = this_flow_owned - to_delete + if kept: + logger.info( + f"Flow {msg.flow_id}: keeping {len(kept)} topics " + f"still referenced by other live flows" + ) + + # Phase 3: Remove the flow record. if msg.flow_id in await self.config.keys("flow"): await self.config.delete("flow", msg.flow_id) From 9332089b3d4adf2260d1755084e3c772761e151b Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 21 Apr 2026 21:36:46 +0100 Subject: [PATCH 06/21] Setup for 2.4 release branch (#839) --- .github/workflows/pull-request.yaml | 2 +- trustgraph-bedrock/pyproject.toml | 2 +- trustgraph-cli/pyproject.toml | 2 +- trustgraph-embeddings-hf/pyproject.toml | 4 ++-- trustgraph-flow/pyproject.toml | 2 +- trustgraph-ocr/pyproject.toml | 2 +- trustgraph-unstructured/pyproject.toml | 2 +- trustgraph-vertexai/pyproject.toml | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index b1ae8611..8248dfbf 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=2.3.999 + run: make update-package-versions VERSION=2.4.999 - name: Setup environment run: python3 -m venv env diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 2d65461b..f0c8d571 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 0151fef4..a60b2bba 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "requests", "pulsar-client", "aiohttp", diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 459f6123..70489969 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", - "trustgraph-flow>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", + "trustgraph-flow>=2.4,<2.5", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 492af385..8ba85adf 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "aiohttp", "anthropic", "scylla-driver", diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index cd1d20a1..1718258f 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index d8879329..7169fc8b 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "prometheus-client", "python-magic", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 45958ef3..f43f154d 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.3,<2.4", + "trustgraph-base>=2.4,<2.5", "pulsar-client", "google-genai", "google-api-core", From d35473f7f7c96cb1517fcb2877e8b2e303d74795 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 21 Apr 2026 23:23:01 +0100 Subject: [PATCH 07/21] feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840) Introduces `workspace` as the isolation boundary for config, flows, library, and knowledge data. Removes `user` as a schema-level field throughout the code, API specs, and tests; workspace provides the same separation more cleanly at the trusted flow.workspace layer rather than through client-supplied message fields. Design ------ - IAM tech spec (docs/tech-specs/iam.md) documents current state, proposed auth/access model, and migration direction. - Data ownership model (docs/tech-specs/data-ownership-model.md) captures the workspace/collection/flow hierarchy. Schema + messaging ------------------ - Drop `user` field from AgentRequest/Step, GraphRagQuery, DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest, Sparql/Rows/Structured QueryRequest, ToolServiceRequest. - Keep collection/workspace routing via flow.workspace at the service layer. - Translators updated to not serialise/deserialise user. API specs --------- - OpenAPI schemas and path examples cleaned of user fields. - Websocket async-api messages updated. - Removed the unused parameters/User.yaml. Services + base --------------- - Librarian, collection manager, knowledge, config: all operations scoped by workspace. Config client API takes workspace as first positional arg. - `flow.workspace` set at flow start time by the infrastructure; no longer pass-through from clients. - Tool service drops user-personalisation passthrough. CLI + SDK --------- - tg-init-workspace and workspace-aware import/export. - All tg-* commands drop user args; accept --workspace. - Python API/SDK (flow, socket_client, async_*, explainability, library) drop user kwargs from every method signature. MCP server ---------- - All tool endpoints drop user parameters; socket_manager no longer keyed per user. Flow service ------------ - Closure-based topic cleanup on flow stop: only delete topics whose blueprint template was parameterised AND no remaining live flow (across all workspaces) still resolves to that topic. Three scopes fall out naturally from template analysis: * {id} -> per-flow, deleted on stop * {blueprint} -> per-blueprint, kept while any flow of the same blueprint exists * {workspace} -> per-workspace, kept while any flow in the workspace exists * literal -> global, never deleted (e.g. tg.request.librarian) Fixes a bug where stopping a flow silently destroyed the global librarian exchange, wedging all library operations until manual restart. RabbitMQ backend ---------------- - heartbeat=60, blocked_connection_timeout=300. Catches silently dead connections (broker restart, orphaned channels, network partitions) within ~2 heartbeat windows, so the consumer reconnects and re-binds its queue rather than sitting forever on a zombie connection. Tests ----- - Full test refresh: unit, integration, contract, provenance. - Dropped user-field assertions and constructor kwargs across ~100 test files. - Renamed user-collection isolation tests to workspace-collection. --- .gitignore | 1 + Makefile | 2 +- docs/tech-specs/data-ownership-model.md | 309 +++++++ docs/tech-specs/flow-class-definition.md | 36 +- docs/tech-specs/iam.md | 858 ++++++++++++++++++ specs/api/components/parameters/User.yaml | 8 - .../schemas/agent/AgentRequest.yaml | 9 - .../schemas/collection/CollectionRequest.yaml | 7 +- .../collection/CollectionResponse.yaml | 5 - .../DocumentEmbeddingsQueryRequest.yaml | 5 - .../GraphEmbeddingsQueryRequest.yaml | 5 - .../RowEmbeddingsQueryRequest.yaml | 5 - .../schemas/knowledge/KnowledgeRequest.yaml | 17 +- .../schemas/knowledge/KnowledgeResponse.yaml | 10 - .../schemas/librarian/LibrarianRequest.yaml | 5 - .../schemas/loading/DocumentLoadRequest.yaml | 5 - .../schemas/loading/TextLoadRequest.yaml | 5 - .../schemas/query/RowsQueryRequest.yaml | 5 - .../schemas/query/StructuredQueryRequest.yaml | 5 - .../schemas/query/TriplesQueryRequest.yaml | 5 - .../schemas/rag/DocumentRagRequest.yaml | 5 - .../schemas/rag/GraphRagRequest.yaml | 5 - specs/api/paths/collection-management.yaml | 21 +- specs/api/paths/document-stream.yaml | 8 - specs/api/paths/export-core.yaml | 10 - specs/api/paths/flow/agent.yaml | 4 - specs/api/paths/flow/document-embeddings.yaml | 1 - specs/api/paths/flow/document-load.yaml | 2 - specs/api/paths/flow/document-rag.yaml | 3 - specs/api/paths/flow/graph-embeddings.yaml | 1 - specs/api/paths/flow/graph-rag.yaml | 2 - specs/api/paths/flow/row-embeddings.yaml | 1 - specs/api/paths/flow/rows.yaml | 1 - specs/api/paths/flow/sparql-query.yaml | 5 - specs/api/paths/flow/structured-query.yaml | 2 - specs/api/paths/flow/text-load.yaml | 2 - specs/api/paths/flow/triples.yaml | 2 - specs/api/paths/import-core.yaml | 10 - specs/api/paths/knowledge.yaml | 12 +- .../requests/RowEmbeddingsRequest.yaml | 1 - .../messages/requests/SparqlQueryRequest.yaml | 5 - tests/contract/conftest.py | 7 +- .../test_document_embeddings_contract.py | 14 +- tests/contract/test_message_contracts.py | 6 +- tests/contract/test_orchestrator_contracts.py | 4 - .../contract/test_rows_cassandra_contracts.py | 17 +- .../test_rows_graphql_query_contracts.py | 52 +- tests/contract/test_schema_field_contracts.py | 3 +- .../test_structured_data_contracts.py | 13 +- ...test_agent_structured_query_integration.py | 24 +- .../test_cassandra_config_end_to_end.py | 18 +- .../integration/test_cassandra_integration.py | 10 +- .../test_document_rag_integration.py | 5 - ...test_document_rag_streaming_integration.py | 10 - .../integration/test_graph_rag_integration.py | 30 +- .../test_graph_rag_streaming_integration.py | 8 - .../test_import_export_graceful_shutdown.py | 1 - .../test_kg_extract_store_integration.py | 26 +- .../integration/test_nlp_query_integration.py | 18 +- .../test_object_extraction_integration.py | 60 +- .../test_prompt_streaming_integration.py | 5 +- .../test_rag_streaming_protocol.py | 8 - .../test_rows_cassandra_integration.py | 115 ++- .../test_rows_graphql_query_integration.py | 22 +- .../test_structured_query_integration.py | 2 - .../test_agent_service_non_streaming.py | 10 +- tests/unit/test_agent/test_aggregator.py | 10 +- .../test_agent/test_completion_dispatch.py | 4 +- ...est_orchestrator_provenance_integration.py | 1 - .../test_agent/test_pattern_base_subagent.py | 1 - tests/unit/test_agent/test_tool_service.py | 30 +- .../test_agent/test_tool_service_lifecycle.py | 141 +-- .../test_base/test_async_processor_config.py | 257 +++--- .../test_document_embeddings_client.py | 3 - .../unit/test_base/test_flow_base_modules.py | 7 +- .../test_base/test_flow_parameter_specs.py | 4 +- tests/unit/test_base/test_flow_processor.py | 32 +- tests/unit/test_chunking/conftest.py | 4 - .../test_chunking/test_recursive_chunker.py | 1 - .../unit/test_chunking/test_token_chunker.py | 1 - tests/unit/test_cli/test_config_commands.py | 18 +- tests/unit/test_cli/test_load_knowledge.py | 22 +- tests/unit/test_cli/test_tool_commands.py | 8 +- .../test_sync_document_embeddings_client.py | 3 - .../test_graph_rag_concurrency.py | 2 - .../unit/test_cores/test_knowledge_manager.py | 28 +- .../test_decoding/test_universal_processor.py | 6 +- .../test_milvus_collection_naming.py | 30 +- .../test_document_embeddings_processor.py | 8 +- .../test_graph_embeddings_processor.py | 7 +- .../test_row_embeddings_processor.py | 35 +- .../test_definitions_batching.py | 12 +- .../test_relationships_batching.py | 9 +- .../unit/test_gateway/test_config_receiver.py | 232 +++-- .../test_core_import_export_roundtrip.py | 32 +- .../test_gateway/test_dispatch_manager.py | 96 +- .../test_entity_contexts_import_dispatcher.py | 1 - ...test_graph_embeddings_import_dispatcher.py | 1 - .../test_rows_import_dispatcher.py | 1 - .../test_text_document_translator.py | 1 - tests/unit/test_knowledge_graph/conftest.py | 5 +- .../test_agent_extraction.py | 2 - .../test_object_extraction_logic.py | 4 +- .../test_triple_construction.py | 2 - .../test_librarian/test_chunked_upload.py | 68 +- .../test_provenance/test_dag_structure.py | 12 +- .../test_doc_embeddings_milvus_query.py | 40 +- .../test_doc_embeddings_pinecone_query.py | 28 +- .../test_doc_embeddings_qdrant_query.py | 20 +- .../test_graph_embeddings_milvus_query.py | 34 +- .../test_graph_embeddings_pinecone_query.py | 22 +- .../test_graph_embeddings_qdrant_query.py | 16 +- ...st_memgraph_workspace_collection_query.py} | 119 ++- ... test_neo4j_workspace_collection_query.py} | 127 ++- .../test_query/test_rows_cassandra_query.py | 50 +- .../test_triples_cassandra_query.py | 50 +- .../test_query/test_triples_falkordb_query.py | 27 +- .../test_query/test_triples_memgraph_query.py | 27 +- .../test_query/test_triples_neo4j_query.py | 12 +- .../test_metadata_preservation.py | 24 +- .../test_null_embedding_protection.py | 36 +- .../unit/test_retrieval/test_document_rag.py | 27 +- .../test_document_rag_service.py | 24 +- tests/unit/test_retrieval/test_graph_rag.py | 32 +- .../test_graph_rag_explain_forwarding.py | 1 - .../test_retrieval/test_graph_rag_service.py | 3 - tests/unit/test_retrieval/test_nlp_query.py | 10 +- .../test_schema_selection.py | 3 +- .../test_retrieval/test_structured_query.py | 2 - .../test_doc_embeddings_milvus_storage.py | 73 +- .../test_doc_embeddings_pinecone_storage.py | 35 +- .../test_doc_embeddings_qdrant_storage.py | 30 +- .../test_graph_embeddings_milvus_storage.py | 33 +- .../test_graph_embeddings_pinecone_storage.py | 28 +- .../test_graph_embeddings_qdrant_storage.py | 12 +- ...emgraph_workspace_collection_isolation.py} | 269 +++--- ...t_neo4j_workspace_collection_isolation.py} | 394 ++++---- .../test_row_embeddings_qdrant_storage.py | 53 +- .../test_rows_cassandra_storage.py | 121 +-- .../test_triples_cassandra_storage.py | 41 +- .../test_triples_falkordb_storage.py | 93 +- .../test_triples_memgraph_storage.py | 89 +- .../test_triples_neo4j_storage.py | 101 +-- .../test_row_embeddings_query.py | 54 +- .../test_tables/test_knowledge_table_store.py | 17 +- ...ocument_embeddings_translator_roundtrip.py | 2 - .../test_knowledge_translator_roundtrip.py | 9 +- trustgraph-base/trustgraph/api/__init__.py | 3 +- trustgraph-base/trustgraph/api/api.py | 33 +- trustgraph-base/trustgraph/api/async_flow.py | 60 +- .../trustgraph/api/async_socket_client.py | 30 +- trustgraph-base/trustgraph/api/bulk_client.py | 12 +- trustgraph-base/trustgraph/api/collection.py | 101 +-- trustgraph-base/trustgraph/api/config.py | 58 +- .../trustgraph/api/explainability.py | 89 +- trustgraph-base/trustgraph/api/flow.py | 242 +---- trustgraph-base/trustgraph/api/knowledge.py | 92 +- trustgraph-base/trustgraph/api/library.py | 168 ++-- .../trustgraph/api/socket_client.py | 37 +- trustgraph-base/trustgraph/api/types.py | 17 +- .../trustgraph/base/async_processor.py | 172 ++-- .../trustgraph/base/chunking_service.py | 5 +- .../base/collection_config_handler.py | 129 +-- .../trustgraph/base/config_client.py | 38 +- .../trustgraph/base/consumer_spec.py | 5 +- .../base/document_embeddings_client.py | 4 +- .../base/document_embeddings_query_service.py | 4 +- .../base/document_embeddings_store_service.py | 3 +- .../trustgraph/base/dynamic_tool_service.py | 11 +- trustgraph-base/trustgraph/base/flow.py | 7 +- .../trustgraph/base/flow_processor.py | 69 +- .../base/graph_embeddings_client.py | 4 +- .../base/graph_embeddings_query_service.py | 4 +- .../base/graph_embeddings_store_service.py | 3 +- .../trustgraph/base/graph_rag_client.py | 4 +- .../trustgraph/base/librarian_client.py | 24 +- .../trustgraph/base/request_response_spec.py | 5 +- .../base/row_embeddings_query_client.py | 3 +- .../base/structured_query_client.py | 3 +- .../trustgraph/base/subscriber_spec.py | 2 +- .../trustgraph/base/tool_service.py | 1 + .../trustgraph/base/tool_service_client.py | 8 +- .../trustgraph/base/triples_client.py | 7 +- .../trustgraph/base/triples_query_service.py | 12 +- .../trustgraph/base/triples_store_service.py | 5 +- .../trustgraph/clients/config_client.py | 29 + .../clients/document_embeddings_client.py | 4 +- .../trustgraph/clients/document_rag_client.py | 5 +- .../clients/graph_embeddings_client.py | 4 +- .../trustgraph/clients/graph_rag_client.py | 5 +- .../clients/row_embeddings_client.py | 4 +- .../clients/triples_query_client.py | 5 +- .../trustgraph/messaging/translators/agent.py | 2 - .../messaging/translators/collection.py | 8 +- .../messaging/translators/config.py | 19 +- .../messaging/translators/document_loading.py | 12 - .../messaging/translators/embeddings_query.py | 6 - .../trustgraph/messaging/translators/flow.py | 5 +- .../messaging/translators/knowledge.py | 12 +- .../messaging/translators/library.py | 6 +- .../messaging/translators/metadata.py | 12 +- .../messaging/translators/retrieval.py | 4 - .../messaging/translators/rows_query.py | 2 - .../messaging/translators/sparql_query.py | 2 - .../messaging/translators/structured_query.py | 6 +- .../messaging/translators/triples.py | 4 +- .../trustgraph/schema/core/metadata.py | 5 +- .../trustgraph/schema/knowledge/knowledge.py | 6 +- .../trustgraph/schema/services/agent.py | 2 - .../trustgraph/schema/services/collection.py | 13 +- .../trustgraph/schema/services/config.py | 40 +- .../trustgraph/schema/services/flow.py | 4 +- .../trustgraph/schema/services/library.py | 14 +- .../trustgraph/schema/services/query.py | 4 - .../trustgraph/schema/services/retrieval.py | 2 - .../trustgraph/schema/services/rows_query.py | 1 - .../schema/services/sparql_query.py | 1 - .../schema/services/structured_query.py | 1 - .../schema/services/tool_service.py | 2 - trustgraph-cli/pyproject.toml | 2 + .../trustgraph/cli/add_library_document.py | 30 +- .../trustgraph/cli/delete_collection.py | 35 +- .../trustgraph/cli/delete_config_item.py | 13 +- .../trustgraph/cli/delete_flow_blueprint.py | 21 +- .../trustgraph/cli/delete_kg_core.py | 33 +- .../trustgraph/cli/delete_mcp_tool.py | 24 +- trustgraph-cli/trustgraph/cli/delete_tool.py | 24 +- .../trustgraph/cli/export_workspace_config.py | 114 +++ .../trustgraph/cli/get_config_item.py | 13 +- .../trustgraph/cli/get_document_content.py | 18 +- .../trustgraph/cli/get_flow_blueprint.py | 20 +- trustgraph-cli/trustgraph/cli/get_kg_core.py | 33 +- .../trustgraph/cli/graph_to_turtle.py | 23 +- .../trustgraph/cli/import_workspace_config.py | 143 +++ .../trustgraph/cli/init_trustgraph.py | 22 +- trustgraph-cli/trustgraph/cli/invoke_agent.py | 36 +- .../cli/invoke_document_embeddings.py | 20 +- .../trustgraph/cli/invoke_document_rag.py | 45 +- .../trustgraph/cli/invoke_embeddings.py | 13 +- .../trustgraph/cli/invoke_graph_embeddings.py | 20 +- .../trustgraph/cli/invoke_graph_rag.py | 102 +-- trustgraph-cli/trustgraph/cli/invoke_llm.py | 12 +- .../trustgraph/cli/invoke_mcp_tool.py | 20 +- .../trustgraph/cli/invoke_nlp_query.py | 22 +- .../trustgraph/cli/invoke_prompt.py | 12 +- .../trustgraph/cli/invoke_row_embeddings.py | 20 +- .../trustgraph/cli/invoke_rows_query.py | 31 +- .../trustgraph/cli/invoke_sparql_query.py | 28 +- .../trustgraph/cli/invoke_structured_query.py | 32 +- .../trustgraph/cli/list_collections.py | 38 +- .../trustgraph/cli/list_config_items.py | 13 +- .../trustgraph/cli/list_explain_traces.py | 12 +- .../trustgraph/cli/load_doc_embeds.py | 16 +- trustgraph-cli/trustgraph/cli/load_kg_core.py | 36 +- .../trustgraph/cli/load_knowledge.py | 38 +- .../trustgraph/cli/load_sample_documents.py | 29 +- .../trustgraph/cli/load_structured_data.py | 63 +- trustgraph-cli/trustgraph/cli/load_turtle.py | 29 +- .../trustgraph/cli/put_config_item.py | 13 +- .../trustgraph/cli/put_flow_blueprint.py | 13 +- trustgraph-cli/trustgraph/cli/put_kg_core.py | 43 +- trustgraph-cli/trustgraph/cli/query_graph.py | 26 +- .../trustgraph/cli/remove_library_document.py | 28 +- .../trustgraph/cli/save_doc_embeds.py | 15 +- .../trustgraph/cli/set_collection.py | 25 +- trustgraph-cli/trustgraph/cli/set_mcp_tool.py | 12 +- trustgraph-cli/trustgraph/cli/set_prompt.py | 15 +- .../trustgraph/cli/set_token_costs.py | 15 +- trustgraph-cli/trustgraph/cli/set_tool.py | 12 +- trustgraph-cli/trustgraph/cli/show_config.py | 12 +- .../trustgraph/cli/show_explain_trace.py | 59 +- .../cli/show_extraction_provenance.py | 40 +- .../trustgraph/cli/show_flow_blueprints.py | 23 +- .../trustgraph/cli/show_flow_state.py | 17 +- trustgraph-cli/trustgraph/cli/show_flows.py | 28 +- trustgraph-cli/trustgraph/cli/show_graph.py | 21 +- .../trustgraph/cli/show_kg_cores.py | 21 +- .../trustgraph/cli/show_library_documents.py | 16 +- .../trustgraph/cli/show_library_processing.py | 29 +- .../trustgraph/cli/show_mcp_tools.py | 13 +- trustgraph-cli/trustgraph/cli/show_prompts.py | 13 +- .../trustgraph/cli/show_token_costs.py | 13 +- trustgraph-cli/trustgraph/cli/show_tools.py | 13 +- trustgraph-cli/trustgraph/cli/start_flow.py | 13 +- .../cli/start_library_processing.py | 38 +- trustgraph-cli/trustgraph/cli/stop_flow.py | 12 +- .../trustgraph/cli/stop_library_processing.py | 30 +- .../trustgraph/cli/unload_kg_core.py | 24 +- .../trustgraph/cli/verify_system_status.py | 31 +- .../trustgraph/agent/mcp_tool/service.py | 34 +- .../agent/orchestrator/aggregator.py | 3 +- .../agent/orchestrator/pattern_base.py | 67 +- .../agent/orchestrator/plan_pattern.py | 32 +- .../agent/orchestrator/react_pattern.py | 15 +- .../trustgraph/agent/orchestrator/service.py | 48 +- .../agent/orchestrator/supervisor_pattern.py | 19 +- .../trustgraph/agent/react/service.py | 79 +- .../trustgraph/agent/react/tools.py | 33 +- .../trustgraph/chunking/recursive/chunker.py | 6 +- .../trustgraph/chunking/token/chunker.py | 6 +- .../trustgraph/config/service/config.py | 154 ++-- .../trustgraph/config/service/service.py | 9 +- trustgraph-flow/trustgraph/cores/knowledge.py | 29 +- trustgraph-flow/trustgraph/cores/service.py | 12 +- .../decoding/mistral_ocr/processor.py | 8 +- .../trustgraph/decoding/pdf/pdf_decoder.py | 8 +- .../direct/milvus_doc_embeddings.py | 52 +- .../direct/milvus_graph_embeddings.py | 52 +- .../embeddings/row_embeddings/embeddings.py | 59 +- .../trustgraph/extract/kg/agent/extract.py | 45 +- .../extract/kg/definitions/extract.py | 2 - .../trustgraph/extract/kg/ontology/extract.py | 130 ++- .../extract/kg/relationships/extract.py | 1 - .../trustgraph/extract/kg/rows/processor.py | 64 +- .../trustgraph/flow/service/flow.py | 349 ++++--- .../trustgraph/flow/service/service.py | 7 +- .../trustgraph/gateway/config/receiver.py | 125 ++- .../gateway/dispatch/core_export.py | 6 +- .../gateway/dispatch/core_import.py | 8 +- .../gateway/dispatch/document_stream.py | 8 +- .../dispatch/entity_contexts_import.py | 1 - .../dispatch/graph_embeddings_import.py | 1 - .../trustgraph/gateway/dispatch/manager.py | 64 +- .../trustgraph/gateway/dispatch/mux.py | 24 +- .../gateway/dispatch/rows_import.py | 1 - .../trustgraph/gateway/dispatch/serialize.py | 16 +- .../gateway/dispatch/triples_import.py | 1 - .../librarian/collection_manager.py | 55 +- .../trustgraph/librarian/librarian.py | 62 +- .../trustgraph/librarian/service.py | 30 +- .../trustgraph/metering/counter.py | 33 +- .../trustgraph/prompt/template/service.py | 56 +- .../query/doc_embeddings/milvus/service.py | 4 +- .../query/doc_embeddings/pinecone/service.py | 4 +- .../query/doc_embeddings/qdrant/service.py | 4 +- .../query/graph_embeddings/milvus/service.py | 4 +- .../graph_embeddings/pinecone/service.py | 4 +- .../query/graph_embeddings/qdrant/service.py | 4 +- .../trustgraph/query/graphql/schema.py | 6 +- .../query/ontology/query_explanation.py | 4 +- .../query/ontology/query_service.py | 2 +- .../query/ontology/question_analyzer.py | 2 +- .../query/row_embeddings/qdrant/service.py | 20 +- .../query/rows/cassandra/service.py | 72 +- .../trustgraph/query/sparql/algebra.py | 86 +- .../trustgraph/query/sparql/service.py | 2 +- .../query/triples/cassandra/service.py | 26 +- .../query/triples/falkordb/service.py | 2 +- .../query/triples/memgraph/service.py | 133 ++- .../trustgraph/query/triples/neo4j/service.py | 136 ++- .../retrieval/document_rag/document_rag.py | 15 +- .../trustgraph/retrieval/document_rag/rag.py | 21 +- .../retrieval/graph_rag/graph_rag.py | 49 +- .../trustgraph/retrieval/graph_rag/rag.py | 33 +- .../trustgraph/retrieval/nlp_query/service.py | 67 +- .../retrieval/structured_diag/service.py | 67 +- .../retrieval/structured_query/service.py | 4 +- .../storage/doc_embeddings/milvus/write.py | 20 +- .../storage/doc_embeddings/pinecone/write.py | 22 +- .../storage/doc_embeddings/qdrant/write.py | 22 +- .../storage/graph_embeddings/milvus/write.py | 20 +- .../graph_embeddings/pinecone/write.py | 22 +- .../storage/graph_embeddings/qdrant/write.py | 22 +- .../trustgraph/storage/knowledge/store.py | 4 +- .../storage/row_embeddings/qdrant/write.py | 41 +- .../storage/rows/cassandra/write.py | 95 +- .../storage/triples/cassandra/write.py | 48 +- .../storage/triples/falkordb/write.py | 106 ++- .../storage/triples/memgraph/write.py | 140 ++- .../trustgraph/storage/triples/neo4j/write.py | 119 ++- trustgraph-flow/trustgraph/tables/config.py | 79 +- .../trustgraph/tables/knowledge.py | 74 +- trustgraph-flow/trustgraph/tables/library.py | 139 ++- .../trustgraph/tool_service/joke/service.py | 11 +- trustgraph-mcp/trustgraph/mcp_server/mcp.py | 175 ++-- .../trustgraph/decoding/ocr/pdf_decoder.py | 8 +- .../decoding/universal/processor.py | 11 +- 377 files changed, 6868 insertions(+), 5785 deletions(-) create mode 100644 docs/tech-specs/data-ownership-model.md create mode 100644 docs/tech-specs/iam.md delete mode 100644 specs/api/components/parameters/User.yaml rename tests/unit/test_query/{test_memgraph_user_collection_query.py => test_memgraph_workspace_collection_query.py} (76%) rename tests/unit/test_query/{test_neo4j_user_collection_query.py => test_neo4j_workspace_collection_query.py} (75%) rename tests/unit/test_storage/{test_memgraph_user_collection_isolation.py => test_memgraph_workspace_collection_isolation.py} (53%) rename tests/unit/test_storage/{test_neo4j_user_collection_isolation.py => test_neo4j_workspace_collection_isolation.py} (51%) create mode 100644 trustgraph-cli/trustgraph/cli/export_workspace_config.py create mode 100644 trustgraph-cli/trustgraph/cli/import_workspace_config.py diff --git a/.gitignore b/.gitignore index daeba074..32942156 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ trustgraph-parquet/trustgraph/parquet_version.py trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-mcp/trustgraph/mcp_version.py +trustgraph/trustgraph/trustgraph_version.py vertexai/ \ No newline at end of file diff --git a/Makefile b/Makefile index 85f10fdd..0f0f37b2 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ container-bedrock container-vertexai \ container-hf container-ocr \ container-unstructured container-mcp -some-containers: container-base container-flow +some-containers: container-base container-flow container-unstructured push: ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} diff --git a/docs/tech-specs/data-ownership-model.md b/docs/tech-specs/data-ownership-model.md new file mode 100644 index 00000000..ea94ec46 --- /dev/null +++ b/docs/tech-specs/data-ownership-model.md @@ -0,0 +1,309 @@ +--- +layout: default +title: "Data Ownership and Information Separation" +parent: "Tech Specs" +--- + +# Data Ownership and Information Separation + +## Purpose + +This document defines the logical ownership model for data in +TrustGraph: what the artefacts are, who owns them, and how they relate +to each other. + +The IAM spec ([iam.md](iam.md)) describes authentication and +authorisation mechanics. This spec addresses the prior question: what +are the boundaries around data, and who owns what? + +## Concepts + +### Workspace + +A workspace is the primary isolation boundary. It represents an +organisation, team, or independent operating unit. All data belongs to +exactly one workspace. Cross-workspace access is never permitted through +the API. + +A workspace owns: +- Source documents +- Flows (processing pipeline definitions) +- Knowledge cores (stored extraction output) +- Collections (organisational units for extracted knowledge) + +### Collection + +A collection is an organisational unit within a workspace. It groups +extracted knowledge produced from source documents. A workspace can +have multiple collections, allowing: + +- Processing the same documents with different parameters or models. +- Maintaining separate knowledge bases for different purposes. +- Deleting extracted knowledge without deleting source documents. + +Collections do not own source documents. A source document exists at the +workspace level and can be processed into multiple collections. + +### Source document + +A source document (PDF, text file, etc.) is raw input uploaded to the +system. Documents belong to the workspace, not to a specific collection. + +This is intentional. A document is an asset that exists independently +of how it is processed. The same PDF might be processed into multiple +collections with different chunking parameters or extraction models. +Tying a document to a single collection would force re-upload for each +collection. + +### Flow + +A flow defines a processing pipeline: which models to use, what +parameters to apply (chunk size, temperature, etc.), and how processing +services are connected. Flows belong to a workspace. + +The processing services themselves (document-decoder, chunker, +embeddings, LLM completion, etc.) are shared infrastructure — they serve +all workspaces. Each flow has its own queues, keeping data from +different workspaces and flows separate as it moves through the +pipeline. + +Different workspaces can define different flows. Workspace A might use +GPT-5.2 with a chunk size of 2000, while workspace B uses Claude with a +chunk size of 1000. + +### Prompts + +Prompts are templates that control how the LLM behaves during knowledge +extraction and query answering. They belong to a workspace, allowing +different workspaces to have different extraction strategies, response +styles, or domain-specific instructions. + +### Ontology + +An ontology defines the concepts, entities, and relationships that the +extraction pipeline looks for in source documents. Ontologies belong to +a workspace. A medical workspace might define ontologies around diseases, +symptoms, and treatments, while a legal workspace defines ontologies +around statutes, precedents, and obligations. + +### Schemas + +Schemas define structured data types for extraction. They specify what +fields to extract, their types, and how they relate. Schemas belong to +a workspace, as different workspaces extract different structured +information from their documents. + +### Tools, tool services, and MCP tools + +Tools define capabilities available to agents: what actions they can +take, what external services they can call. Tool services configure how +tools connect to backend services. MCP tools configure connections to +remote MCP servers, including authentication tokens. All belong to a +workspace. + +### Agent patterns and agent task types + +Agent patterns define agent behaviour strategies (how an agent reasons, +what steps it follows). Agent task types define the kinds of tasks +agents can perform. Both belong to a workspace, as different workspaces +may have different agent configurations. + +### Token costs + +Token cost definitions specify pricing for LLM token usage per model. +These belong to a workspace since different workspaces may use different +models or have different billing arrangements. + +### Flow blueprints + +Flow blueprints are templates for creating flows. They define the +default pipeline structure and parameters. Blueprints belong to a +workspace, allowing workspaces to define custom processing templates. + +### Parameter types + +Parameter types define the kinds of parameters that flows accept (e.g. +"llm-model", "temperature"), including their defaults and validation +rules. They belong to a workspace since workspaces that define custom +flows need to define the parameter types those flows use. + +### Interface descriptions + +Interface descriptions define the connection points of a flow — what +queues and topics it uses. They belong to a workspace since they +describe workspace-owned flows. + +### Knowledge core + +A knowledge core is a stored snapshot of extracted knowledge (triples +and graph embeddings). Knowledge cores belong to a workspace and can be +loaded into any collection within that workspace. + +Knowledge cores serve as a portable extraction output. You process +documents through a flow, the pipeline produces triples and embeddings, +and the results can be stored as a knowledge core. That core can later +be loaded into a different collection or reloaded after a collection is +cleared. + +### Extracted knowledge + +Extracted knowledge is the live, queryable content within a collection: +triples in the knowledge graph, graph embeddings, and document +embeddings. It is the product of processing source documents through a +flow into a specific collection. + +Extracted knowledge is scoped to a workspace and a collection. It +cannot exist without both. + +### Processing record + +A processing record tracks which source document was processed, through +which flow, into which collection. It links the source document +(workspace-scoped) to the extracted knowledge (workspace + collection +scoped). + +## Ownership summary + +| Artefact | Owned by | Shared across collections? | +|----------|----------|---------------------------| +| Workspaces | Global (platform) | N/A | +| User accounts | Global (platform) | N/A | +| API keys | Global (platform) | N/A | +| Source documents | Workspace | Yes | +| Flows | Workspace | N/A | +| Flow blueprints | Workspace | N/A | +| Prompts | Workspace | N/A | +| Ontologies | Workspace | N/A | +| Schemas | Workspace | N/A | +| Tools | Workspace | N/A | +| Tool services | Workspace | N/A | +| MCP tools | Workspace | N/A | +| Agent patterns | Workspace | N/A | +| Agent task types | Workspace | N/A | +| Token costs | Workspace | N/A | +| Parameter types | Workspace | N/A | +| Interface descriptions | Workspace | N/A | +| Knowledge cores | Workspace | Yes — can be loaded into any collection | +| Collections | Workspace | N/A | +| Extracted knowledge | Workspace + collection | No | +| Processing records | Workspace + collection | No | + +## Scoping summary + +### Global (system-level) + +A small number of artefacts exist outside any workspace: + +- **Workspace registry** — the list of workspaces itself +- **User accounts** — users reference a workspace but are not owned by + one +- **API keys** — belong to users, not workspaces + +These are managed by the IAM layer and exist at the platform level. + +### Workspace-owned + +All other configuration and data is workspace-owned: + +- Flow definitions and parameters +- Flow blueprints +- Prompts +- Ontologies +- Schemas +- Tools, tool services, and MCP tools +- Agent patterns and agent task types +- Token costs +- Parameter types +- Interface descriptions +- Collection definitions +- Knowledge cores +- Source documents +- Collections and their extracted knowledge + +## Relationship between artefacts + +``` +Platform (global) + | + +-- Workspaces + | | + +-- User accounts (each assigned to a workspace) + | | + +-- API keys (belong to users) + +Workspace + | + +-- Source documents (uploaded, unprocessed) + | + +-- Flows (pipeline definitions: models, parameters, queues) + | + +-- Flow blueprints (templates for creating flows) + | + +-- Prompts (LLM instruction templates) + | + +-- Ontologies (entity and relationship definitions) + | + +-- Schemas (structured data type definitions) + | + +-- Tools, tool services, MCP tools (agent capabilities) + | + +-- Agent patterns and agent task types (agent behaviour) + | + +-- Token costs (LLM pricing per model) + | + +-- Parameter types (flow parameter definitions) + | + +-- Interface descriptions (flow connection points) + | + +-- Knowledge cores (stored extraction snapshots) + | + +-- Collections + | + +-- Extracted knowledge (triples, embeddings) + | + +-- Processing records (links documents to collections) +``` + +A typical workflow: + +1. A source document is uploaded to the workspace. +2. A flow defines how to process it (which models, what parameters). +3. The document is processed through the flow into a collection. +4. Processing records track what was processed. +5. Extracted knowledge (triples, embeddings) is queryable within the + collection. +6. Optionally, the extracted knowledge is stored as a knowledge core + for later reuse. + +## Implementation notes + +The current codebase uses a `user` field in message metadata and storage +partition keys to identify the workspace. The `collection` field +identifies the collection within that workspace. The IAM spec describes +how the gateway maps authenticated credentials to a workspace identity +and sets these fields. + +For details on how each storage backend implements this scoping, see: + +- [Entity-Centric Graph](entity-centric-graph.md) — Cassandra KG schema +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) +- [Collection Management](collection-management.md) + +### Known inconsistencies in current implementation + +- **Pipeline intermediate tables** do not include collection in their + partition keys. Re-processing the same document into a different + collection may overwrite intermediate state. +- **Processing metadata** stores collection in the row payload but not + in the partition key, making collection-based queries inefficient. +- **Upload sessions** are keyed by upload ID, not workspace. The + gateway should validate workspace ownership before allowing + operations on upload sessions. + +## References + +- [Identity and Access Management](iam.md) +- [Collection Management](collection-management.md) +- [Entity-Centric Graph](entity-centric-graph.md) +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) +- [Multi-Tenant Support](multi-tenant-support.md) diff --git a/docs/tech-specs/flow-class-definition.md b/docs/tech-specs/flow-class-definition.md index 94229b72..3a81bf71 100644 --- a/docs/tech-specs/flow-class-definition.md +++ b/docs/tech-specs/flow-class-definition.md @@ -20,8 +20,8 @@ Defines shared service processors that are instantiated once per flow blueprint. ```json "class": { "service-name:{class}": { - "request": "queue-pattern:{class}", - "response": "queue-pattern:{class}", + "request": "queue-pattern:{workspace}:{class}", + "response": "queue-pattern:{workspace}:{class}", "settings": { "setting-name": "fixed-value", "parameterized-setting": "{parameter-name}" @@ -31,11 +31,11 @@ Defines shared service processors that are instantiated once per flow blueprint. ``` **Characteristics:** -- Shared across all flow instances of the same class +- Shared across all flow instances of the same class within a workspace - Typically expensive or stateless services (LLMs, embedding models) -- Use `{class}` template variable for queue naming +- Use `{workspace}` and `{class}` template variables for queue naming - Settings can be fixed values or parameterized with `{parameter-name}` syntax -- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}` +- Examples: `embeddings:{workspace}:{class}`, `text-completion:{workspace}:{class}` ### 2. Flow Section Defines flow-specific processors that are instantiated for each individual flow instance. Each flow gets its own isolated set of these processors. @@ -43,8 +43,8 @@ Defines flow-specific processors that are instantiated for each individual flow ```json "flow": { "processor-name:{id}": { - "input": "queue-pattern:{id}", - "output": "queue-pattern:{id}", + "input": "queue-pattern:{workspace}:{id}", + "output": "queue-pattern:{workspace}:{id}", "settings": { "setting-name": "fixed-value", "parameterized-setting": "{parameter-name}" @@ -56,9 +56,9 @@ Defines flow-specific processors that are instantiated for each individual flow **Characteristics:** - Unique instance per flow - Handle flow-specific data and state -- Use `{id}` template variable for queue naming +- Use `{workspace}` and `{id}` template variables for queue naming - Settings can be fixed values or parameterized with `{parameter-name}` syntax -- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}` +- Examples: `chunker:{workspace}:{id}`, `pdf-decoder:{workspace}:{id}` ### 3. Interfaces Section Defines the entry points and interaction contracts for the flow. These form the API surface for external systems and internal component communication. @@ -68,8 +68,8 @@ Interfaces can take two forms: **Fire-and-Forget Pattern** (single queue): ```json "interfaces": { - "document-load": "persistent://tg/flow/document-load:{id}", - "triples-store": "persistent://tg/flow/triples-store:{id}" + "document-load": "persistent://tg/flow/{workspace}:document-load:{id}", + "triples-store": "persistent://tg/flow/{workspace}:triples-store:{id}" } ``` @@ -77,8 +77,8 @@ Interfaces can take two forms: ```json "interfaces": { "embeddings": { - "request": "non-persistent://tg/request/embeddings:{class}", - "response": "non-persistent://tg/response/embeddings:{class}" + "request": "non-persistent://tg/request/{workspace}:embeddings:{class}", + "response": "non-persistent://tg/response/{workspace}:embeddings:{class}" } } ``` @@ -117,6 +117,16 @@ Additional information about the flow blueprint: ### System Variables +#### {workspace} +- Replaced with the workspace identifier +- Isolates queue names between workspaces so that two workspaces + starting the same flow do not share queues +- Must be included in all queue name patterns to ensure workspace + isolation +- Example: `ws-acme`, `ws-globex` +- All blueprint templates must include `{workspace}` in queue name + patterns + #### {id} - Replaced with the unique flow instance identifier - Creates isolated resources for each flow diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md new file mode 100644 index 00000000..5de50749 --- /dev/null +++ b/docs/tech-specs/iam.md @@ -0,0 +1,858 @@ +--- +layout: default +title: "Identity and Access Management" +parent: "Tech Specs" +--- + +# Identity and Access Management + +## Problem Statement + +TrustGraph has no meaningful identity or access management. The system +relies on a single shared gateway token for authentication and an +honour-system `user` query parameter for data isolation. This creates +several problems: + +- **No user identity.** There are no user accounts, no login, and no way + to know who is making a request. The `user` field in message metadata + is a caller-supplied string with no validation — any client can claim + to be any user. + +- **No access control.** A valid gateway token grants unrestricted access + to every endpoint, every user's data, every collection, and every + administrative operation. There is no way to limit what an + authenticated caller can do. + +- **No credential isolation.** All callers share one static token. There + is no per-user credential, no token expiration, and no rotation + mechanism. Revoking access means changing the shared token, which + affects all callers. + +- **Data isolation is unenforced.** Storage backends (Cassandra, Neo4j, + Qdrant) filter queries by `user` and `collection`, but the gateway + does not prevent a caller from specifying another user's identity. + Cross-user data access is trivial. + +- **No audit trail.** There is no logging of who accessed what. Without + user identity, audit logging is impossible. + +These gaps make the system unsuitable for multi-user deployments, +multi-tenant SaaS, or any environment where access needs to be +controlled or audited. + +## Current State + +### Authentication + +The API gateway supports a single shared token configured via the +`GATEWAY_SECRET` environment variable or `--api-token` CLI argument. If +unset, authentication is disabled entirely. When enabled, every HTTP +endpoint requires an `Authorization: Bearer ` header. WebSocket +connections pass the token as a query parameter. + +Implementation: `trustgraph-flow/trustgraph/gateway/auth.py` + +```python +class Authenticator: + def __init__(self, token=None, allow_all=False): + self.token = token + self.allow_all = allow_all + + def permitted(self, token, roles): + if self.allow_all: return True + if self.token != token: return False + return True +``` + +The `roles` parameter is accepted but never evaluated. All authenticated +requests have identical privileges. + +MCP tool configurations support an optional per-tool `auth-token` for +service-to-service authentication with remote MCP servers. These are +static, system-wide tokens — not per-user credentials. See +[mcp-tool-bearer-token.md](mcp-tool-bearer-token.md) for details. + +### User identity + +The `user` field is passed explicitly by the caller as a query parameter +(e.g. `?user=trustgraph`) or set by CLI tools. It flows through the +system in the core `Metadata` dataclass: + +```python +@dataclass +class Metadata: + id: str = "" + root: str = "" + user: str = "" + collection: str = "" +``` + +There is no user registration, login, user database, or session +management. + +### Data isolation + +The `user` + `collection` pair is used at the storage layer to partition +data: + +- **Cassandra**: queries filter by `user` and `collection` columns +- **Neo4j**: queries filter by `user` and `collection` properties +- **Qdrant**: vector search filters by `user` and `collection` metadata + +| Layer | Isolation mechanism | Enforced by | +|-------|-------------------|-------------| +| Gateway | Single shared token | `Authenticator` class | +| Message metadata | `user` + `collection` fields | Caller (honour system) | +| Cassandra | Column filters on `user`, `collection` | Query layer | +| Neo4j | Property filters on `user`, `collection` | Query layer | +| Qdrant | Metadata filters on `user`, `collection` | Query layer | +| Pub/sub topics | Per-flow topic namespacing | Flow service | + +The storage-layer isolation depends on all queries correctly filtering by +`user` and `collection`. There is no gateway-level enforcement preventing +a caller from querying another user's data by passing a different `user` +parameter. + +### Configuration and secrets + +| Setting | Source | Default | Purpose | +|---------|--------|---------|---------| +| `GATEWAY_SECRET` | Env var | Empty (auth disabled) | Gateway bearer token | +| `--api-token` | CLI arg | None | Gateway bearer token (overrides env) | +| `PULSAR_API_KEY` | Env var | None | Pub/sub broker auth | +| MCP `auth-token` | Config service | None | Per-tool MCP server auth | + +No secrets are encrypted at rest. The gateway token and MCP tokens are +stored and transmitted in plaintext (aside from any transport-layer +encryption such as TLS). + +### Capabilities that do not exist + +- Per-user authentication (JWT, OAuth, SAML, API keys per user) +- User accounts or user management +- Role-based access control (RBAC) +- Attribute-based access control (ABAC) +- Per-user or per-workspace API keys +- Token expiration or rotation +- Session management +- Per-user rate limiting +- Audit logging of user actions +- Permission checks preventing cross-user data access +- Multi-workspace credential isolation + +### Key files + +| File | Purpose | +|------|---------| +| `trustgraph-flow/trustgraph/gateway/auth.py` | Authenticator class | +| `trustgraph-flow/trustgraph/gateway/service.py` | Gateway init, token config | +| `trustgraph-flow/trustgraph/gateway/endpoint/*.py` | Per-endpoint auth checks | +| `trustgraph-base/trustgraph/schema/core/metadata.py` | `Metadata` dataclass with `user` field | + +## Technical Design + +### Design principles + +- **Auth at the edge.** The gateway is the single enforcement point. + Internal services trust the gateway and do not re-authenticate. + This avoids distributing credential validation across dozens of + microservices. + +- **Identity from credentials, not from callers.** The gateway derives + user identity from authentication credentials. Callers can no longer + self-declare their identity via query parameters. + +- **Workspace isolation by default.** Every authenticated user belongs to + a workspace. All data operations are scoped to that workspace. + Cross-workspace access is not possible through the API. + +- **Extensible API contract.** The API accepts an optional workspace + parameter on every request. This allows the same protocol to support + single-workspace deployments today and multi-workspace extensions in + the future without breaking changes. + +- **Simple roles, not fine-grained permissions.** A small number of + predefined roles controls what operations a user can perform. This is + sufficient for the current API surface and avoids the complexity of + per-resource permission management. + +### Authentication + +The gateway supports two credential types. Both are carried as a Bearer +token in the `Authorization` header for HTTP requests. The gateway +distinguishes them by format. + +For WebSocket connections, credentials are not passed in the URL or +headers. Instead, the client authenticates after connecting by sending +an auth message as the first frame: + +``` +Client: opens WebSocket to /api/v1/socket +Server: accepts connection (unauthenticated state) +Client: sends {"type": "auth", "token": "tg_abc123..."} +Server: validates token + success → {"type": "auth-ok", "workspace": "acme"} + failure → {"type": "auth-failed", "error": "invalid token"} +``` + +The server rejects all non-auth messages until authentication succeeds. +The socket remains open on auth failure, allowing the client to retry +with a different token without reconnecting. The client can also send +a new auth message at any time to re-authenticate — for example, to +refresh an expiring JWT or to switch workspace. The +resolved identity (user, workspace, roles) is updated on each +successful auth. + +#### API keys + +For programmatic access: CLI tools, scripts, and integrations. + +- Opaque tokens (e.g. `tg_a1b2c3d4e5f6...`). Not JWTs — short, + simple, easy to paste into CLI tools and headers. +- Each user has one or more API keys. +- Keys are stored hashed (SHA-256 with salt) in the IAM service. The + plaintext key is returned once at creation time and cannot be + retrieved afterwards. +- Keys can be revoked individually without affecting other users. +- Keys optionally have an expiry date. Expired keys are rejected. + +On each request, the gateway resolves an API key by: + +1. Hashing the token. +2. Checking a local cache (hash → user/workspace/roles). +3. On cache miss, calling the IAM service to resolve. +4. Caching the result with a short TTL (e.g. 60 seconds). + +Revoked keys stop working when the cache entry expires. No push +invalidation is needed. + +#### JWTs (login sessions) + +For interactive access via the UI or WebSocket connections. + +- A user logs in with username and password. The gateway forwards the + request to the IAM service, which validates the credentials and + returns a signed JWT. +- The JWT carries the user ID, workspace, and roles as claims. +- The gateway validates JWTs locally using the IAM service's public + signing key — no service call needed on subsequent requests. +- Token expiry is enforced by standard JWT validation at the time the + request (or WebSocket connection) is made. +- For long-lived WebSocket connections, the JWT is validated at connect + time only. The connection remains authenticated for its lifetime. + +The IAM service manages the signing key. The gateway fetches the public +key at startup (or on first JWT encounter) and caches it. + +#### Login endpoint + +``` +POST /api/v1/auth/login +{ + "username": "alice", + "password": "..." +} +→ { + "token": "eyJ...", + "expires": "2026-04-20T19:00:00Z" +} +``` + +The gateway forwards this to the IAM service, which validates +credentials and returns a signed JWT. The gateway returns the JWT to +the caller. + +#### IAM service delegation + +The gateway stays thin. Its authentication logic is: + +1. Extract Bearer token from header (or query param for WebSocket). +2. If the token has JWT format (dotted structure), validate the + signature locally and extract claims. +3. Otherwise, treat as an API key: hash it and check the local cache. + On cache miss, call the IAM service to resolve. +4. If neither succeeds, return 401. + +All user management, key management, credential validation, and token +signing logic lives in the IAM service. The gateway is a generic +enforcement point that can be replaced without changing the IAM +service. + +#### No legacy token support + +The existing `GATEWAY_SECRET` shared token is removed. All +authentication uses API keys or JWTs. On first start, the bootstrap +process creates a default workspace and admin user with an initial API +key. + +### User identity + +A user belongs to exactly one workspace. The design supports extending +this to multi-workspace access in the future (see +[Extension points](#extension-points)). + +A user record contains: + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique user identifier (UUID) | +| `name` | string | Display name | +| `email` | string | Email address (optional) | +| `workspace` | string | Workspace the user belongs to | +| `roles` | list[string] | Assigned roles (e.g. `["reader"]`) | +| `enabled` | bool | Whether the user can authenticate | +| `created` | datetime | Account creation timestamp | + +The `workspace` field maps to the existing `user` field in `Metadata`. +This means the storage-layer isolation (Cassandra, Neo4j, Qdrant +filtering by `user` + `collection`) works without changes — the gateway +sets the `user` metadata field to the authenticated user's workspace. + +### Workspaces + +A workspace is an isolated data boundary. Users belong to a workspace, +and all data operations are scoped to it. Workspaces map to the existing +`user` field in `Metadata` and the corresponding Cassandra keyspace, +Qdrant collection prefix, and Neo4j property filters. + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique workspace identifier | +| `name` | string | Display name | +| `enabled` | bool | Whether the workspace is active | +| `created` | datetime | Creation timestamp | + +All data operations are scoped to a workspace. The gateway determines +the effective workspace for each request as follows: + +1. If the request includes a `workspace` parameter, validate it against + the user's assigned workspace. + - If it matches, use it. + - If it does not match, return 403. (This could be extended to + check a workspace access grant list.) +2. If no `workspace` parameter is provided, use the user's assigned + workspace. + +The gateway sets the `user` field in `Metadata` to the effective +workspace ID, replacing the caller-supplied `?user=` query parameter. + +This design ensures forward compatibility. Clients that pass a +workspace parameter will work unchanged if multi-workspace support is +added later. Requests for an unassigned workspace get a clear 403 +rather than silent misbehaviour. + +### Roles and access control + +Three roles with fixed permissions: + +| Role | Data operations | Admin operations | System | +|------|----------------|-----------------|--------| +| `reader` | Query knowledge graph, embeddings, RAG | None | None | +| `writer` | All reader operations + load documents, manage collections | None | None | +| `admin` | All writer operations | Config, flows, collection management, user management | Metrics | + +Role checks happen at the gateway before dispatching to backend +services. Each endpoint declares the minimum role required: + +| Endpoint pattern | Minimum role | +|-----------------|--------------| +| `GET /api/v1/socket` (queries) | `reader` | +| `POST /api/v1/librarian` | `writer` | +| `POST /api/v1/flow/*/import/*` | `writer` | +| `POST /api/v1/config` | `admin` | +| `GET /api/v1/flow/*` | `admin` | +| `GET /api/metrics` | `admin` | + +Roles are hierarchical: `admin` implies `writer`, which implies +`reader`. + +### IAM service + +The IAM service is a new backend service that manages all identity and +access data. It is the authority for users, workspaces, API keys, and +credentials. The gateway delegates to it. + +#### Data model + +``` +iam_workspaces ( + id text PRIMARY KEY, + name text, + enabled boolean, + created timestamp +) + +iam_users ( + id text PRIMARY KEY, + workspace text, + name text, + email text, + password_hash text, + roles set, + enabled boolean, + created timestamp +) + +iam_api_keys ( + key_hash text PRIMARY KEY, + user_id text, + name text, + expires timestamp, + created timestamp +) +``` + +A secondary index on `iam_api_keys.user_id` supports listing a user's +keys. + +#### Responsibilities + +- User CRUD (create, list, update, disable) +- Workspace CRUD (create, list, update, disable) +- API key management (create, revoke, list) +- API key resolution (hash → user/workspace/roles) +- Credential validation (username/password → signed JWT) +- JWT signing key management (initialise, rotate) +- Bootstrap (create default workspace and admin user on first start) + +#### Communication + +The IAM service communicates via the standard request/response pub/sub +pattern, the same as the config service. The gateway calls it to +resolve API keys and to handle login requests. User management +operations (create user, revoke key, etc.) also go through the IAM +service. + +### Gateway changes + +The current `Authenticator` class is replaced with a thin authentication +middleware that delegates to the IAM service: + +For HTTP requests: + +1. Extract Bearer token from the `Authorization` header. +2. If the token has JWT format (dotted structure): + - Validate signature locally using the cached public key. + - Extract user ID, workspace, and roles from claims. +3. Otherwise, treat as an API key: + - Hash the token and check the local cache. + - On cache miss, call the IAM service to resolve. + - Cache the result (user/workspace/roles) with a short TTL. +4. If neither succeeds, return 401. +5. If the user or workspace is disabled, return 403. +6. Check the user's role against the endpoint's minimum role. If + insufficient, return 403. +7. Resolve the effective workspace: + - If the request includes a `workspace` parameter, validate it + against the user's assigned workspace. Return 403 on mismatch. + - If no `workspace` parameter, use the user's assigned workspace. +8. Set the `user` field in the request context to the effective + workspace ID. This propagates through `Metadata` to all downstream + services. + +For WebSocket connections: + +1. Accept the connection in an unauthenticated state. +2. Wait for an auth message (`{"type": "auth", "token": "..."}`). +3. Validate the token using the same logic as steps 2-7 above. +4. On success, attach the resolved identity to the connection and + send `{"type": "auth-ok", ...}`. +5. On failure, send `{"type": "auth-failed", ...}` but keep the + socket open. +6. Reject all non-auth messages until authentication succeeds. +7. Accept new auth messages at any time to re-authenticate. + +### CLI changes + +CLI tools authenticate with API keys: + +- `--api-key` argument on all CLI tools, replacing `--api-token`. +- `tg-create-workspace`, `tg-list-workspaces` for workspace management. +- `tg-create-user`, `tg-list-users`, `tg-disable-user` for user + management. +- `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` for + key management. +- `--workspace` argument on tools that operate on workspace-scoped + data. +- The API key is passed as a Bearer token in the same way as the + current shared token, so the transport protocol is unchanged. + +### Audit logging + +With user identity established, the gateway logs: + +- Timestamp, user ID, workspace, endpoint, HTTP method, response status. +- Audit logs are written to the standard logging output (structured + JSON). Integration with external log aggregation (Loki, ELK) is a + deployment concern, not an application concern. + +### Config service changes + +All configuration is workspace-scoped (see +[data-ownership-model.md](data-ownership-model.md)). The config service +needs to support this. + +#### Schema change + +The config table adds workspace as a key dimension: + +``` +config ( + workspace text, + class text, + key text, + value text, + PRIMARY KEY ((workspace, class), key) +) +``` + +#### Request format + +Config requests add a `workspace` field at the request level. The +existing `(type, key)` structure is unchanged within each workspace. + +**Get:** +```json +{ + "operation": "get", + "workspace": "workspace-a", + "keys": [{"type": "prompt", "key": "rag-prompt"}] +} +``` + +**Put:** +```json +{ + "operation": "put", + "workspace": "workspace-a", + "values": [{"type": "prompt", "key": "rag-prompt", "value": "..."}] +} +``` + +**List (all keys of a type within a workspace):** +```json +{ + "operation": "list", + "workspace": "workspace-a", + "type": "prompt" +} +``` + +**Delete:** +```json +{ + "operation": "delete", + "workspace": "workspace-a", + "keys": [{"type": "prompt", "key": "rag-prompt"}] +} +``` + +The workspace is set by: + +- **Gateway** — from the authenticated user's workspace for API-facing + requests. +- **Internal services** — explicitly, based on `Metadata.user` from + the message being processed, or `_system` for operational config. + +#### System config namespace + +Processor-level operational config (logging levels, connection strings, +resource limits) is not workspace-specific. This stays in a reserved +`_system` workspace that is not associated with any user workspace. +Services read system config at startup without needing a workspace +context. + +#### Config change notifications + +The config notify mechanism pushes change notifications via pub/sub +when config is updated. A single update may affect multiple workspaces +and multiple config types. The notification message carries a dict of +changes keyed by config type, with each value being the list of +affected workspaces: + +```json +{ + "version": 42, + "changes": { + "prompt": ["workspace-a", "workspace-b"], + "schema": ["workspace-a"] + } +} +``` + +System config changes use the reserved `_system` workspace: + +```json +{ + "version": 43, + "changes": { + "logging": ["_system"] + } +} +``` + +This structure is keyed by type because handlers register by type. A +handler registered for `prompt` looks up `"prompt"` directly and gets +the list of affected workspaces — no iteration over unrelated types. + +#### Config change handlers + +The current `on_config` hook mechanism needs two modes to support shared +processing services: + +- **Workspace-scoped handlers** — notify when a config type changes in a + specific workspace. The handler looks up its registered type in the + changes dict and checks if its workspace is in the list. Used by the + gateway and by services that serve a single workspace. + +- **Global handlers** — notify when a config type changes in any + workspace. The handler looks up its registered type in the changes + dict and gets the full list of affected workspaces. Used by shared + processing services (prompt-rag, agent manager, etc.) that serve all + workspaces. Each workspace in the list tells the handler which cache + entry to update rather than reloading everything. + +#### Per-workspace config caching + +Shared services that handle messages from multiple workspaces maintain a +per-workspace config cache. When a message arrives, the service looks up +the config for the workspace identified in `Metadata.user`. If the +workspace is not yet cached, the service fetches its config on demand. +Config change notifications update the relevant cache entry. + +### Flow and queue isolation + +Flows are workspace-owned. When two workspaces start flows with the same +name and blueprint, their queues must be separate to prevent data +mixing. + +Flow blueprint templates currently use `{id}` (flow instance ID) and +`{class}` (blueprint name) as template variables in queue names. A new +`{workspace}` variable is added so queue names include the workspace: + +**Current queue names (no workspace isolation):** +``` +flow:tg:document-load:{id} → flow:tg:document-load:default +request:tg:embeddings:{class} → request:tg:embeddings:everything +``` + +**With workspace isolation:** +``` +flow:tg:{workspace}:document-load:{id} → flow:tg:ws-a:document-load:default +request:tg:{workspace}:embeddings:{class} → request:tg:ws-a:embeddings:everything +``` + +The flow service substitutes `{workspace}` from the authenticated +workspace when starting a flow, the same way it substitutes `{id}` and +`{class}` today. + +Processing services are shared infrastructure — they consume from +workspace-specific queues but are not themselves workspace-aware. The +workspace is carried in `Metadata.user` on every message, so services +know which workspace's data they are processing. + +Blueprint templates need updating to include `{workspace}` in all queue +name patterns. For migration, the flow service can inject the workspace +into queue names automatically if the template does not include +`{workspace}`, defaulting to the legacy behaviour for existing +blueprints. + +See [flow-class-definition.md](flow-class-definition.md) for the full +blueprint template specification. + +### What changes and what doesn't + +**Changes:** + +| Component | Change | +|-----------|--------| +| `gateway/auth.py` | Replace `Authenticator` with new auth middleware | +| `gateway/service.py` | Initialise IAM client, configure JWT validation | +| `gateway/endpoint/*.py` | Add role requirement per endpoint | +| Metadata propagation | Gateway sets `user` from workspace, ignores query param | +| Config service | Add workspace dimension to config schema | +| Config table | `PRIMARY KEY ((workspace, class), key)` | +| Config request/response schema | Add `workspace` field | +| Config notify messages | Include workspace ID in change notifications | +| `on_config` handlers | Support workspace-scoped and global modes | +| Shared services | Per-workspace config caching | +| Flow blueprints | Add `{workspace}` template variable to queue names | +| Flow service | Substitute `{workspace}` when starting flows | +| CLI tools | New user management commands, `--api-key` argument | +| Cassandra schema | New `iam_workspaces`, `iam_users`, `iam_api_keys` tables | + +**Does not change:** + +| Component | Reason | +|-----------|--------| +| Internal service-to-service pub/sub | Services trust the gateway | +| `Metadata` dataclass | `user` field continues to carry workspace identity | +| Storage-layer isolation | Same `user` + `collection` filtering | +| Message serialisation | No schema changes | + +### Migration + +This is a breaking change. Existing deployments must be reconfigured: + +1. `GATEWAY_SECRET` is removed. Authentication requires API keys or + JWT login tokens. +2. The `?user=` query parameter is removed. Workspace identity comes + from authentication. +3. On first start, the IAM service bootstraps a default workspace and + admin user. The initial API key is output to the service log. +4. Operators create additional workspaces and users via CLI tools. +5. Flow blueprints must be updated to include `{workspace}` in queue + name patterns. +6. Config data must be migrated to include the workspace dimension. + +## Extension points + +The design includes deliberate extension points for future capabilities. +These are not implemented but the architecture does not preclude them: + +- **Multi-workspace access.** Users could be granted access to + additional workspaces beyond their primary assignment. The workspace + validation step checks a grant list instead of a single assignment. +- **Rules-based access control.** A separate access control service + could evaluate fine-grained policies (per-collection permissions, + operation-level restrictions, time-based access). The gateway + delegates authorisation decisions to this service. +- **External identity provider integration.** SAML, LDAP, and OIDC + flows (group mapping, claims-based role assignment) could be added + to the IAM service. +- **Cross-workspace administration.** A `superadmin` role for platform + operators who manage multiple workspaces. +- **Delegated workspace provisioning.** APIs for programmatic workspace + creation and user onboarding. + +These extensions are additive — they extend the validation logic +without changing the request/response protocol. The gateway can be +replaced with an alternative implementation that supports these +capabilities while the IAM service and backend services remain +unchanged. + +## Implementation plan + +Workspace support is a prerequisite for auth — users are assigned to +workspaces, config is workspace-scoped, and flows use workspace in +queue names. Implementing workspaces first allows the structural changes +to be tested end-to-end without auth complicating debugging. + +### Phase 1: Workspace support (no auth) + +All workspace-scoped data and processing changes. The system works with +workspaces but no authentication — callers pass workspace as a +parameter, honour system. This allows full end-to-end testing: multiple +workspaces with separate flows, config, queues, and data. + +#### Config service + +- Update config client API to accept a workspace parameter on all + requests +- Update config storage schema to add workspace as a key dimension +- Update config notification API to report changes as a dict of + type → workspace list +- Update the processor base class to understand workspaces in config + notifications (workspace-scoped and global handler modes) +- Update all processors to implement workspace-aware config handling + (per-workspace config caching, on-demand fetch) + +#### Flow and queue isolation + +- Update flow blueprints to include `{workspace}` in all queue name + patterns +- Update the flow service to substitute `{workspace}` when starting + flows +- Update all built-in blueprints to include `{workspace}` + +#### CLI tools (workspace support) + +- Add `--workspace` argument to CLI tools that operate on + workspace-scoped data +- Add `tg-create-workspace`, `tg-list-workspaces` commands + +### Phase 2: Authentication and access control + +With workspaces working, add the IAM service and lock down the gateway. + +#### IAM service + +A new service handling identity and access management on behalf of the +API gateway: + +- Add workspace table support (CRUD, enable/disable) +- Add user table support (CRUD, enable/disable, workspace assignment) +- Add roles support (role assignment, role validation) +- Add API key support (create, revoke, list, hash storage) +- Add ability to initialise a JWT signing key for token grants +- Add token grant endpoint: user/password login returns a signed JWT +- Add bootstrap/initialisation mechanism: ability to set the signing + key and create the initial workspace + admin user on first start + +#### API gateway integration + +- Add IAM middleware to the API gateway replacing the current + `Authenticator` +- Add local JWT validation (public key from IAM service) +- Add API key resolution with local cache (hash → user/workspace/roles, + cache miss calls IAM service, short TTL) +- Add login endpoint forwarding to IAM service +- Add workspace resolution: validate requested workspace against user + assignment +- Add role-based endpoint access checks +- Add user management API endpoints (forwarded to IAM service) +- Add audit logging (user ID, workspace, endpoint, method, status) +- WebSocket auth via first-message protocol (auth message after + connect, socket stays open on failure, re-auth supported) + +#### CLI tools (auth support) + +- Add `tg-create-user`, `tg-list-users`, `tg-disable-user` commands +- Add `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` + commands +- Replace `--api-token` with `--api-key` on existing CLI tools + +#### Bootstrap and cutover + +- Create default workspace and admin user on first start if IAM tables + are empty +- Remove `GATEWAY_SECRET` and `?user=` query parameter support + +## Design Decisions + +### IAM data store + +IAM data is stored in dedicated Cassandra tables owned by the IAM +service, not in the config service. Reasons: + +- **Security isolation.** The config service has a broad, generic + protocol. An access control failure on the config service could + expose credentials. A dedicated IAM service with a purpose-built + protocol limits the attack surface and makes security auditing + clearer. +- **Data model fit.** IAM needs indexed lookups (API key hash → user, + list keys by user). The config service's `(workspace, type, key) → + value` model stores opaque JSON strings with no secondary indexes. +- **Scope.** IAM data is global (workspaces, users, keys). Config is + workspace-scoped. Mixing global and workspace-scoped data in the + same store adds complexity. +- **Audit.** IAM operations (key creation, revocation, login attempts) + are security events that should be logged separately from general + config changes. + +## Deferred to future design + +- **OIDC integration.** External identity provider support (SAML, LDAP, + OIDC) is left for future implementation. The extension points section + describes where this fits architecturally. +- **API key scoping.** API keys could be scoped to specific collections + within a workspace rather than granting workspace-wide access. To be + designed when the need arises. +- **tg-init-trustgraph** only initialises a single workspace. + +## References + +- [Data Ownership and Information Separation](data-ownership-model.md) +- [MCP Tool Bearer Token Specification](mcp-tool-bearer-token.md) +- [Multi-Tenant Support Specification](multi-tenant-support.md) +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) diff --git a/specs/api/components/parameters/User.yaml b/specs/api/components/parameters/User.yaml deleted file mode 100644 index ad0657ca..00000000 --- a/specs/api/components/parameters/User.yaml +++ /dev/null @@ -1,8 +0,0 @@ -name: user -in: query -required: false -schema: - type: string - default: trustgraph -description: User identifier -example: alice diff --git a/specs/api/components/schemas/agent/AgentRequest.yaml b/specs/api/components/schemas/agent/AgentRequest.yaml index ddf2019a..26703402 100644 --- a/specs/api/components/schemas/agent/AgentRequest.yaml +++ b/specs/api/components/schemas/agent/AgentRequest.yaml @@ -43,15 +43,6 @@ properties: type: string description: Result of the action example: "Paris is the capital of France" - user: - type: string - description: User context for this step - example: alice - user: - type: string - description: User identifier for multi-tenancy - default: trustgraph - example: alice streaming: type: boolean description: Enable streaming response delivery diff --git a/specs/api/components/schemas/collection/CollectionRequest.yaml b/specs/api/components/schemas/collection/CollectionRequest.yaml index bf3ab7d4..e1dc8338 100644 --- a/specs/api/components/schemas/collection/CollectionRequest.yaml +++ b/specs/api/components/schemas/collection/CollectionRequest.yaml @@ -14,14 +14,9 @@ properties: - delete-collection description: | Collection operation: - - `list-collections`: List collections for user + - `list-collections`: List collections in workspace - `update-collection`: Create or update collection metadata - `delete-collection`: Delete collection - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection identifier (for update, delete) diff --git a/specs/api/components/schemas/collection/CollectionResponse.yaml b/specs/api/components/schemas/collection/CollectionResponse.yaml index f924cbf5..d65a7274 100644 --- a/specs/api/components/schemas/collection/CollectionResponse.yaml +++ b/specs/api/components/schemas/collection/CollectionResponse.yaml @@ -12,13 +12,8 @@ properties: items: type: object required: - - user - collection properties: - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier diff --git a/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml index f2d0aec2..b6e9dcb3 100644 --- a/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml +++ b/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml @@ -17,11 +17,6 @@ properties: minimum: 1 maximum: 1000 example: 20 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to search diff --git a/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml index 6cf60bbd..212eb3e2 100644 --- a/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml +++ b/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml @@ -17,11 +17,6 @@ properties: minimum: 1 maximum: 1000 example: 20 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to search diff --git a/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml index 916b4beb..51111a94 100644 --- a/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml +++ b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml @@ -27,11 +27,6 @@ properties: minimum: 1 maximum: 1000 example: 20 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to search diff --git a/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml b/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml index 5c40e118..8be57dd6 100644 --- a/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml +++ b/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml @@ -18,17 +18,12 @@ properties: - unload-kg-core description: | Knowledge core operation: - - `list-kg-cores`: List knowledge cores for user + - `list-kg-cores`: List knowledge cores in workspace - `get-kg-core`: Get knowledge core by ID - `put-kg-core`: Store triples and/or embeddings - `delete-kg-core`: Delete knowledge core by ID - `load-kg-core`: Load knowledge core into flow - `unload-kg-core`: Unload knowledge core from flow - user: - type: string - description: User identifier (for list-kg-cores, put-kg-core, delete-kg-core) - default: trustgraph - example: alice id: type: string description: Knowledge core ID (for get, put, delete, load, unload) @@ -53,17 +48,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier @@ -89,17 +79,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier diff --git a/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml b/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml index 229233ca..b0e4d6bb 100644 --- a/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml +++ b/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml @@ -15,17 +15,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier @@ -48,17 +43,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier diff --git a/specs/api/components/schemas/librarian/LibrarianRequest.yaml b/specs/api/components/schemas/librarian/LibrarianRequest.yaml index eed999f0..25dca7e2 100644 --- a/specs/api/components/schemas/librarian/LibrarianRequest.yaml +++ b/specs/api/components/schemas/librarian/LibrarianRequest.yaml @@ -62,11 +62,6 @@ properties: description: Collection identifier default: default example: default - user: - type: string - description: User identifier - default: trustgraph - example: alice document-id: type: string description: Document identifier diff --git a/specs/api/components/schemas/loading/DocumentLoadRequest.yaml b/specs/api/components/schemas/loading/DocumentLoadRequest.yaml index 45bbe428..8d9a996f 100644 --- a/specs/api/components/schemas/loading/DocumentLoadRequest.yaml +++ b/specs/api/components/schemas/loading/DocumentLoadRequest.yaml @@ -15,11 +15,6 @@ properties: type: string description: Document identifier example: doc-456 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection for document diff --git a/specs/api/components/schemas/loading/TextLoadRequest.yaml b/specs/api/components/schemas/loading/TextLoadRequest.yaml index 447308d4..57f7ecc3 100644 --- a/specs/api/components/schemas/loading/TextLoadRequest.yaml +++ b/specs/api/components/schemas/loading/TextLoadRequest.yaml @@ -14,11 +14,6 @@ properties: type: string description: Document identifier example: doc-123 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection for document diff --git a/specs/api/components/schemas/query/RowsQueryRequest.yaml b/specs/api/components/schemas/query/RowsQueryRequest.yaml index 08f03ad3..611864e8 100644 --- a/specs/api/components/schemas/query/RowsQueryRequest.yaml +++ b/specs/api/components/schemas/query/RowsQueryRequest.yaml @@ -28,11 +28,6 @@ properties: type: string description: Operation name (for multi-operation documents) example: GetPerson - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to query diff --git a/specs/api/components/schemas/query/StructuredQueryRequest.yaml b/specs/api/components/schemas/query/StructuredQueryRequest.yaml index ae564c0a..00bc75cb 100644 --- a/specs/api/components/schemas/query/StructuredQueryRequest.yaml +++ b/specs/api/components/schemas/query/StructuredQueryRequest.yaml @@ -10,11 +10,6 @@ properties: type: string description: Natural language question example: Who does Alice know that works in engineering? - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to query diff --git a/specs/api/components/schemas/query/TriplesQueryRequest.yaml b/specs/api/components/schemas/query/TriplesQueryRequest.yaml index d49e0300..0efb1452 100644 --- a/specs/api/components/schemas/query/TriplesQueryRequest.yaml +++ b/specs/api/components/schemas/query/TriplesQueryRequest.yaml @@ -18,11 +18,6 @@ properties: minimum: 1 maximum: 100000 example: 100 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to query diff --git a/specs/api/components/schemas/rag/DocumentRagRequest.yaml b/specs/api/components/schemas/rag/DocumentRagRequest.yaml index 97a9d2ff..92bc383b 100644 --- a/specs/api/components/schemas/rag/DocumentRagRequest.yaml +++ b/specs/api/components/schemas/rag/DocumentRagRequest.yaml @@ -9,11 +9,6 @@ properties: type: string description: User query or question example: What are the key findings in the research papers? - user: - type: string - description: User identifier for multi-tenancy - default: trustgraph - example: alice collection: type: string description: Collection to search within diff --git a/specs/api/components/schemas/rag/GraphRagRequest.yaml b/specs/api/components/schemas/rag/GraphRagRequest.yaml index 733dd7c1..754dcc92 100644 --- a/specs/api/components/schemas/rag/GraphRagRequest.yaml +++ b/specs/api/components/schemas/rag/GraphRagRequest.yaml @@ -9,11 +9,6 @@ properties: type: string description: User query or question example: What connections exist between quantum physics and computer science? - user: - type: string - description: User identifier for multi-tenancy - default: trustgraph - example: alice collection: type: string description: Collection to search within diff --git a/specs/api/paths/collection-management.yaml b/specs/api/paths/collection-management.yaml index 7dffd4e0..acd7ff9f 100644 --- a/specs/api/paths/collection-management.yaml +++ b/specs/api/paths/collection-management.yaml @@ -10,11 +10,10 @@ post: Collections are organizational units for grouping: - Documents in the librarian - Knowledge cores - - User data + - Workspace data Each collection has: - - **user**: Owner identifier - - **collection**: Unique collection ID + - **collection**: Unique collection ID (within the workspace) - **name**: Human-readable display name - **description**: Purpose and contents - **tags**: Labels for filtering and organization @@ -22,7 +21,7 @@ post: ## Operations ### list-collections - List all collections for a user. Optionally filter by tags and limit results. + List all collections in the workspace. Optionally filter by tags and limit results. Returns array of collection metadata. ### update-collection @@ -30,7 +29,7 @@ post: If it exists, metadata is updated. Allows setting name, description, and tags. ### delete-collection - Delete a collection by user and collection ID. This removes the metadata but + Delete a collection by collection ID. This removes the metadata but typically does not delete the associated data (documents, knowledge cores). operationId: collectionManagementService @@ -44,22 +43,19 @@ post: $ref: '../components/schemas/collection/CollectionRequest.yaml' examples: listCollections: - summary: List all collections for user + summary: List all collections in workspace value: operation: list-collections - user: alice listCollectionsFiltered: summary: List collections filtered by tags value: operation: list-collections - user: alice tag-filter: ["research", "AI"] limit: 50 updateCollection: summary: Create/update collection value: operation: update-collection - user: alice collection: research name: Research Papers description: Academic research papers on AI and ML @@ -69,7 +65,6 @@ post: summary: Delete collection value: operation: delete-collection - user: alice collection: research responses: '200': @@ -84,13 +79,11 @@ post: value: timestamp: "2024-01-15T10:30:00Z" collections: - - user: alice - collection: research + - collection: research name: Research Papers description: Academic research papers on AI and ML tags: ["research", "AI", "academic"] - - user: alice - collection: personal + - collection: personal name: Personal Documents description: Personal notes and documents tags: ["personal"] diff --git a/specs/api/paths/document-stream.yaml b/specs/api/paths/document-stream.yaml index 5f6a11a7..67aea0e1 100644 --- a/specs/api/paths/document-stream.yaml +++ b/specs/api/paths/document-stream.yaml @@ -8,7 +8,6 @@ get: ## Parameters - - `user`: User identifier (required) - `document-id`: Document IRI to retrieve (required) - `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB) @@ -16,13 +15,6 @@ get: security: - bearerAuth: [] parameters: - - name: user - in: query - required: true - schema: - type: string - description: User identifier - example: trustgraph - name: document-id in: query required: true diff --git a/specs/api/paths/export-core.yaml b/specs/api/paths/export-core.yaml index e7dc06b0..7fddd024 100644 --- a/specs/api/paths/export-core.yaml +++ b/specs/api/paths/export-core.yaml @@ -23,7 +23,6 @@ get: "m": { // Metadata "i": "core-id", // Knowledge core ID "m": [...], // Metadata triples array - "u": "user", // User "c": "collection" // Collection }, "t": [...] // Triples array @@ -36,7 +35,6 @@ get: "m": { // Metadata "i": "core-id", "m": [...], - "u": "user", "c": "collection" }, "e": [ // Entities array @@ -56,7 +54,6 @@ get: ## Query Parameters - **id**: Knowledge core ID to export - - **user**: User identifier ## Streaming @@ -86,13 +83,6 @@ get: type: string description: Knowledge core ID to export example: core-123 - - name: user - in: query - required: true - schema: - type: string - description: User identifier - example: alice responses: '200': description: Export stream diff --git a/specs/api/paths/flow/agent.yaml b/specs/api/paths/flow/agent.yaml index 2cecf89c..a38b6a82 100644 --- a/specs/api/paths/flow/agent.yaml +++ b/specs/api/paths/flow/agent.yaml @@ -69,25 +69,21 @@ post: summary: Simple question value: question: What is the capital of France? - user: alice streamingQuestion: summary: Question with streaming enabled value: question: Explain quantum computing - user: alice streaming: true conversationWithHistory: summary: Multi-turn conversation value: question: And what about its population? - user: alice history: - thought: User is asking about the capital of France action: search arguments: query: "capital of France" observation: "Paris is the capital of France" - user: alice responses: '200': description: Successful response diff --git a/specs/api/paths/flow/document-embeddings.yaml b/specs/api/paths/flow/document-embeddings.yaml index dbab2f92..ba7344fe 100644 --- a/specs/api/paths/flow/document-embeddings.yaml +++ b/specs/api/paths/flow/document-embeddings.yaml @@ -75,7 +75,6 @@ post: value: vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] limit: 10 - user: alice collection: research largeQuery: summary: Larger result set diff --git a/specs/api/paths/flow/document-load.yaml b/specs/api/paths/flow/document-load.yaml index 09ddc09f..97ca3f3f 100644 --- a/specs/api/paths/flow/document-load.yaml +++ b/specs/api/paths/flow/document-load.yaml @@ -88,14 +88,12 @@ post: value: data: JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg== id: doc-789 - user: alice collection: research withMetadata: summary: Load with metadata value: data: JVBERi0xLjQKJeLjz9MK... id: doc-101112 - user: bob collection: papers metadata: - s: {v: "doc-101112", e: false} diff --git a/specs/api/paths/flow/document-rag.yaml b/specs/api/paths/flow/document-rag.yaml index f91bfc27..891868a5 100644 --- a/specs/api/paths/flow/document-rag.yaml +++ b/specs/api/paths/flow/document-rag.yaml @@ -40,7 +40,6 @@ post: - Higher = more context but slower - Lower = faster but may miss relevant info - **collection**: Target specific document collection - - **user**: Multi-tenant isolation operationId: documentRagService security: @@ -64,13 +63,11 @@ post: summary: Basic document query value: query: What are the key findings in the research papers? - user: alice collection: research streamingQuery: summary: Streaming query value: query: Summarize the main conclusions - user: alice collection: research doc-limit: 15 streaming: true diff --git a/specs/api/paths/flow/graph-embeddings.yaml b/specs/api/paths/flow/graph-embeddings.yaml index 277659de..16c1f4b3 100644 --- a/specs/api/paths/flow/graph-embeddings.yaml +++ b/specs/api/paths/flow/graph-embeddings.yaml @@ -66,7 +66,6 @@ post: value: vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] limit: 10 - user: alice collection: research largeQuery: summary: Larger result set diff --git a/specs/api/paths/flow/graph-rag.yaml b/specs/api/paths/flow/graph-rag.yaml index b3f087c6..e9ffadea 100644 --- a/specs/api/paths/flow/graph-rag.yaml +++ b/specs/api/paths/flow/graph-rag.yaml @@ -77,13 +77,11 @@ post: summary: Basic graph query value: query: What connections exist between quantum physics and computer science? - user: alice collection: research streamingQuery: summary: Streaming query with custom limits value: query: Trace the historical development of AI from Turing to modern LLMs - user: alice collection: research entity-limit: 40 triple-limit: 25 diff --git a/specs/api/paths/flow/row-embeddings.yaml b/specs/api/paths/flow/row-embeddings.yaml index 05837c06..9f9f5c4f 100644 --- a/specs/api/paths/flow/row-embeddings.yaml +++ b/specs/api/paths/flow/row-embeddings.yaml @@ -62,7 +62,6 @@ post: vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] schema_name: customers limit: 10 - user: alice collection: sales filteredQuery: summary: Search specific index diff --git a/specs/api/paths/flow/rows.yaml b/specs/api/paths/flow/rows.yaml index d648c9db..e8b1573a 100644 --- a/specs/api/paths/flow/rows.yaml +++ b/specs/api/paths/flow/rows.yaml @@ -89,7 +89,6 @@ post: email } } - user: alice collection: research queryWithVariables: summary: Query with variables diff --git a/specs/api/paths/flow/sparql-query.yaml b/specs/api/paths/flow/sparql-query.yaml index 2f343488..b7970bd0 100644 --- a/specs/api/paths/flow/sparql-query.yaml +++ b/specs/api/paths/flow/sparql-query.yaml @@ -61,10 +61,6 @@ post: query: type: string description: SPARQL 1.1 query string - user: - type: string - default: trustgraph - description: User/keyspace identifier collection: type: string default: default @@ -78,7 +74,6 @@ post: summary: SELECT query value: query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" - user: trustgraph collection: default askQuery: summary: ASK query diff --git a/specs/api/paths/flow/structured-query.yaml b/specs/api/paths/flow/structured-query.yaml index 6d4dfe87..7e963b5b 100644 --- a/specs/api/paths/flow/structured-query.yaml +++ b/specs/api/paths/flow/structured-query.yaml @@ -79,13 +79,11 @@ post: summary: Simple relationship question value: question: Who does Alice know? - user: alice collection: research complexQuestion: summary: Complex multi-hop question value: question: What companies employ engineers that Bob collaborates with? - user: bob collection: work filterQuestion: summary: Question with implicit filters diff --git a/specs/api/paths/flow/text-load.yaml b/specs/api/paths/flow/text-load.yaml index 08bfe47b..b97d249f 100644 --- a/specs/api/paths/flow/text-load.yaml +++ b/specs/api/paths/flow/text-load.yaml @@ -87,14 +87,12 @@ post: value: text: This is the document text... id: doc-123 - user: alice collection: research withMetadata: summary: Load with RDF metadata using base64 text value: text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u id: doc-456 - user: alice collection: research metadata: - s: {v: "doc-456", e: false} diff --git a/specs/api/paths/flow/triples.yaml b/specs/api/paths/flow/triples.yaml index 5557ea5a..9683d9f7 100644 --- a/specs/api/paths/flow/triples.yaml +++ b/specs/api/paths/flow/triples.yaml @@ -81,7 +81,6 @@ post: s: v: https://example.com/person/alice e: true - user: alice collection: research limit: 100 allInstancesOfType: @@ -100,7 +99,6 @@ post: p: v: https://example.com/knows e: true - user: alice limit: 200 responses: '200': diff --git a/specs/api/paths/import-core.yaml b/specs/api/paths/import-core.yaml index 38c99bf0..633f5477 100644 --- a/specs/api/paths/import-core.yaml +++ b/specs/api/paths/import-core.yaml @@ -23,7 +23,6 @@ post: "m": { // Metadata "i": "core-id", // Knowledge core ID "m": [...], // Metadata triples array - "u": "user", // User "c": "collection" // Collection }, "t": [...] // Triples array @@ -36,7 +35,6 @@ post: "m": { // Metadata "i": "core-id", "m": [...], - "u": "user", "c": "collection" }, "e": [ // Entities array @@ -51,7 +49,6 @@ post: ## Query Parameters - **id**: Knowledge core ID - - **user**: User identifier ## Streaming @@ -77,13 +74,6 @@ post: type: string description: Knowledge core ID to import example: core-123 - - name: user - in: query - required: true - schema: - type: string - description: User identifier - example: alice requestBody: required: true content: diff --git a/specs/api/paths/knowledge.yaml b/specs/api/paths/knowledge.yaml index 71bba496..0fe1976f 100644 --- a/specs/api/paths/knowledge.yaml +++ b/specs/api/paths/knowledge.yaml @@ -12,12 +12,12 @@ post: - **Graph Embeddings**: Vector embeddings for entities - **Metadata**: Descriptive information about the knowledge - Each core has an ID, user, and collection for organization. + Each core has an ID and collection for organization (within the workspace). ## Operations ### list-kg-cores - List all knowledge cores for a user. Returns array of core IDs. + List all knowledge cores in the workspace. Returns array of core IDs. ### get-kg-core Retrieve a knowledge core by ID. Returns triples and/or graph embeddings. @@ -58,7 +58,6 @@ post: summary: List knowledge cores value: operation: list-kg-cores - user: alice getKnowledgeCore: summary: Get knowledge core value: @@ -71,7 +70,6 @@ post: triples: metadata: id: core-123 - user: alice collection: default metadata: - s: {v: "https://example.com/core-123", e: true} @@ -91,7 +89,6 @@ post: graph-embeddings: metadata: id: core-123 - user: alice collection: default metadata: [] entities: @@ -106,7 +103,6 @@ post: triples: metadata: id: core-456 - user: bob collection: research metadata: [] triples: @@ -116,7 +112,6 @@ post: graph-embeddings: metadata: id: core-456 - user: bob collection: research metadata: [] entities: @@ -127,7 +122,6 @@ post: value: operation: delete-kg-core id: core-123 - user: alice loadKnowledgeCore: summary: Load core into flow value: @@ -161,7 +155,6 @@ post: triples: metadata: id: core-123 - user: alice collection: default metadata: - s: {v: "https://example.com/core-123", e: true} @@ -177,7 +170,6 @@ post: graph-embeddings: metadata: id: core-123 - user: alice collection: default metadata: [] entities: diff --git a/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml b/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml index 8010417d..efe6412b 100644 --- a/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml +++ b/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml @@ -26,5 +26,4 @@ examples: vectors: [0.023, -0.142, 0.089, 0.234] schema_name: customers limit: 10 - user: trustgraph collection: default diff --git a/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml b/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml index 6954b539..bc32f33a 100644 --- a/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml +++ b/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml @@ -24,10 +24,6 @@ properties: query: type: string description: SPARQL 1.1 query string - user: - type: string - default: trustgraph - description: User/keyspace identifier collection: type: string default: default @@ -42,5 +38,4 @@ examples: flow: my-flow request: query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" - user: trustgraph collection: default diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index 4fdfe83b..e93866e7 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -72,7 +72,6 @@ def sample_message_data(): }, "DocumentRagQuery": { "query": "What is artificial intelligence?", - "user": "test_user", "collection": "test_collection", "doc_limit": 10 }, @@ -95,7 +94,6 @@ def sample_message_data(): }, "Metadata": { "id": "test-doc-123", - "user": "test_user", "collection": "test_collection" }, "Term": { @@ -130,9 +128,8 @@ def invalid_message_data(): {}, # Missing required fields ], "DocumentRagQuery": [ - {"query": None, "user": "test", "collection": "test", "doc_limit": 10}, # Invalid query - {"query": "test", "user": None, "collection": "test", "doc_limit": 10}, # Invalid user - {"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit + {"query": None, "collection": "test", "doc_limit": 10}, # Invalid query + {"query": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit {"query": "test"}, # Missing required fields ], "Term": [ diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index b6d14124..c56e7b93 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -18,24 +18,18 @@ class TestDocumentEmbeddingsRequestContract: def test_request_schema_fields(self): """Test that DocumentEmbeddingsRequest has expected fields""" - # Create a request request = DocumentEmbeddingsRequest( vector=[0.1, 0.2, 0.3], limit=10, - user="test_user", collection="test_collection" ) - # Verify all expected fields exist assert hasattr(request, 'vector') assert hasattr(request, 'limit') - assert hasattr(request, 'user') assert hasattr(request, 'collection') - # Verify field values assert request.vector == [0.1, 0.2, 0.3] assert request.limit == 10 - assert request.user == "test_user" assert request.collection == "test_collection" def test_request_translator_decode(self): @@ -45,7 +39,6 @@ class TestDocumentEmbeddingsRequestContract: data = { "vector": [0.1, 0.2, 0.3, 0.4], "limit": 5, - "user": "custom_user", "collection": "custom_collection" } @@ -54,7 +47,6 @@ class TestDocumentEmbeddingsRequestContract: assert isinstance(result, DocumentEmbeddingsRequest) assert result.vector == [0.1, 0.2, 0.3, 0.4] assert result.limit == 5 - assert result.user == "custom_user" assert result.collection == "custom_collection" def test_request_translator_decode_with_defaults(self): @@ -63,7 +55,7 @@ class TestDocumentEmbeddingsRequestContract: data = { "vector": [0.1, 0.2] - # No limit, user, or collection provided + # No limit or collection provided } result = translator.decode(data) @@ -71,7 +63,6 @@ class TestDocumentEmbeddingsRequestContract: assert isinstance(result, DocumentEmbeddingsRequest) assert result.vector == [0.1, 0.2] assert result.limit == 10 # Default - assert result.user == "trustgraph" # Default assert result.collection == "default" # Default def test_request_translator_encode(self): @@ -81,7 +72,6 @@ class TestDocumentEmbeddingsRequestContract: request = DocumentEmbeddingsRequest( vector=[0.5, 0.6], limit=20, - user="test_user", collection="test_collection" ) @@ -90,7 +80,6 @@ class TestDocumentEmbeddingsRequestContract: assert isinstance(result, dict) assert result["vector"] == [0.5, 0.6] assert result["limit"] == 20 - assert result["user"] == "test_user" assert result["collection"] == "test_collection" @@ -219,7 +208,6 @@ class TestDocumentEmbeddingsMessageCompatibility: request_data = { "vector": [0.1, 0.2, 0.3], "limit": 5, - "user": "test_user", "collection": "test_collection" } diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index 6b7f82e7..59db99f6 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -132,7 +132,6 @@ class TestDocumentRagMessageContracts: # Test required fields query = DocumentRagQuery(**query_data) assert hasattr(query, 'query') - assert hasattr(query, 'user') assert hasattr(query, 'collection') assert hasattr(query, 'doc_limit') @@ -154,12 +153,10 @@ class TestDocumentRagMessageContracts: # Test valid query valid_query = DocumentRagQuery( query="What is AI?", - user="test_user", collection="test_collection", doc_limit=5 ) assert valid_query.query == "What is AI?" - assert valid_query.user == "test_user" assert valid_query.collection == "test_collection" assert valid_query.doc_limit == 5 @@ -400,7 +397,6 @@ class TestMetadataMessageContracts: metadata = Metadata(**metadata_data) assert metadata.id == "test-doc-123" - assert metadata.user == "test_user" assert metadata.collection == "test_collection" def test_error_schema_contract(self): @@ -491,7 +487,7 @@ class TestSchemaEvolutionContracts: required_fields = { "TextCompletionRequest": ["system", "prompt"], "TextCompletionResponse": ["error", "response", "model"], - "DocumentRagQuery": ["query", "user", "collection"], + "DocumentRagQuery": ["query", "collection"], "DocumentRagResponse": ["error", "response"], "AgentRequest": ["question", "history"], "AgentResponse": ["error"], diff --git a/tests/contract/test_orchestrator_contracts.py b/tests/contract/test_orchestrator_contracts.py index ab168ece..d0833297 100644 --- a/tests/contract/test_orchestrator_contracts.py +++ b/tests/contract/test_orchestrator_contracts.py @@ -18,7 +18,6 @@ class TestOrchestrationFieldContracts: def test_agent_request_orchestration_fields_roundtrip(self): req = AgentRequest( question="Test question", - user="testuser", collection="default", correlation_id="corr-123", parent_session_id="parent-sess", @@ -42,7 +41,6 @@ class TestOrchestrationFieldContracts: def test_agent_request_orchestration_fields_default_empty(self): req = AgentRequest( question="Test question", - user="testuser", ) assert req.correlation_id == "" @@ -82,7 +80,6 @@ class TestSubagentCompletionStepContract: ) req = AgentRequest( question="goal", - user="testuser", correlation_id="corr-123", history=[step], ) @@ -126,7 +123,6 @@ class TestSynthesisStepContract: req = AgentRequest( question="Original question", - user="testuser", pattern="supervisor", correlation_id="", session_id="parent-sess", diff --git a/tests/contract/test_rows_cassandra_contracts.py b/tests/contract/test_rows_cassandra_contracts.py index bf85b9fb..55a06751 100644 --- a/tests/contract/test_rows_cassandra_contracts.py +++ b/tests/contract/test_rows_cassandra_contracts.py @@ -22,7 +22,6 @@ class TestRowsCassandraContracts: # Create test object with all required fields test_metadata = Metadata( id="test-doc-001", - user="test_user", collection="test_collection", ) @@ -47,7 +46,6 @@ class TestRowsCassandraContracts: # Verify metadata structure assert hasattr(test_object.metadata, 'id') - assert hasattr(test_object.metadata, 'user') assert hasattr(test_object.metadata, 'collection') # Verify types @@ -150,7 +148,6 @@ class TestRowsCassandraContracts: original = ExtractedObject( metadata=Metadata( id="serial-001", - user="test_user", collection="test_coll", ), schema_name="test_schema", @@ -168,7 +165,6 @@ class TestRowsCassandraContracts: # Verify round-trip assert decoded.metadata.id == original.metadata.id - assert decoded.metadata.user == original.metadata.user assert decoded.metadata.collection == original.metadata.collection assert decoded.schema_name == original.schema_name assert decoded.values == original.values @@ -228,8 +224,7 @@ class TestRowsCassandraContracts: # Create test object test_obj = ExtractedObject( metadata=Metadata( - id="meta-001", - user="user123", # -> keyspace + id="meta-001", # -> keyspace collection="coll456", # -> partition key ), schema_name="table789", # -> table name @@ -242,7 +237,6 @@ class TestRowsCassandraContracts: # - metadata.user -> Cassandra keyspace # - schema_name -> Cassandra table # - metadata.collection -> Part of primary key - assert test_obj.metadata.user # Required for keyspace assert test_obj.schema_name # Required for table assert test_obj.metadata.collection # Required for partition key @@ -256,7 +250,6 @@ class TestRowsCassandraContractsBatch: # Create test object with multiple values in batch test_metadata = Metadata( id="batch-doc-001", - user="test_user", collection="test_collection", ) @@ -302,7 +295,6 @@ class TestRowsCassandraContractsBatch: """Test empty batch ExtractedObject contract""" test_metadata = Metadata( id="empty-batch-001", - user="test_user", collection="test_collection", ) @@ -324,7 +316,6 @@ class TestRowsCassandraContractsBatch: """Test single-item batch (backward compatibility) contract""" test_metadata = Metadata( id="single-batch-001", - user="test_user", collection="test_collection", ) @@ -353,7 +344,6 @@ class TestRowsCassandraContractsBatch: original = ExtractedObject( metadata=Metadata( id="batch-serial-001", - user="test_user", collection="test_coll", ), schema_name="test_schema", @@ -375,7 +365,6 @@ class TestRowsCassandraContractsBatch: # Verify round-trip for batch assert decoded.metadata.id == original.metadata.id - assert decoded.metadata.user == original.metadata.user assert decoded.metadata.collection == original.metadata.collection assert decoded.schema_name == original.schema_name assert len(decoded.values) == len(original.values) @@ -425,8 +414,7 @@ class TestRowsCassandraContractsBatch: # 3. Be stored in the same keyspace (user) test_metadata = Metadata( - id="partition-test-001", - user="consistent_user", # Same keyspace + id="partition-test-001", # Same keyspace collection="consistent_collection", # Same partition ) @@ -443,7 +431,6 @@ class TestRowsCassandraContractsBatch: ) # Verify consistency contract - assert batch_object.metadata.user # Must have user for keyspace assert batch_object.metadata.collection # Must have collection for partition key # Verify unique primary keys in batch diff --git a/tests/contract/test_rows_graphql_query_contracts.py b/tests/contract/test_rows_graphql_query_contracts.py index db796306..e5baa6e2 100644 --- a/tests/contract/test_rows_graphql_query_contracts.py +++ b/tests/contract/test_rows_graphql_query_contracts.py @@ -21,29 +21,25 @@ class TestRowsGraphQLQueryContracts: """Test RowsQueryRequest schema structure and required fields""" # Create test request with all required fields test_request = RowsQueryRequest( - user="test_user", collection="test_collection", query='{ customers { id name email } }', variables={"status": "active", "limit": "10"}, operation_name="GetCustomers" ) - + # Verify all required fields are present - assert hasattr(test_request, 'user') - assert hasattr(test_request, 'collection') + assert hasattr(test_request, 'collection') assert hasattr(test_request, 'query') assert hasattr(test_request, 'variables') assert hasattr(test_request, 'operation_name') - + # Verify field types - assert isinstance(test_request.user, str) assert isinstance(test_request.collection, str) assert isinstance(test_request.query, str) assert isinstance(test_request.variables, dict) assert isinstance(test_request.operation_name, str) - + # Verify content - assert test_request.user == "test_user" assert test_request.collection == "test_collection" assert "customers" in test_request.query assert test_request.variables["status"] == "active" @@ -53,15 +49,13 @@ class TestRowsGraphQLQueryContracts: """Test RowsQueryRequest with minimal required fields""" # Create request with only essential fields minimal_request = RowsQueryRequest( - user="user", collection="collection", query='{ test }', variables={}, operation_name="" ) - + # Verify minimal request is valid - assert minimal_request.user == "user" assert minimal_request.collection == "collection" assert minimal_request.query == '{ test }' assert minimal_request.variables == {} @@ -187,22 +181,20 @@ class TestRowsGraphQLQueryContracts: """Test that request/response can be serialized/deserialized correctly""" # Create original request original_request = RowsQueryRequest( - user="serialization_test", collection="test_data", query='{ orders(limit: 5) { id total customer { name } } }', variables={"limit": "5", "status": "active"}, operation_name="GetRecentOrders" ) - + # Test request serialization using Pulsar schema request_schema = AvroSchema(RowsQueryRequest) - + # Encode and decode request encoded_request = request_schema.encode(original_request) decoded_request = request_schema.decode(encoded_request) - + # Verify request round-trip - assert decoded_request.user == original_request.user assert decoded_request.collection == original_request.collection assert decoded_request.query == original_request.query assert decoded_request.variables == original_request.variables @@ -245,7 +237,7 @@ class TestRowsGraphQLQueryContracts: """Test supported GraphQL query formats""" # Test basic query basic_query = RowsQueryRequest( - user="test", collection="test", query='{ customers { id } }', + collection="test", query='{ customers { id } }', variables={}, operation_name="" ) assert "customers" in basic_query.query @@ -254,7 +246,7 @@ class TestRowsGraphQLQueryContracts: # Test query with variables parameterized_query = RowsQueryRequest( - user="test", collection="test", + collection="test", query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }', variables={"status": "active", "limit": "10"}, operation_name="GetCustomers" @@ -266,7 +258,7 @@ class TestRowsGraphQLQueryContracts: # Test complex nested query nested_query = RowsQueryRequest( - user="test", collection="test", + collection="test", query=''' { customers(limit: 10) { @@ -297,7 +289,7 @@ class TestRowsGraphQLQueryContracts: # This test verifies the current contract, though ideally we'd support all JSON types variables_test = RowsQueryRequest( - user="test", collection="test", query='{ test }', + collection="test", query='{ test }', variables={ "string_var": "test_value", "numeric_var": "123", # Numbers as strings due to Map(String()) limitation @@ -318,22 +310,18 @@ class TestRowsGraphQLQueryContracts: def test_cassandra_context_fields_contract(self): """Test that request contains necessary fields for Cassandra operations""" - # Verify request has fields needed for Cassandra keyspace/table targeting + # Verify request has fields needed for partition key targeting request = RowsQueryRequest( - user="keyspace_name", # Maps to Cassandra keyspace collection="partition_collection", # Used in partition key query='{ objects { id } }', variables={}, operation_name="" ) - - # These fields are required for proper Cassandra operations - assert request.user # Required for keyspace identification - assert request.collection # Required for partition key - + + # Required for partition key + assert request.collection + # Verify field naming follows TrustGraph patterns (matching other query services) - # This matches TriplesQueryRequest, DocumentEmbeddingsRequest patterns - assert hasattr(request, 'user') # Same as TriplesQueryRequest.user - assert hasattr(request, 'collection') # Same as TriplesQueryRequest.collection + assert hasattr(request, 'collection') def test_graphql_extensions_contract(self): """Test GraphQL extensions field format and usage""" @@ -405,7 +393,7 @@ class TestRowsGraphQLQueryContracts: # Request to execute specific operation multi_op_request = RowsQueryRequest( - user="test", collection="test", + collection="test", query=multi_op_query, variables={}, operation_name="GetCustomers" @@ -418,7 +406,7 @@ class TestRowsGraphQLQueryContracts: # Test single operation (operation_name optional) single_op_request = RowsQueryRequest( - user="test", collection="test", + collection="test", query='{ customers { id } }', variables={}, operation_name="" ) diff --git a/tests/contract/test_schema_field_contracts.py b/tests/contract/test_schema_field_contracts.py index 4b7c3da5..5be745b8 100644 --- a/tests/contract/test_schema_field_contracts.py +++ b/tests/contract/test_schema_field_contracts.py @@ -41,10 +41,11 @@ class TestSchemaFieldContracts: def test_metadata_fields(self): # NOTE: there is no `metadata` field. A previous regression # constructed Metadata(metadata=...) and crashed at runtime. + # `user` was also dropped in the workspace refactor — workspace + # now flows via flow.workspace, not via message payload. assert _field_names(Metadata) == { "id", "root", - "user", "collection", } diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index d8f4c5cb..8208ef1b 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -93,7 +93,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="structured-data-001", - user="test_user", collection="test_collection", ) @@ -118,7 +117,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-obj-001", - user="test_user", collection="test_collection", ) @@ -143,7 +141,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-batch-001", - user="test_user", collection="test_collection", ) @@ -177,7 +174,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-empty-001", - user="test_user", collection="test_collection", ) @@ -277,7 +273,6 @@ class TestStructuredEmbeddingsContracts: # Arrange metadata = Metadata( id="struct-embed-001", - user="test_user", collection="test_collection", ) @@ -308,7 +303,7 @@ class TestStructuredDataSerializationContracts: def test_structured_data_submission_serialization(self): """Test StructuredDataSubmission serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") submission_data = { "metadata": metadata, "format": "json", @@ -323,7 +318,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_serialization(self): """Test ExtractedObject serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") object_data = { "metadata": metadata, "schema_name": "test_schema", @@ -373,7 +368,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_batch_serialization(self): """Test ExtractedObject batch serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") batch_object_data = { "metadata": metadata, "schema_name": "test_schema", @@ -392,7 +387,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_empty_batch_serialization(self): """Test ExtractedObject empty batch serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") empty_batch_data = { "metadata": metadata, "schema_name": "test_schema", diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index 2442bf10..883e5261 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration: async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config): """Test basic agent integration with structured query tool""" # Arrange - Load tool configuration - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") # Create agent request request = AgentRequest( @@ -66,7 +66,6 @@ class TestAgentStructuredQueryIntegration: state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -119,6 +118,7 @@ Args: { # Mock flow parameter in agent_processor.on_request flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -146,14 +146,13 @@ Args: { async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config): """Test agent handling of structured query errors""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="Find data from a table that doesn't exist using structured query.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -199,6 +198,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -221,14 +221,13 @@ Args: { async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config): """Test agent using structured query in multi-step reasoning""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="First find all customers from California, then tell me how many orders they have made.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -279,6 +278,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -313,14 +313,13 @@ Args: { } } - await agent_processor.on_tools_config(tool_config_with_collection, "v1") + await agent_processor.on_tools_config("default", tool_config_with_collection, "v1") request = AgentRequest( question="Query the sales data for recent transactions.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -371,6 +370,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -394,10 +394,10 @@ Args: { async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config): """Test that structured query tool arguments are properly validated""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") # Check that the tool was registered with correct arguments - tools = agent_processor.agent.tools + tools = agent_processor.agents["default"].tools assert "structured-query" in tools structured_tool = tools["structured-query"] @@ -414,14 +414,13 @@ Args: { async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config): """Test that structured query results are properly formatted for agent consumption""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="Get customer information and format it nicely.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -482,6 +481,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 6c83fb05..1e4276fe 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow: # Create a mock message to trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): # This should create TrustGraph with environment config - await processor.store_triples(mock_message) + await processor.store_triples('test_user', mock_message) # Verify Cluster was created with correct hosts mock_cluster.assert_called_once() @@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('test_user', mock_message) # Should use CLI parameters, not environment mock_cluster.assert_called_once() @@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd: # Mock query to trigger TrustGraph creation mock_query = MagicMock() - mock_query.user = 'default_user' mock_query.collection = 'default_collection' mock_query.s = None mock_query.p = None @@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd: mock_tg_instance.get_all.return_value = [] processor.tg = mock_tg_instance - await processor.query_triples(mock_query) + await processor.query_triples('default_user', mock_query) # Should use defaults mock_cluster.assert_called_once() @@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'legacy_user' mock_message.metadata.collection = 'legacy_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('legacy_user', mock_message) # Should use defaults since old parameters are not recognized mock_cluster.assert_called_once() @@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'precedence_user' mock_message.metadata.collection = 'precedence_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('precedence_user', mock_message) # Should use new parameters, not old ones mock_cluster.assert_called_once() @@ -354,13 +349,12 @@ class TestMultipleHostsHandling: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'single_user' mock_message.metadata.collection = 'single_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('single_user', mock_message) # Single host should be converted to list mock_cluster.assert_called_once() diff --git a/tests/integration/test_cassandra_integration.py b/tests/integration/test_cassandra_integration.py index 2f5a4195..42273d58 100644 --- a/tests/integration/test_cassandra_integration.py +++ b/tests/integration/test_cassandra_integration.py @@ -115,7 +115,7 @@ class TestCassandraIntegration: # Create test message storage_message = Triples( - metadata=Metadata(user="testuser", collection="testcol"), + metadata=Metadata(collection="testcol"), triples=[ Triple( s=Term(type=IRI, iri="http://example.org/person1"), @@ -178,7 +178,7 @@ class TestCassandraIntegration: # Store test data for querying query_test_message = Triples( - metadata=Metadata(user="testuser", collection="testcol"), + metadata=Metadata(collection="testcol"), triples=[ Triple( s=Term(type=IRI, iri="http://example.org/alice"), @@ -212,7 +212,6 @@ class TestCassandraIntegration: p=None, # None for wildcard o=None, # None for wildcard limit=10, - user="testuser", collection="testcol" ) s_results = await query_processor.query_triples(s_query) @@ -232,7 +231,6 @@ class TestCassandraIntegration: p=Term(type=IRI, iri="http://example.org/knows"), o=None, # None for wildcard limit=10, - user="testuser", collection="testcol" ) p_results = await query_processor.query_triples(p_query) @@ -259,7 +257,7 @@ class TestCassandraIntegration: # Create multiple coroutines for concurrent storage async def store_person_data(person_id, name, age, department): message = Triples( - metadata=Metadata(user="concurrent_test", collection="people"), + metadata=Metadata(collection="people"), triples=[ Triple( s=Term(type=IRI, iri=f"http://example.org/{person_id}"), @@ -329,7 +327,7 @@ class TestCassandraIntegration: # Create a knowledge graph about a company company_graph = Triples( - metadata=Metadata(user="integration_test", collection="company"), + metadata=Metadata(collection="company"), triples=[ # People and their types Triple( diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 8c165385..78a85acf 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -99,7 +99,6 @@ class TestDocumentRagIntegration: # Act result = await document_rag.query( query=query, - user=user, collection=collection, doc_limit=doc_limit ) @@ -110,7 +109,6 @@ class TestDocumentRagIntegration: mock_doc_embeddings_client.query.assert_called_once_with( vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], limit=doc_limit, - user=user, collection=collection ) @@ -278,14 +276,12 @@ class TestDocumentRagIntegration: # Act await document_rag.query( f"query from {user} in {collection}", - user=user, collection=collection ) # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args - assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection @pytest.mark.asyncio @@ -353,6 +349,5 @@ class TestDocumentRagIntegration: # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args - assert call_args.kwargs['user'] == "trustgraph" assert call_args.kwargs['collection'] == "default" assert call_args.kwargs['limit'] == 20 diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index e2c032ad..49ddc3a2 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -107,7 +107,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query=query, - user="test_user", collection="test_collection", doc_limit=10, streaming=True, @@ -141,7 +140,6 @@ class TestDocumentRagStreaming: # Act - Non-streaming non_streaming_result = await document_rag_streaming.query( query=query, - user=user, collection=collection, doc_limit=doc_limit, streaming=False @@ -155,7 +153,6 @@ class TestDocumentRagStreaming: streaming_result = await document_rag_streaming.query( query=query, - user=user, collection=collection, doc_limit=doc_limit, streaming=True, @@ -178,7 +175,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=5, streaming=True, @@ -200,7 +196,6 @@ class TestDocumentRagStreaming: # Arrange & Act result = await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=5, streaming=True, @@ -223,7 +218,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query="unknown topic", - user="test_user", collection="test_collection", doc_limit=10, streaming=True, @@ -247,7 +241,6 @@ class TestDocumentRagStreaming: with pytest.raises(Exception) as exc_info: await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=5, streaming=True, @@ -272,7 +265,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=limit, streaming=True, @@ -300,7 +292,6 @@ class TestDocumentRagStreaming: # Act await document_rag_streaming.query( query="test query", - user=user, collection=collection, doc_limit=10, streaming=True, @@ -309,5 +300,4 @@ class TestDocumentRagStreaming: # Assert - Verify user/collection were passed to document embeddings client call_args = mock_doc_embeddings_client.query.call_args - assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 9c3cdf45..696df7ec 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -146,7 +146,6 @@ class TestGraphRagIntegration: # Act response = await graph_rag.query( query=query, - user=user, collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, @@ -163,7 +162,6 @@ class TestGraphRagIntegration: call_args = mock_graph_embeddings_client.query.call_args assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['limit'] == entity_limit - assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection # 3. Should query triples to build knowledge subgraph @@ -204,7 +202,6 @@ class TestGraphRagIntegration: # Act await graph_rag.query( query=query, - user="test_user", collection="test_collection", entity_limit=config["entity_limit"], triple_limit=config["triple_limit"] @@ -224,7 +221,6 @@ class TestGraphRagIntegration: with pytest.raises(Exception) as exc_info: await graph_rag.query( query="test query", - user="test_user", collection="test_collection" ) @@ -247,7 +243,6 @@ class TestGraphRagIntegration: # Act response = await graph_rag.query( query="unknown topic", - user="test_user", collection="test_collection", explain_callback=collect_provenance ) @@ -267,7 +262,6 @@ class TestGraphRagIntegration: # First query await graph_rag.query( query=query, - user="test_user", collection="test_collection" ) @@ -277,7 +271,6 @@ class TestGraphRagIntegration: # Second identical query await graph_rag.query( query=query, - user="test_user", collection="test_collection" ) @@ -289,26 +282,27 @@ class TestGraphRagIntegration: assert second_call_count >= 0 # Should complete without errors @pytest.mark.asyncio - async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client): - """Test that different users/collections are properly isolated""" + async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client): + """Test that different collections propagate through to the embeddings query. + + Workspace isolation is enforced by flow.workspace at the service + boundary — not by parameters on GraphRag.query — so this test + verifies collection routing only. + """ # Arrange query = "test query" - user1, collection1 = "user1", "collection1" - user2, collection2 = "user2", "collection2" + collection1 = "collection1" + collection2 = "collection2" # Act - await graph_rag.query(query=query, user=user1, collection=collection1) - await graph_rag.query(query=query, user=user2, collection=collection2) + await graph_rag.query(query=query, collection=collection1) + await graph_rag.query(query=query, collection=collection2) - # Assert - Both users should have separate queries + # Assert - Each call propagated its collection assert mock_graph_embeddings_client.query.call_count == 2 - # Verify first call first_call = mock_graph_embeddings_client.query.call_args_list[0] - assert first_call.kwargs['user'] == user1 assert first_call.kwargs['collection'] == collection1 - # Verify second call second_call = mock_graph_embeddings_client.query.call_args_list[1] - assert second_call.kwargs['user'] == user2 assert second_call.kwargs['collection'] == collection2 diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 95c494bb..48e26618 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -116,7 +116,6 @@ class TestGraphRagStreaming: # Act - query() returns response, provenance via callback response = await graph_rag_streaming.query( query=query, - user="test_user", collection="test_collection", streaming=True, chunk_callback=collector.collect, @@ -154,7 +153,6 @@ class TestGraphRagStreaming: # Act - Non-streaming non_streaming_response = await graph_rag_streaming.query( query=query, - user=user, collection=collection, streaming=False ) @@ -167,7 +165,6 @@ class TestGraphRagStreaming: streaming_response = await graph_rag_streaming.query( query=query, - user=user, collection=collection, streaming=True, chunk_callback=collect @@ -189,7 +186,6 @@ class TestGraphRagStreaming: # Act response = await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -209,7 +205,6 @@ class TestGraphRagStreaming: # Arrange & Act response = await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=None # No callback provided @@ -231,7 +226,6 @@ class TestGraphRagStreaming: # Act response = await graph_rag_streaming.query( query="unknown topic", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -253,7 +247,6 @@ class TestGraphRagStreaming: with pytest.raises(Exception) as exc_info: await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -273,7 +266,6 @@ class TestGraphRagStreaming: # Act await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", entity_limit=entity_limit, triple_limit=triple_limit, diff --git a/tests/integration/test_import_export_graceful_shutdown.py b/tests/integration/test_import_export_graceful_shutdown.py index a3771b80..2fcd2683 100644 --- a/tests/integration/test_import_export_graceful_shutdown.py +++ b/tests/integration/test_import_export_graceful_shutdown.py @@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend): triples_obj = Triples( metadata=Metadata( id=f"export-msg-{i}", - user=msg_data["metadata"]["user"], collection=msg_data["metadata"]["collection"], ), triples=to_subgraph(msg_data["triples"]), diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 84c0905d..48878a00 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -97,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration: return Chunk( metadata=Metadata( id="doc-123", - user="test_user", collection="test_collection", ), chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns." @@ -247,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange metadata = Metadata( id="test-doc", - user="test_user", collection="test_collection", ) @@ -305,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange metadata = Metadata( id="test-doc", - user="test_user", collection="test_collection", ) @@ -375,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration: sample_triples = Triples( metadata=Metadata( id="test-doc", - user="test_user", collection="test_collection", ), triples=[ @@ -390,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration: mock_msg = MagicMock() mock_msg.value.return_value = sample_triples + mock_flow = MagicMock() + mock_flow.workspace = "test_workspace" + # Act - await processor.on_triples(mock_msg, None, None) + await processor.on_triples(mock_msg, None, mock_flow) # Assert - mock_cassandra_store.add_triples.assert_called_once_with(sample_triples) + mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples) @pytest.mark.asyncio async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store): @@ -407,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration: sample_embeddings = GraphEmbeddings( metadata=Metadata( id="test-doc", - user="test_user", collection="test_collection", ), entities=[ @@ -421,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration: mock_msg = MagicMock() mock_msg.value.return_value = sample_embeddings + mock_flow = MagicMock() + mock_flow.workspace = "test_workspace" + # Act - await processor.on_graph_embeddings(mock_msg, None, None) + await processor.on_graph_embeddings(mock_msg, None, mock_flow) # Assert - mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings) + mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings) @pytest.mark.asyncio async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor, @@ -553,7 +554,7 @@ class TestKnowledgeGraphPipelineIntegration: ) sample_chunk = Chunk( - metadata=Metadata(id="test", user="user", collection="collection"), + metadata=Metadata(id="test", collection="collection"), chunk=b"Test chunk" ) @@ -580,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange large_chunk_batch = [ Chunk( - metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"), + metadata=Metadata(id=f"doc-{i}", collection="collection"), chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8") ) for i in range(100) # Large batch @@ -617,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange original_metadata = Metadata( id="test-doc-123", - user="test_user", collection="test_collection", ) @@ -646,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration: entity_contexts_call = entity_contexts_producer.send.call_args[0][0] assert triples_call.metadata.id == "test-doc-123" - assert triples_call.metadata.user == "test_user" assert triples_call.metadata.collection == "test_collection" assert entity_contexts_call.metadata.id == "test-doc-123" - assert entity_contexts_call.metadata.user == "test_user" assert entity_contexts_call.metadata.collection == "test_collection" \ No newline at end of file diff --git a/tests/integration/test_nlp_query_integration.py b/tests/integration/test_nlp_query_integration.py index 16c4543e..08bf1e77 100644 --- a/tests/integration/test_nlp_query_integration.py +++ b/tests/integration/test_nlp_query_integration.py @@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration: ) # Set up schemas - proc.schemas = sample_schemas + proc.schemas = {"default": dict(sample_schemas)} # Mock the client method proc.client = MagicMock() @@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration: } # Act - Update configuration - await integration_processor.on_schema_config(new_schema_config, "v2") + await integration_processor.on_schema_config("default", new_schema_config, "v2") # Arrange - Test query using new schema request = QuestionToStructuredQueryRequest( @@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration: await integration_processor.on_message(msg, consumer, flow) # Assert - assert "inventory" in integration_processor.schemas + assert "inventory" in integration_processor.schemas["default"] response_call = flow_response.send.call_args response = response_call[0][0] assert response.detected_schemas == ["inventory"] @@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration: graphql_generation_template="custom-graphql-generator" ) - custom_processor.schemas = sample_schemas + custom_processor.schemas = {"default": dict(sample_schemas)} custom_processor.client = MagicMock() request = QuestionToStructuredQueryRequest( @@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration: ] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)] ) - integration_processor.schemas.update(large_schema_set) + integration_processor.schemas["default"].update(large_schema_set) request = QuestionToStructuredQueryRequest( question="Show me data from table_05 and table_12", @@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration: msg.properties.return_value = {"id": f"concurrent-test-{i}"} flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index 22ba9a3f..8d58c764 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration: return AsyncMock() context.side_effect = context_router + context.workspace = "default" return context @pytest.mark.asyncio @@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration: processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Act - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Assert - assert len(processor.schemas) == 2 - assert "customer_records" in processor.schemas - assert "product_catalog" in processor.schemas - + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 2 + assert "customer_records" in ws_schemas + assert "product_catalog" in ws_schemas + # Verify customer schema - customer_schema = processor.schemas["customer_records"] + customer_schema = ws_schemas["customer_records"] assert customer_schema.name == "customer_records" assert len(customer_schema.fields) == 4 - + # Verify product schema - product_schema = processor.schemas["product_catalog"] + product_schema = ws_schemas["product_catalog"] assert product_schema.name == "product_catalog" assert len(product_schema.fields) == 4 @@ -237,12 +239,11 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create realistic customer data chunk metadata = Metadata( id="customer-doc-001", - user="integration_test", collection="test_documents", ) @@ -304,12 +305,11 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create realistic product data chunk metadata = Metadata( id="product-doc-001", - user="integration_test", collection="test_documents", ) @@ -368,7 +368,7 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create multiple test chunks chunks_data = [ @@ -382,7 +382,6 @@ class TestObjectExtractionServiceIntegration: for chunk_id, text in chunks_data: metadata = Metadata( id=chunk_id, - user="concurrent_test", collection="test_collection", ) chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8')) @@ -431,19 +430,21 @@ class TestObjectExtractionServiceIntegration: "customer_records": integration_config["schema"]["customer_records"] } } - await processor.on_schema_config(initial_config, version=1) - - assert len(processor.schemas) == 1 - assert "customer_records" in processor.schemas - assert "product_catalog" not in processor.schemas - + await processor.on_schema_config("default", initial_config, version=1) + + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 1 + assert "customer_records" in ws_schemas + assert "product_catalog" not in ws_schemas + # Act - Reload with full configuration - await processor.on_schema_config(integration_config, version=2) - + await processor.on_schema_config("default", integration_config, version=2) + # Assert - assert len(processor.schemas) == 2 - assert "customer_records" in processor.schemas - assert "product_catalog" in processor.schemas + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 2 + assert "customer_records" in ws_schemas + assert "product_catalog" in ws_schemas @pytest.mark.asyncio async def test_error_resilience_integration(self, integration_config): @@ -474,13 +475,14 @@ class TestObjectExtractionServiceIntegration: return AsyncMock() failing_flow.side_effect = failing_context_router + failing_flow.workspace = "default" processor.flow = failing_flow # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create test chunk - metadata = Metadata(id="error-test", user="test", collection="test") + metadata = Metadata(id="error-test", collection="test") chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process") mock_msg = MagicMock() @@ -510,12 +512,11 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create chunk with rich metadata original_metadata = Metadata( id="metadata-test-chunk", - user="test_user", collection="test_collection", ) @@ -544,6 +545,5 @@ class TestObjectExtractionServiceIntegration: assert extracted_obj is not None # Verify metadata propagation - assert extracted_obj.metadata.user == "test_user" assert extracted_obj.metadata.collection == "test_collection" assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference \ No newline at end of file diff --git a/tests/integration/test_prompt_streaming_integration.py b/tests/integration/test_prompt_streaming_integration.py index a1414e2d..84a3cdec 100644 --- a/tests/integration/test_prompt_streaming_integration.py +++ b/tests/integration/test_prompt_streaming_integration.py @@ -87,6 +87,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" return context @pytest.fixture @@ -109,7 +110,7 @@ class TestPromptStreaming: def prompt_processor_streaming(self, mock_prompt_manager): """Create Prompt processor with streaming support""" processor = MagicMock() - processor.manager = mock_prompt_manager + processor.managers = {"default": mock_prompt_manager} processor.config_key = "prompt" # Bind the actual on_request method @@ -248,6 +249,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" request = PromptRequest( id="test_prompt", @@ -341,6 +343,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" request = PromptRequest( id="test_prompt", diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 83a90412..279c81ef 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -267,7 +262,6 @@ class TestDocumentRagStreamingProtocol: # Act await document_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -290,7 +284,6 @@ class TestDocumentRagStreamingProtocol: # Act await document_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -314,7 +307,6 @@ class TestDocumentRagStreamingProtocol: # Act await document_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index a2b8ae08..1358d420 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class _MockFlowDefault: + """Mock Flow with default workspace for testing.""" + workspace = "default" + name = "default" + id = "test-processor" + + +mock_flow_default = _MockFlowDefault() + @pytest.mark.integration class TestRowsCassandraIntegration: """Integration tests for Cassandra row storage with unified table""" @@ -125,14 +136,13 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) - assert "customer_records" in processor.schemas + await processor.on_schema_config("default", config, version=1) + assert "customer_records" in processor.schemas["default"] # Step 2: Process an ExtractedObject test_obj = ExtractedObject( metadata=Metadata( id="doc-001", - user="test_user", collection="import_2024", ), schema_name="customer_records", @@ -149,7 +159,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify Cassandra interactions assert mock_cluster.connect.called @@ -158,7 +168,7 @@ class TestRowsCassandraIntegration: keyspace_calls = [call for call in mock_session.execute.call_args_list if "CREATE KEYSPACE" in str(call)] assert len(keyspace_calls) == 1 - assert "test_user" in str(keyspace_calls[0]) + assert "default" in str(keyspace_calls[0]) # Verify unified table creation (rows table, not per-schema table) table_calls = [call for call in mock_session.execute.call_args_list @@ -209,12 +219,12 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) - assert len(processor.schemas) == 2 + await processor.on_schema_config("default", config, version=1) + assert len(processor.schemas["default"]) == 2 # Process objects for different schemas product_obj = ExtractedObject( - metadata=Metadata(id="p1", user="shop", collection="catalog"), + metadata=Metadata(id="p1", collection="catalog"), schema_name="products", values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], confidence=0.9, @@ -222,7 +232,7 @@ class TestRowsCassandraIntegration: ) order_obj = ExtractedObject( - metadata=Metadata(id="o1", user="shop", collection="sales"), + metadata=Metadata(id="o1", collection="sales"), schema_name="orders", values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], confidence=0.85, @@ -233,7 +243,7 @@ class TestRowsCassandraIntegration: for obj in [product_obj, order_obj]: msg = MagicMock() msg.value.return_value = obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # All data goes into the same unified rows table table_calls = [call for call in mock_session.execute.call_args_list @@ -256,18 +266,20 @@ class TestRowsCassandraIntegration: with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): # Schema with multiple indexed fields - processor.schemas["indexed_data"] = RowSchema( - name="indexed_data", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="category", type="string", size=50, indexed=True), - Field(name="status", type="string", size=50, indexed=True), - Field(name="description", type="string", size=200) # Not indexed - ] - ) + processor.schemas["default"] = { + "indexed_data": RowSchema( + name="indexed_data", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True), + Field(name="status", type="string", size=50, indexed=True), + Field(name="description", type="string", size=200) # Not indexed + ] + ) + } test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test"), + metadata=Metadata(id="t1", collection="test"), schema_name="indexed_data", values=[{ "id": "123", @@ -282,7 +294,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 data inserts (one per indexed field: id, category, status) rows_insert_calls = [call for call in mock_session.execute.call_args_list @@ -342,13 +354,12 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) + await processor.on_schema_config("default", config, version=1) # Process batch object with multiple values batch_obj = ExtractedObject( metadata=Metadata( id="batch-001", - user="test_user", collection="batch_import", ), schema_name="batch_customers", @@ -376,7 +387,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify unified table creation table_calls = [call for call in mock_session.execute.call_args_list @@ -396,14 +407,16 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["empty_test"] = RowSchema( - name="empty_test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) + processor.schemas["default"] = { + "empty_test": RowSchema( + name="empty_test", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + } # Process empty batch object empty_obj = ExtractedObject( - metadata=Metadata(id="empty-1", user="test", collection="empty"), + metadata=Metadata(id="empty-1", collection="empty"), schema_name="empty_test", values=[], # Empty batch confidence=1.0, @@ -413,7 +426,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = empty_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should not create any data insert statements for empty batch # (partition registration may still happen) @@ -428,17 +441,19 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["map_test"] = RowSchema( - name="map_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="name", type="string", size=100), - Field(name="count", type="integer", size=0) - ] - ) + processor.schemas["default"] = { + "map_test": RowSchema( + name="map_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="name", type="string", size=100), + Field(name="count", type="integer", size=0) + ] + ) + } test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test"), + metadata=Metadata(id="t1", collection="test"), schema_name="map_test", values=[{"id": "123", "name": "Test Item", "count": "42"}], confidence=0.9, @@ -448,7 +463,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify insert uses map for data rows_insert_calls = [call for call in mock_session.execute.call_args_list @@ -473,16 +488,18 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["partition_test"] = RowSchema( - name="partition_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="category", type="string", size=50, indexed=True) - ] - ) + processor.schemas["default"] = { + "partition_test": RowSchema( + name="partition_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True) + ] + ) + } test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="my_collection"), + metadata=Metadata(id="t1", collection="my_collection"), schema_name="partition_test", values=[{"id": "123", "category": "test"}], confidence=0.9, @@ -492,7 +509,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify partition registration partition_inserts = [call for call in mock_session.execute.call_args_list diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index a717901b..29b4464d 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_schema_configuration_and_generation(self, processor, sample_schema_config): """Test schema configuration loading and GraphQL schema generation""" # Load schema configuration - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Verify schemas were loaded assert len(processor.schemas) == 2 @@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config): """Test Cassandra connection and dynamic table creation""" # Load schema configuration - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra processor.connect_cassandra() @@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config): """Test inserting data and querying via GraphQL""" # Load schema and connect - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Setup test data @@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_graphql_query_with_filters(self, processor, sample_schema_config): """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() keyspace = "test_user" @@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_graphql_error_handling(self, processor, sample_schema_config): """Test GraphQL error handling for invalid queries""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Test invalid field query invalid_query = ''' @@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_message_processing_integration(self, processor, sample_schema_config): """Test full message processing workflow""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Create mock message @@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_concurrent_queries(self, processor, sample_schema_config): """Test handling multiple concurrent GraphQL queries""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Create multiple query tasks @@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration: } } - await processor.on_schema_config(initial_config, version=1) + await processor.on_schema_config("default", initial_config, version=1) assert len(processor.schemas) == 1 assert "simple" in processor.schemas @@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration: } } - await processor.on_schema_config(updated_config, version=2) + await processor.on_schema_config("default", updated_config, version=2) # Verify updated schemas assert len(processor.schemas) == 2 @@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_large_result_set_handling(self, processor, sample_schema_config): """Test handling of large query result sets""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() keyspace = "large_test_user" @@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance: } } - await processor.on_schema_config(schema_config, version=1) + await processor.on_schema_config("default", schema_config, version=1) # Measure query execution time start_time = time.time() diff --git a/tests/integration/test_structured_query_integration.py b/tests/integration/test_structured_query_integration.py index d5fb5672..67c85406 100644 --- a/tests/integration/test_structured_query_integration.py +++ b/tests/integration/test_structured_query_integration.py @@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration: # Arrange - Create realistic query request request = StructuredQueryRequest( question="Show me all customers from California who have made purchases over $500", - user="trustgraph", collection="default" ) @@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration: assert "orders" in objects_call_args.query assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string assert objects_call_args.variables["state"] == "California" - assert objects_call_args.user == "trustgraph" assert objects_call_args.collection == "default" # Verify response diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index bb58e5ee..0a27b118 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming: # Setup mock agent manager mock_agent_instance = AsyncMock() mock_agent_manager_class.return_value = mock_agent_instance + mock_agent_instance.tools = {} + mock_agent_instance.additional_context = "" + processor.agents["default"] = mock_agent_instance # Mock react to call think and observe callbacks async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): @@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming: msg = MagicMock() msg.value.return_value = AgentRequest( question="What is 2 + 2?", - user="trustgraph", streaming=False # Non-streaming mode ) msg.properties.return_value = {"id": "test-id"} @@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming: # Setup flow mock consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" mock_producer = AsyncMock() @@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming: # Setup mock agent manager mock_agent_instance = AsyncMock() mock_agent_manager_class.return_value = mock_agent_instance + mock_agent_instance.tools = {} + mock_agent_instance.additional_context = "" + processor.agents["default"] = mock_agent_instance # Mock react to return Final directly async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): @@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming: msg = MagicMock() msg.value.return_value = AgentRequest( question="What is 2 + 2?", - user="trustgraph", streaming=False # Non-streaming mode ) msg.properties.return_value = {"id": "test-id"} @@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming: # Setup flow mock consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" mock_producer = AsyncMock() diff --git a/tests/unit/test_agent/test_aggregator.py b/tests/unit/test_agent/test_aggregator.py index afb19499..87a9e3bc 100644 --- a/tests/unit/test_agent/test_aggregator.py +++ b/tests/unit/test_agent/test_aggregator.py @@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep from trustgraph.agent.orchestrator.aggregator import Aggregator -def _make_request(question="Test question", user="testuser", +def _make_request(question="Test question", collection="default", streaming=False, session_id="parent-session", task_type="research", framing="test framing", conversation_id="conv-1"): return AgentRequest( question=question, - user=user, collection=collection, streaming=streaming, session_id=session_id, @@ -127,7 +126,6 @@ class TestBuildSynthesisRequest: req = agg.build_synthesis_request( "corr-1", original_question="Original question", - user="testuser", collection="default", ) @@ -148,7 +146,7 @@ class TestBuildSynthesisRequest: agg.record_completion("corr-1", "goal-b", "answer-b") req = agg.build_synthesis_request( - "corr-1", "question", "user", "default", + "corr-1", "question", "default", ) # Last history step should be the synthesis step @@ -168,7 +166,7 @@ class TestBuildSynthesisRequest: agg.record_completion("corr-1", "goal-a", "answer-a") agg.build_synthesis_request( - "corr-1", "question", "user", "default", + "corr-1", "question", "default", ) # Entry should be removed @@ -178,7 +176,7 @@ class TestBuildSynthesisRequest: agg = Aggregator() with pytest.raises(RuntimeError, match="No results"): agg.build_synthesis_request( - "unknown", "question", "user", "default", + "unknown", "question", "default", ) diff --git a/tests/unit/test_agent/test_completion_dispatch.py b/tests/unit/test_agent/test_completion_dispatch.py index 8c01f126..0d28d168 100644 --- a/tests/unit/test_agent/test_completion_dispatch.py +++ b/tests/unit/test_agent/test_completion_dispatch.py @@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator def _make_request(**kwargs): defaults = dict( question="Test question", - user="testuser", collection="default", ) defaults.update(kwargs) @@ -130,7 +129,6 @@ class TestAggregatorIntegration: synth = agg.build_synthesis_request( "corr-1", original_question="Original question", - user="testuser", collection="default", ) @@ -160,7 +158,7 @@ class TestAggregatorIntegration: agg.record_completion("corr-1", "goal", "answer") synth = agg.build_synthesis_request( - "corr-1", "question", "user", "default", + "corr-1", "question", "default", ) # correlation_id must be empty so it's not intercepted diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py index 63d87ba1..7a1ec4c1 100644 --- a/tests/unit/test_agent/test_orchestrator_provenance_integration.py +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -126,7 +126,6 @@ def make_base_request(**kwargs): state="", group=[], history=[], - user="testuser", collection="default", streaming=False, session_id="test-session-123", diff --git a/tests/unit/test_agent/test_pattern_base_subagent.py b/tests/unit/test_agent/test_pattern_base_subagent.py index 1523b592..bb176ba4 100644 --- a/tests/unit/test_agent/test_pattern_base_subagent.py +++ b/tests/unit/test_agent/test_pattern_base_subagent.py @@ -21,7 +21,6 @@ class MockProcessor: def _make_request(**kwargs): defaults = dict( question="Test question", - user="testuser", collection="default", ) defaults.update(kwargs) diff --git a/tests/unit/test_agent/test_tool_service.py b/tests/unit/test_agent/test_tool_service.py index 8bcf39ce..369a3c73 100644 --- a/tests/unit/test_agent/test_tool_service.py +++ b/tests/unit/test_agent/test_tool_service.py @@ -167,39 +167,28 @@ class TestToolServiceRequest: """Test cases for tool service request format""" def test_request_format(self): - """Test that request is properly formatted with user, config, and arguments""" - # Arrange - user = "alice" + """Test that request is properly formatted with config and arguments""" config_values = {"style": "pun", "collection": "jokes"} arguments = {"topic": "programming"} - # Act - simulate request building request = { - "user": user, "config": json.dumps(config_values), "arguments": json.dumps(arguments) } - # Assert - assert request["user"] == "alice" assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"} assert json.loads(request["arguments"]) == {"topic": "programming"} def test_request_with_empty_config(self): """Test request when no config values are provided""" - # Arrange - user = "bob" config_values = {} arguments = {"query": "test"} - # Act request = { - "user": user, "config": json.dumps(config_values) if config_values else "{}", "arguments": json.dumps(arguments) if arguments else "{}" } - # Assert assert request["config"] == "{}" assert json.loads(request["arguments"]) == {"query": "test"} @@ -386,18 +375,13 @@ class TestJokeServiceLogic: assert map_topic_to_category("random topic") == "default" assert map_topic_to_category("") == "default" - def test_joke_response_personalization(self): - """Test that joke responses include user personalization""" - # Arrange - user = "alice" + def test_joke_response_format(self): + """Test that joke response is formatted as expected""" style = "pun" joke = "Why do programmers prefer dark mode? Because light attracts bugs!" - # Act - response = f"Hey {user}! Here's a {style} for you:\n\n{joke}" + response = f"Here's a {style} for you:\n\n{joke}" - # Assert - assert "Hey alice!" in response assert "pun" in response assert joke in response @@ -439,20 +423,14 @@ class TestDynamicToolServiceBase: def test_request_parsing(self): """Test parsing of incoming request""" - # Arrange request_data = { - "user": "alice", "config": '{"style": "pun"}', "arguments": '{"topic": "programming"}' } - # Act - user = request_data.get("user", "trustgraph") config = json.loads(request_data["config"]) if request_data["config"] else {} arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {} - # Assert - assert user == "alice" assert config == {"style": "pun"} assert arguments == {"topic": "programming"} diff --git a/tests/unit/test_agent/test_tool_service_lifecycle.py b/tests/unit/test_agent/test_tool_service_lifecycle.py index 65cdb542..874ef0e6 100644 --- a/tests/unit/test_agent/test_tool_service_lifecycle.py +++ b/tests/unit/test_agent/test_tool_service_lifecycle.py @@ -1,6 +1,6 @@ """ Tests for tool service lifecycle, invoke contract, streaming responses, -multi-tenancy, and error propagation. +and error propagation. Tests the actual DynamicToolService, ToolService, and ToolServiceClient classes rather than plain dicts. @@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract: svc = DynamicToolService.__new__(DynamicToolService) with pytest.raises(NotImplementedError): - await svc.invoke("user", {}, {}) + await svc.invoke({}, {}) @pytest.mark.asyncio async def test_on_request_calls_invoke_with_parsed_args(self): @@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract: calls = [] - async def tracking_invoke(user, config, arguments): - calls.append({"user": user, "config": config, "arguments": arguments}) + async def tracking_invoke(config, arguments): + calls.append({"config": config, "arguments": arguments}) return "ok" svc.invoke = tracking_invoke @@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract: msg = MagicMock() msg.value.return_value = ToolServiceRequest( - user="alice", config='{"style": "pun"}', arguments='{"topic": "cats"}', ) @@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract: await svc.on_request(msg, MagicMock(), None) assert len(calls) == 1 - assert calls[0]["user"] == "alice" assert calls[0]["config"] == {"style": "pun"} assert calls[0]["arguments"] == {"topic": "cats"} - @pytest.mark.asyncio - async def test_on_request_empty_user_defaults_to_trustgraph(self): - """Empty user field should default to 'trustgraph'.""" - from trustgraph.base.dynamic_tool_service import DynamicToolService - - svc = DynamicToolService.__new__(DynamicToolService) - svc.id = "test-svc" - svc.producer = AsyncMock() - - received_user = None - - async def capture_invoke(user, config, arguments): - nonlocal received_user - received_user = user - return "ok" - - svc.invoke = capture_invoke - - if not hasattr(DynamicToolService, "tool_service_metric"): - DynamicToolService.tool_service_metric = MagicMock() - - msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="", config="", arguments="") - msg.properties.return_value = {"id": "req-2"} - - await svc.on_request(msg, MagicMock(), None) - - assert received_user == "trustgraph" - @pytest.mark.asyncio async def test_on_request_string_response_sent_directly(self): """String return from invoke → response field is the string.""" @@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def string_invoke(user, config, arguments): + async def string_invoke(config, arguments): return "hello world" svc.invoke = string_invoke @@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract: DynamicToolService.tool_service_metric = MagicMock() msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r1"} await svc.on_request(msg, MagicMock(), None) @@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def dict_invoke(user, config, arguments): + async def dict_invoke(config, arguments): return {"result": 42} svc.invoke = dict_invoke @@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract: DynamicToolService.tool_service_metric = MagicMock() msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r2"} await svc.on_request(msg, MagicMock(), None) @@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def failing_invoke(user, config, arguments): + async def failing_invoke(config, arguments): raise ValueError("bad input") svc.invoke = failing_invoke msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r3"} await svc.on_request(msg, MagicMock(), None) @@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def rate_limited_invoke(user, config, arguments): + async def rate_limited_invoke(config, arguments): raise TooManyRequests("rate limited") svc.invoke = rate_limited_invoke msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r4"} with pytest.raises(TooManyRequests): @@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def ok_invoke(user, config, arguments): + async def ok_invoke(config, arguments): return "ok" svc.invoke = ok_invoke @@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract: DynamicToolService.tool_service_metric = MagicMock() msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "unique-42"} await svc.on_request(msg, MagicMock(), None) @@ -241,7 +210,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def mock_invoke(name, params): + async def mock_invoke(workspace, name, params): return "tool result" svc.invoke_tool = mock_invoke @@ -260,6 +229,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}') @@ -280,7 +250,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def mock_invoke(name, params): + async def mock_invoke(workspace, name, params): return {"data": [1, 2, 3]} svc.invoke_tool = mock_invoke @@ -298,6 +268,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") @@ -317,7 +288,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def failing_invoke(name, params): + async def failing_invoke(workspace, name, params): raise RuntimeError("tool broke") svc.invoke_tool = failing_invoke @@ -330,6 +301,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") @@ -350,7 +322,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def rate_limited(name, params): + async def rate_limited(workspace, name, params): raise TooManyRequests("slow down") svc.invoke_tool = rate_limited @@ -362,6 +334,7 @@ class TestToolServiceOnRequest: flow = MagicMock() flow.producer = {"response": AsyncMock()} flow.name = "test-flow" + flow.workspace = "default" with pytest.raises(TooManyRequests): await svc.on_request(msg, MagicMock(), flow) @@ -376,7 +349,8 @@ class TestToolServiceOnRequest: received = {} - async def capture_invoke(name, params): + async def capture_invoke(workspace, name, params): + received["workspace"] = workspace received["name"] = name received["params"] = params return "ok" @@ -390,6 +364,7 @@ class TestToolServiceOnRequest: flow = lambda name: mock_pub flow.producer = {"response": mock_pub} flow.name = "f" + flow.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest( @@ -421,7 +396,6 @@ class TestToolServiceClientCall: )) result = await client.call( - user="alice", config={"style": "pun"}, arguments={"topic": "cats"}, ) @@ -430,7 +404,6 @@ class TestToolServiceClientCall: req = client.request.call_args[0][0] assert isinstance(req, ToolServiceRequest) - assert req.user == "alice" assert json.loads(req.config) == {"style": "pun"} assert json.loads(req.arguments) == {"topic": "cats"} @@ -446,7 +419,7 @@ class TestToolServiceClientCall: )) with pytest.raises(RuntimeError, match="service down"): - await client.call(user="u", config={}, arguments={}) + await client.call(config={}, arguments={}) @pytest.mark.asyncio async def test_call_empty_config_sends_empty_json(self): @@ -458,7 +431,7 @@ class TestToolServiceClientCall: error=None, response="ok", )) - await client.call(user="u", config=None, arguments=None) + await client.call(config=None, arguments=None) req = client.request.call_args[0][0] assert req.config == "{}" @@ -474,7 +447,7 @@ class TestToolServiceClientCall: error=None, response="ok", )) - await client.call(user="u", config={}, arguments={}, timeout=30) + await client.call(config={}, arguments={}, timeout=30) _, kwargs = client.request.call_args assert kwargs["timeout"] == 30 @@ -509,7 +482,7 @@ class TestToolServiceClientStreaming: received.append(text) result = await client.call_streaming( - user="u", config={}, arguments={}, callback=callback, + config={}, arguments={}, callback=callback, ) assert result == "chunk1chunk2" @@ -534,7 +507,7 @@ class TestToolServiceClientStreaming: with pytest.raises(RuntimeError, match="stream failed"): await client.call_streaming( - user="u", config={}, arguments={}, + config={}, arguments={}, callback=AsyncMock(), ) @@ -564,61 +537,9 @@ class TestToolServiceClientStreaming: received.append(text) result = await client.call_streaming( - user="u", config={}, arguments={}, callback=callback, + config={}, arguments={}, callback=callback, ) # Empty response is falsy, so callback shouldn't be called for it assert result == "data" assert received == ["data"] - - -# --------------------------------------------------------------------------- -# Multi-tenancy -# --------------------------------------------------------------------------- - -class TestMultiTenancy: - - @pytest.mark.asyncio - async def test_user_propagated_to_invoke(self): - """User from request should reach the invoke method.""" - from trustgraph.base.dynamic_tool_service import DynamicToolService - - svc = DynamicToolService.__new__(DynamicToolService) - svc.id = "test" - svc.producer = AsyncMock() - - users_seen = [] - - async def tracking(user, config, arguments): - users_seen.append(user) - return "ok" - - svc.invoke = tracking - - if not hasattr(DynamicToolService, "tool_service_metric"): - DynamicToolService.tool_service_metric = MagicMock() - - for u in ["tenant-a", "tenant-b", "tenant-c"]: - msg = MagicMock() - msg.value.return_value = ToolServiceRequest( - user=u, config="{}", arguments="{}", - ) - msg.properties.return_value = {"id": f"req-{u}"} - await svc.on_request(msg, MagicMock(), None) - - assert users_seen == ["tenant-a", "tenant-b", "tenant-c"] - - @pytest.mark.asyncio - async def test_client_sends_user_in_request(self): - """ToolServiceClient.call should include user in request.""" - from trustgraph.base.tool_service_client import ToolServiceClient - - client = ToolServiceClient.__new__(ToolServiceClient) - client.request = AsyncMock(return_value=ToolServiceResponse( - error=None, response="ok", - )) - - await client.call(user="isolated-tenant", config={}, arguments={}) - - req = client.request.call_args[0][0] - assert req.user == "isolated-tenant" diff --git a/tests/unit/test_base/test_async_processor_config.py b/tests/unit/test_base/test_async_processor_config.py index f1a83fef..3dffd775 100644 --- a/tests/unit/test_base/test_async_processor_config.py +++ b/tests/unit/test_base/test_async_processor_config.py @@ -1,17 +1,14 @@ """ Tests for AsyncProcessor config notify pattern: - register_config_handler with types filtering -- on_config_notify version comparison and type matching -- fetch_config with short-lived client -- fetch_and_apply_config retry logic +- on_config_notify version comparison, type/workspace matching +- fetch_and_apply_config retry logic over per-workspace fetches """ import pytest from unittest.mock import AsyncMock, MagicMock, patch, Mock -from trustgraph.schema import Term, IRI, LITERAL -# Patch heavy dependencies before importing AsyncProcessor @pytest.fixture def processor(): """Create an AsyncProcessor with mocked dependencies.""" @@ -68,6 +65,13 @@ class TestRegisterConfigHandler: assert len(processor.config_handlers) == 2 +def _notify_msg(version, changes): + """Build a Mock config-notify message with given version and changes dict.""" + msg = Mock() + msg.value.return_value = Mock(version=version, changes=changes) + return msg + + class TestOnConfigNotify: @pytest.mark.asyncio @@ -77,9 +81,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=3, types=["prompt"]) - + msg = _notify_msg(3, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -91,9 +93,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=5, types=["prompt"]) - + msg = _notify_msg(5, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -105,9 +105,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=2, types=["schema"]) - + msg = _notify_msg(2, {"schema": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -121,40 +119,36 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - # Mock fetch_config - mock_config = {"prompt": {"key": "value"}} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={"key": "value"}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) - handler.assert_called_once_with(mock_config, 2) + handler.assert_called_once_with( + "default", {"prompt": {"key": "value"}}, 2 + ) assert processor.config_version == 2 @pytest.mark.asyncio - async def test_handler_without_types_always_called(self, processor): + async def test_handler_without_types_ignored_on_notify(self, processor): + """Handlers registered without types never fire on notifications.""" processor.config_version = 1 handler = AsyncMock() - processor.register_config_handler(handler) # No types = all + processor.register_config_handler(handler) # No types - mock_config = {"anything": {}} - with patch.object( - processor, 'fetch_config', - new_callable=AsyncMock, - return_value=(mock_config, 2) - ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["whatever"]) + msg = _notify_msg(2, {"whatever": ["default"]}) + await processor.on_config_notify(msg, None, None) - await processor.on_config_notify(msg, None, None) - - handler.assert_called_once_with(mock_config, 2) + handler.assert_not_called() + # Version still advances past the notify + assert processor.config_version == 2 @pytest.mark.asyncio async def test_mixed_handlers_type_filtering(self, processor): @@ -168,156 +162,149 @@ class TestOnConfigNotify: processor.register_config_handler(schema_handler, types=["schema"]) processor.register_config_handler(all_handler) - mock_config = {"prompt": {}} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) - prompt_handler.assert_called_once() + prompt_handler.assert_called_once_with( + "default", {"prompt": {}}, 2 + ) schema_handler.assert_not_called() - all_handler.assert_called_once() + all_handler.assert_not_called() @pytest.mark.asyncio - async def test_empty_types_invokes_all(self, processor): - """Empty types list (startup signal) should invoke all handlers.""" + async def test_multi_workspace_notify_invokes_handler_per_ws( + self, processor + ): + """Notify affecting multiple workspaces invokes handler once per workspace.""" processor.config_version = 1 - h1 = AsyncMock() - h2 = AsyncMock() - processor.register_config_handler(h1, types=["prompt"]) - processor.register_config_handler(h2, types=["schema"]) + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) - mock_config = {} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=[]) - + msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]}) await processor.on_config_notify(msg, None, None) - h1.assert_called_once() - h2.assert_called_once() + assert handler.call_count == 2 + called_workspaces = {c.args[0] for c in handler.call_args_list} + assert called_workspaces == {"ws1", "ws2"} @pytest.mark.asyncio async def test_fetch_failure_handled(self, processor): processor.config_version = 1 handler = AsyncMock() - processor.register_config_handler(handler) + processor.register_config_handler(handler, types=["prompt"]) + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - side_effect=RuntimeError("Connection failed") + side_effect=RuntimeError("Connection failed"), ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) # Should not raise await processor.on_config_notify(msg, None, None) handler.assert_not_called() -class TestFetchConfig: - - @pytest.mark.asyncio - async def test_fetch_returns_config_and_version(self, processor): - mock_resp = Mock() - mock_resp.error = None - mock_resp.config = {"prompt": {"key": "val"}} - mock_resp.version = 42 - - mock_client = AsyncMock() - mock_client.request.return_value = mock_resp - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - config, version = await processor.fetch_config() - - assert config == {"prompt": {"key": "val"}} - assert version == 42 - mock_client.stop.assert_called_once() - - @pytest.mark.asyncio - async def test_fetch_raises_on_error_response(self, processor): - mock_resp = Mock() - mock_resp.error = Mock(message="not found") - mock_resp.config = {} - mock_resp.version = 0 - - mock_client = AsyncMock() - mock_client.request.return_value = mock_resp - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - with pytest.raises(RuntimeError, match="Config error"): - await processor.fetch_config() - - mock_client.stop.assert_called_once() - - @pytest.mark.asyncio - async def test_fetch_stops_client_on_exception(self, processor): - mock_client = AsyncMock() - mock_client.request.side_effect = TimeoutError("timeout") - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - with pytest.raises(TimeoutError): - await processor.fetch_config() - - mock_client.stop.assert_called_once() - - class TestFetchAndApplyConfig: @pytest.mark.asyncio - async def test_applies_config_to_all_handlers(self, processor): - h1 = AsyncMock() - h2 = AsyncMock() - processor.register_config_handler(h1, types=["prompt"]) - processor.register_config_handler(h2, types=["schema"]) + async def test_applies_config_per_workspace(self, processor): + """Startup fetch invokes handler once per workspace affected.""" + h = AsyncMock() + processor.register_config_handler(h, types=["prompt"]) + + mock_client = AsyncMock() + + async def fake_fetch_all(client, config_type): + return { + "ws1": {"k": "v1"}, + "ws2": {"k": "v2"}, + }, 10 - mock_config = {"prompt": {}, "schema": {}} with patch.object( - processor, 'fetch_config', - new_callable=AsyncMock, - return_value=(mock_config, 10) + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, ): await processor.fetch_and_apply_config() - # On startup, all handlers are invoked regardless of type - h1.assert_called_once_with(mock_config, 10) - h2.assert_called_once_with(mock_config, 10) + assert h.call_count == 2 + call_map = {c.args[0]: c.args[1] for c in h.call_args_list} + assert call_map["ws1"] == {"prompt": {"k": "v1"}} + assert call_map["ws2"] == {"prompt": {"k": "v2"}} assert processor.config_version == 10 @pytest.mark.asyncio - async def test_retries_on_failure(self, processor): - call_count = 0 - mock_config = {"prompt": {}} + async def test_handler_without_types_skipped_at_startup(self, processor): + """Handlers registered without types fetch nothing at startup.""" + typed = AsyncMock() + untyped = AsyncMock() + processor.register_config_handler(typed, types=["prompt"]) + processor.register_config_handler(untyped) - async def mock_fetch(): + mock_client = AsyncMock() + + async def fake_fetch_all(client, config_type): + return {"default": {}}, 1 + + with patch.object( + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, + ): + await processor.fetch_and_apply_config() + + typed.assert_called_once() + untyped.assert_not_called() + + @pytest.mark.asyncio + async def test_retries_on_failure(self, processor): + h = AsyncMock() + processor.register_config_handler(h, types=["prompt"]) + + call_count = 0 + + async def fake_fetch_all(client, config_type): nonlocal call_count call_count += 1 if call_count < 3: raise RuntimeError("not ready") - return mock_config, 5 + return {"default": {"k": "v"}}, 5 - with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \ - patch('asyncio.sleep', new_callable=AsyncMock): + mock_client = AsyncMock() + with patch.object( + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, + ), patch('asyncio.sleep', new_callable=AsyncMock): await processor.fetch_and_apply_config() assert call_count == 3 assert processor.config_version == 5 + h.assert_called_once_with( + "default", {"prompt": {"k": "v"}}, 5 + ) diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py index 705f2bd1..ff9e67e9 100644 --- a/tests/unit/test_base/test_document_embeddings_client.py +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): result = await client.query( vector=vector, limit=10, - user="test_user", collection="test_collection", timeout=30 ) @@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): assert isinstance(call_args, DocumentEmbeddingsRequest) assert call_args.vector == vector assert call_args.limit == 10 - assert call_args.user == "test_user" assert call_args.collection == "test_collection" @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') @@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client.request.assert_called_once() call_args = client.request.call_args[0][0] assert call_args.limit == 20 # Default limit - assert call_args.user == "trustgraph" # Default user assert call_args.collection == "default" # Default collection @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') diff --git a/tests/unit/test_base/test_flow_base_modules.py b/tests/unit/test_base/test_flow_base_modules.py index 5bbd7a18..758edcff 100644 --- a/tests/unit/test_base/test_flow_base_modules.py +++ b/tests/unit/test_base/test_flow_base_modules.py @@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs(): spec_two = MagicMock() processor = MagicMock(specifications=[spec_one, spec_two]) - flow = Flow("processor-1", "flow-a", processor, {"answer": 42}) + flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42}) assert flow.id == "processor-1" assert flow.name == "flow-a" + assert flow.workspace == "default" assert flow.producer == {} assert flow.consumer == {} assert flow.parameter == {} @@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs(): def test_flow_start_and_stop_visit_all_consumers(): consumer_one = AsyncMock() consumer_two = AsyncMock() - flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) + flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {}) flow.consumer = {"one": consumer_one, "two": consumer_two} asyncio.run(flow.start()) @@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers(): def test_flow_call_returns_values_in_priority_order(): - flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) + flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {}) flow.producer["shared"] = "producer-value" flow.consumer["consumer-only"] = "consumer-value" flow.consumer["shared"] = "consumer-value" diff --git a/tests/unit/test_base/test_flow_parameter_specs.py b/tests/unit/test_base/test_flow_parameter_specs.py index c813d66c..da7e9736 100644 --- a/tests/unit/test_base/test_flow_parameter_specs.py +++ b/tests/unit/test_base/test_flow_parameter_specs.py @@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase): flow_defn = {'config': 'test-config'} # Act - await processor.start_flow(flow_name, flow_defn) + await processor.start_flow("default", flow_name, flow_defn) # Assert - Flow should be created with access to processor specifications - mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn) + mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn) # The flow should have access to the processor's specifications # (The exact mechanism depends on Flow implementation) diff --git a/tests/unit/test_base/test_flow_processor.py b/tests/unit/test_base/test_flow_processor.py index 36a05ec2..350a8b43 100644 --- a/tests/unit/test_base/test_flow_processor.py +++ b/tests/unit/test_base/test_flow_processor.py @@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): flow_name = 'test-flow' flow_defn = {'config': 'test-config'} - await processor.start_flow(flow_name, flow_defn) + await processor.start_flow("default", flow_name, flow_defn) - assert flow_name in processor.flows + assert ("default", flow_name) in processor.flows mock_flow_class.assert_called_once_with( - 'test-processor', flow_name, processor, flow_defn + 'test-processor', flow_name, "default", processor, flow_defn ) mock_flow.start.assert_called_once() @@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): mock_flow_class.return_value = mock_flow flow_name = 'test-flow' - await processor.start_flow(flow_name, {'config': 'test-config'}) + await processor.start_flow("default", flow_name, {'config': 'test-config'}) - await processor.stop_flow(flow_name) + await processor.stop_flow("default", flow_name) - assert flow_name not in processor.flows + assert ("default", flow_name) not in processor.flows mock_flow.stop.assert_called_once() @with_async_processor_patches @@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): processor = FlowProcessor(**config) - await processor.stop_flow('non-existent-flow') + await processor.stop_flow("default", 'non-existent-flow') assert processor.flows == {} @@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) - assert 'test-flow' in processor.flows + assert ("default", 'test-flow') in processor.flows mock_flow_class.assert_called_once_with( - 'test-processor', 'test-flow', processor, + 'test-processor', 'test-flow', "default", processor, {'config': 'test-config'} ) mock_flow.start.assert_called_once() @@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) assert processor.flows == {} @@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): 'other-data': 'some-value' } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) assert processor.flows == {} @@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data1, version=1) + await processor.on_configure_flows("default", config_data1, version=1) config_data2 = { 'processor:test-processor': { @@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data2, version=2) + await processor.on_configure_flows("default", config_data2, version=2) - assert 'flow1' not in processor.flows + assert ("default", 'flow1') not in processor.flows mock_flow1.stop.assert_called_once() - assert 'flow2' in processor.flows + assert ("default", 'flow2') in processor.flows mock_flow2.start.assert_called_once() @with_async_processor_patches diff --git a/tests/unit/test_chunking/conftest.py b/tests/unit/test_chunking/conftest.py index 31dab77d..c1f9ae33 100644 --- a/tests/unit/test_chunking/conftest.py +++ b/tests/unit/test_chunking/conftest.py @@ -28,7 +28,6 @@ def sample_text_document(): """Sample document with moderate length text.""" metadata = Metadata( id="test-doc-1", - user="test-user", collection="test-collection" ) text = "The quick brown fox jumps over the lazy dog. " * 20 @@ -43,7 +42,6 @@ def long_text_document(): """Long document for testing multiple chunks.""" metadata = Metadata( id="test-doc-long", - user="test-user", collection="test-collection" ) # Create a long text that will definitely be chunked @@ -59,7 +57,6 @@ def unicode_text_document(): """Document with various unicode characters.""" metadata = Metadata( id="test-doc-unicode", - user="test-user", collection="test-collection" ) text = """ @@ -84,7 +81,6 @@ def empty_text_document(): """Empty document for edge case testing.""" metadata = Metadata( id="test-doc-empty", - user="test-user", collection="test-collection" ) return TextDocument( diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index d1a5d247..74178ab4 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_text_doc = MagicMock() mock_text_doc.metadata = Metadata( id="test-doc-123", - user="test-user", collection="test-collection" ) mock_text_doc.text = b"This is test document content" diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index dba4ca94..568b335f 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_text_doc = MagicMock() mock_text_doc.metadata = Metadata( id="test-doc-456", - user="test-user", collection="test-collection" ) mock_text_doc.text = b"This is test document content for token chunking" diff --git a/tests/unit/test_cli/test_config_commands.py b/tests/unit/test_cli/test_config_commands.py index 68ae1a54..b5b74688 100644 --- a/tests/unit/test_cli/test_config_commands.py +++ b/tests/unit/test_cli/test_config_commands.py @@ -109,7 +109,8 @@ class TestListConfigItems: url='http://custom.com', config_type='prompt', format_type='json', - token=None + token=None, + workspace='default' ) def test_list_main_uses_defaults(self): @@ -128,7 +129,8 @@ class TestListConfigItems: url='http://localhost:8088/', config_type='prompt', format_type='text', - token=None + token=None, + workspace='default' ) @@ -196,7 +198,8 @@ class TestGetConfigItem: config_type='prompt', key='template-1', format_type='json', - token=None + token=None, + workspace='default' ) @@ -253,7 +256,8 @@ class TestPutConfigItem: config_type='prompt', key='new-template', value='Custom prompt: {input}', - token=None + token=None, + workspace='default' ) def test_put_main_with_stdin_arg(self): @@ -278,7 +282,8 @@ class TestPutConfigItem: config_type='prompt', key='stdin-template', value=stdin_content, - token=None + token=None, + workspace='default' ) def test_put_main_mutually_exclusive_args(self): @@ -334,7 +339,8 @@ class TestDeleteConfigItem: url='http://custom.com', config_type='prompt', key='old-template', - token=None + token=None, + workspace='default' ) diff --git a/tests/unit/test_cli/test_load_knowledge.py b/tests/unit/test_cli/test_load_knowledge.py index 63045ef9..a1588e85 100644 --- a/tests/unit/test_cli/test_load_knowledge.py +++ b/tests/unit/test_cli/test_load_knowledge.py @@ -48,7 +48,7 @@ def knowledge_loader(): return KnowledgeLoader( files=["test.ttl"], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc-123", url="http://test.example.com/", @@ -64,7 +64,7 @@ class TestKnowledgeLoader: loader = KnowledgeLoader( files=["file1.ttl", "file2.ttl"], flow="my-flow", - user="user1", + workspace="user1", collection="col1", document_id="doc1", url="http://example.com/", @@ -73,7 +73,7 @@ class TestKnowledgeLoader: assert loader.files == ["file1.ttl", "file2.ttl"] assert loader.flow == "my-flow" - assert loader.user == "user1" + assert loader.workspace == "user1" assert loader.collection == "col1" assert loader.document_id == "doc1" assert loader.url == "http://example.com/" @@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob . loader = KnowledgeLoader( files=[f.name], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc", url="http://test.example.com/" @@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob . loader = KnowledgeLoader( files=[temp_turtle_file], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc", url="http://test.example.com/", @@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob . # Verify Api was created with correct parameters mock_api_class.assert_called_once_with( url="http://test.example.com/", - token="test-token" + token="test-token", + workspace="test-user" ) # Verify bulk client was obtained @@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob . call_args = mock_bulk.import_triples.call_args assert call_args[1]['flow'] == "test-flow" assert call_args[1]['metadata']['id'] == "test-doc" - assert call_args[1]['metadata']['user'] == "test-user" assert call_args[1]['metadata']['collection'] == "test-collection" # Verify import_entity_contexts was called @@ -198,7 +198,7 @@ class TestCLIArgumentParsing: 'tg-load-knowledge', '-i', 'doc-123', '-f', 'my-flow', - '-U', 'my-user', + '-w', 'my-user', '-C', 'my-collection', '-u', 'http://custom.example.com/', '-t', 'my-token', @@ -216,7 +216,7 @@ class TestCLIArgumentParsing: token='my-token', flow='my-flow', files=['file1.ttl', 'file2.ttl'], - user='my-user', + workspace='my-user', collection='my-collection' ) @@ -242,7 +242,7 @@ class TestCLIArgumentParsing: # Verify defaults were used call_args = mock_loader_class.call_args[1] assert call_args['flow'] == 'default' - assert call_args['user'] == 'trustgraph' + assert call_args['workspace'] == 'default' assert call_args['collection'] == 'default' assert call_args['url'] == 'http://localhost:8088/' assert call_args['token'] is None @@ -287,7 +287,7 @@ class TestErrorHandling: loader = KnowledgeLoader( files=[temp_turtle_file], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc", url="http://test.example.com/" diff --git a/tests/unit/test_cli/test_tool_commands.py b/tests/unit/test_cli/test_tool_commands.py index 9c204614..72624d27 100644 --- a/tests/unit/test_cli/test_tool_commands.py +++ b/tests/unit/test_cli/test_tool_commands.py @@ -145,7 +145,8 @@ class TestSetToolStructuredQuery: group=None, state=None, applicable_states=None, - token=None + token=None, + workspace='default' ) def test_set_main_structured_query_no_arguments_needed(self): @@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery: group=None, state=None, applicable_states=None, - token=None + token=None, + workspace='default' ) def test_valid_types_includes_row_embeddings_query(self): @@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery: show_main() - mock_show.assert_called_once_with(url='http://custom.com', token=None) + mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default') class TestShowToolsRowEmbeddingsQuery: diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index edf4ac81..6c466877 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient: # Act result = client.request( vector=vector, - user="test_user", collection="test_collection", limit=10, timeout=300 @@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient: # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.call.assert_called_once_with( - user="test_user", collection="test_collection", vector=vector, limit=10, @@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient: # Assert assert result == ["test_chunk"] client.call.assert_called_once_with( - user="trustgraph", collection="default", vector=vector, limit=10, diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py index 8287427b..1b35a238 100644 --- a/tests/unit/test_concurrency/test_graph_rag_concurrency.py +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -31,7 +31,6 @@ def _make_query( query = Query( rag=rag, - user="test-user", collection="test-collection", verbose=False, entity_limit=entity_limit, @@ -208,7 +207,6 @@ class TestBatchTripleQueries: assert calls[0].kwargs["p"] is None assert calls[0].kwargs["o"] is None assert calls[0].kwargs["limit"] == 15 - assert calls[0].kwargs["user"] == "test-user" assert calls[0].kwargs["collection"] == "test-collection" assert calls[0].kwargs["batch_size"] == 20 diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index 80c27fe8..d677b82f 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -28,10 +28,12 @@ def mock_flow_config(): """Mock flow configuration.""" mock_config = Mock() mock_config.flows = { - "test-flow": { - "interfaces": { - "triples-store": {"flow": "test-triples-queue"}, - "graph-embeddings-store": {"flow": "test-ge-queue"} + "test-user": { + "test-flow": { + "interfaces": { + "triples-store": {"flow": "test-triples-queue"}, + "graph-embeddings-store": {"flow": "test-ge-queue"} + } } } } @@ -43,7 +45,7 @@ def mock_flow_config(): def mock_request(): """Mock knowledge load request.""" request = Mock() - request.user = "test-user" + request.workspace = "test-user" request.id = "test-doc-id" request.collection = "test-collection" request.flow = "test-flow" @@ -71,7 +73,6 @@ def sample_triples(): return Triples( metadata=Metadata( id="test-doc-id", - user="test-user", collection="default", # This should be overridden ), triples=[ @@ -90,7 +91,6 @@ def sample_graph_embeddings(): return GraphEmbeddings( metadata=Metadata( id="test-doc-id", - user="test-user", collection="default", # This should be overridden ), entities=[ @@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore: mock_triples_pub.send.assert_called_once() sent_triples = mock_triples_pub.send.call_args[0][1] assert sent_triples.metadata.collection == "test-collection" - assert sent_triples.metadata.user == "test-user" assert sent_triples.metadata.id == "test-doc-id" @pytest.mark.asyncio @@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore: mock_ge_pub.send.assert_called_once() sent_ge = mock_ge_pub.send.call_args[0][1] assert sent_ge.metadata.collection == "test-collection" - assert sent_ge.metadata.user == "test-user" assert sent_ge.metadata.id == "test-doc-id" @pytest.mark.asyncio @@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore: """Test that load_kg_core falls back to 'default' when request.collection is None.""" # Create request with None collection mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_request.collection = None # Should fall back to "default" mock_request.flow = "test-flow" @@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore: """Test that load_kg_core validates flow configuration before processing.""" # Request with invalid flow mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_request.collection = "test-collection" mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows @@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore: # Test missing ID mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = None # Missing mock_request.collection = "test-collection" mock_request.flow = "test-flow" @@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods: async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples): """Test that get_kg_core preserves collection field from stored data.""" mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_respond = AsyncMock() @@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods: async def test_list_kg_cores(self, knowledge_manager): """Test listing knowledge cores.""" mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_respond = AsyncMock() @@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods: async def test_delete_kg_core(self, knowledge_manager): """Test deleting knowledge cores.""" mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_respond = AsyncMock() diff --git a/tests/unit/test_decoding/test_universal_processor.py b/tests/unit/test_decoding/test_universal_processor.py index 4daa9b68..36804860 100644 --- a/tests/unit/test_decoding/test_universal_processor.py +++ b/tests/unit/test_decoding/test_universal_processor.py @@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): # Mock message with inline data content = b"# Document Title\nBody text content." - mock_metadata = Metadata(id="test-doc", user="testuser", + mock_metadata = Metadata(id="test-doc", collection="default") mock_document = Document( metadata=mock_metadata, @@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): # Mock message content = b"fake pdf" - mock_metadata = Metadata(id="test-doc", user="testuser", + mock_metadata = Metadata(id="test-doc", collection="default") mock_document = Document( metadata=mock_metadata, @@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): ] content = b"fake pdf" - mock_metadata = Metadata(id="test-doc", user="testuser", + mock_metadata = Metadata(id="test-doc", collection="default") mock_document = Document( metadata=mock_metadata, diff --git a/tests/unit/test_direct/test_milvus_collection_naming.py b/tests/unit/test_direct/test_milvus_collection_naming.py index d948caff..57c00a54 100644 --- a/tests/unit/test_direct/test_milvus_collection_naming.py +++ b/tests/unit/test_direct/test_milvus_collection_naming.py @@ -12,7 +12,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_basic(self): """Test basic collection name creation""" result = make_safe_collection_name( - user="test_user", + workspace="test_user", collection="test_collection", prefix="doc" ) @@ -21,7 +21,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_special_characters(self): """Test collection name creation with special characters that need sanitization""" result = make_safe_collection_name( - user="user@domain.com", + workspace="user@domain.com", collection="test-collection.v2", prefix="entity" ) @@ -30,7 +30,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_unicode(self): """Test collection name creation with Unicode characters""" result = make_safe_collection_name( - user="测试用户", + workspace="测试用户", collection="colección_española", prefix="doc" ) @@ -39,7 +39,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_spaces(self): """Test collection name creation with spaces""" result = make_safe_collection_name( - user="test user", + workspace="test user", collection="my test collection", prefix="entity" ) @@ -48,7 +48,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self): """Test collection name creation with multiple consecutive special characters""" result = make_safe_collection_name( - user="user@@@domain!!!", + workspace="user@@@domain!!!", collection="test---collection...v2", prefix="doc" ) @@ -57,7 +57,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_leading_trailing_underscores(self): """Test collection name creation with leading/trailing special characters""" result = make_safe_collection_name( - user="__test_user__", + workspace="__test_user__", collection="@@test_collection##", prefix="entity" ) @@ -66,7 +66,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_empty_user(self): """Test collection name creation with empty user (should fallback to 'default')""" result = make_safe_collection_name( - user="", + workspace="", collection="test_collection", prefix="doc" ) @@ -75,7 +75,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_empty_collection(self): """Test collection name creation with empty collection (should fallback to 'default')""" result = make_safe_collection_name( - user="test_user", + workspace="test_user", collection="", prefix="doc" ) @@ -84,7 +84,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_both_empty(self): """Test collection name creation with both user and collection empty""" result = make_safe_collection_name( - user="", + workspace="", collection="", prefix="doc" ) @@ -93,7 +93,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_only_special_characters(self): """Test collection name creation with only special characters (should fallback to 'default')""" result = make_safe_collection_name( - user="@@@!!!", + workspace="@@@!!!", collection="---###", prefix="entity" ) @@ -102,7 +102,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_whitespace_only(self): """Test collection name creation with whitespace-only strings""" result = make_safe_collection_name( - user=" \n\t ", + workspace=" \n\t ", collection=" \r\n ", prefix="doc" ) @@ -111,7 +111,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_mixed_valid_invalid_chars(self): """Test collection name creation with mixed valid and invalid characters""" result = make_safe_collection_name( - user="user123@test", + workspace="user123@test", collection="coll_2023.v1", prefix="entity" ) @@ -147,7 +147,7 @@ class TestMilvusCollectionNaming: long_collection = "b" * 100 result = make_safe_collection_name( - user=long_user, + workspace=long_user, collection=long_collection, prefix="doc" ) @@ -159,7 +159,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_numeric_values(self): """Test collection name creation with numeric user/collection values""" result = make_safe_collection_name( - user="user123", + workspace="user123", collection="collection456", prefix="doc" ) @@ -168,7 +168,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_case_sensitivity(self): """Test that collection name creation preserves case""" result = make_safe_collection_name( - user="TestUser", + workspace="TestUser", collection="TestCollection", prefix="Doc" ) diff --git a/tests/unit/test_embeddings/test_document_embeddings_processor.py b/tests/unit/test_embeddings/test_document_embeddings_processor.py index 9cd93c4f..314d81c3 100644 --- a/tests/unit/test_embeddings/test_document_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_document_embeddings_processor.py @@ -20,9 +20,8 @@ def processor(): ) -def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", - user="test", collection="default"): - metadata = Metadata(id=doc_id, user=user, collection=collection) +def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"): + metadata = Metadata(id=doc_id, collection=collection) value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id) msg = MagicMock() msg.value.return_value = value @@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor: @pytest.mark.asyncio async def test_metadata_preserved(self, processor): """Output should carry the original metadata.""" - msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1") + msg = _make_chunk_message(collection="reports", doc_id="d1") mock_request = AsyncMock(return_value=EmbeddingsResponse( error=None, vectors=[[0.0]] @@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor: await processor.on_message(msg, MagicMock(), flow) result = mock_output.send.call_args[0][0] - assert result.metadata.user == "alice" assert result.metadata.collection == "reports" assert result.metadata.id == "d1" diff --git a/tests/unit/test_embeddings/test_graph_embeddings_processor.py b/tests/unit/test_embeddings/test_graph_embeddings_processor.py index 5d535349..f3cec4d2 100644 --- a/tests/unit/test_embeddings/test_graph_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_graph_embeddings_processor.py @@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"): return MagicMock(entity=entity, context=context, chunk_id=chunk_id) -def _make_message(entities, doc_id="doc-1", user="test", collection="default"): - metadata = Metadata(id=doc_id, user=user, collection=collection) +def _make_message(entities, doc_id="doc-1", collection="default"): + metadata = Metadata(id=doc_id, collection=collection) value = EntityContexts(metadata=metadata, entities=entities) msg = MagicMock() msg.value.return_value = value @@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing: _make_entity_context(f"E{i}", f"ctx {i}") for i in range(5) ] - msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main") + msg = _make_message(entities, doc_id="doc-42", collection="main") mock_embed = AsyncMock(return_value=[[0.0]] * 5) mock_output = AsyncMock() @@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing: for call in mock_output.send.call_args_list: result = call[0][0] assert result.metadata.id == "doc-42" - assert result.metadata.user == "alice" assert result.metadata.collection == "main" @pytest.mark.asyncio diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py index 45a22e48..36ecd013 100644 --- a/tests/unit/test_embeddings/test_row_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } } - await processor.on_schema_config(config_data, 1) + await processor.on_schema_config("default", config_data, 1) - assert 'customers' in processor.schemas - assert processor.schemas['customers'].name == 'customers' - assert len(processor.schemas['customers'].fields) == 3 + assert 'customers' in processor.schemas["default"] + assert processor.schemas["default"]['customers'].name == 'customers' + assert len(processor.schemas["default"]['customers'].fields) == 3 async def test_on_schema_config_handles_missing_type(self): """Test that missing schema type is handled gracefully""" @@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): 'other_type': {} } - await processor.on_schema_config(config_data, 1) + await processor.on_schema_config("default", config_data, 1) - assert processor.schemas == {} + assert processor.schemas.get("default", {}) == {} async def test_on_message_drops_unknown_collection(self): """Test that messages for unknown collections are dropped""" @@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('default', 'test_collection')] = {} # No schemas registered metadata = MagicMock() @@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('default', 'test_collection')] = {} # Set up schema - processor.schemas['customers'] = RowSchema( - name='customers', - description='Customer records', - fields=[ - Field(name='id', type='text', primary=True), - Field(name='name', type='text', indexed=True), - ] - ) + processor.schemas["default"] = { + 'customers': RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + ] + ) + } metadata = MagicMock() metadata.user = 'test_user' @@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): return MagicMock() mock_flow = MagicMock(side_effect=flow_factory) + mock_flow.workspace = "default" await processor.on_message(mock_msg, MagicMock(), mock_flow) diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py index cbc9a05a..09bb7988 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py @@ -34,11 +34,10 @@ def _make_defn(entity, definition): return {"entity": entity, "definition": definition} -def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", - user="user-1", collection="col-1", document_id=""): +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""): chunk = Chunk( metadata=Metadata( - id=meta_id, root=root, user=user, collection=collection, + id=meta_id, root=root, collection=collection, ), chunk=text.encode("utf-8"), document_id=document_id, @@ -229,8 +228,7 @@ class TestMetadataPreservation: defs = [_make_defn("X", "def X")] flow, triples_pub, _, _ = _make_flow(defs) msg = _make_chunk_msg( - "text", meta_id="c-1", root="r-1", - user="u-1", collection="coll-1", + "text", meta_id="c-1", root="r-1", collection="coll-1", ) await proc.on_message(msg, MagicMock(), flow) @@ -238,7 +236,6 @@ class TestMetadataPreservation: for triples_msg in _sent_triples(triples_pub): assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.root == "r-1" - assert triples_msg.metadata.user == "u-1" assert triples_msg.metadata.collection == "coll-1" @pytest.mark.asyncio @@ -247,8 +244,7 @@ class TestMetadataPreservation: defs = [_make_defn("X", "def X")] flow, _, ecs_pub, _ = _make_flow(defs) msg = _make_chunk_msg( - "text", meta_id="c-2", root="r-2", - user="u-2", collection="coll-2", + "text", meta_id="c-2", root="r-2", collection="coll-2", ) await proc.on_message(msg, MagicMock(), flow) diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py index d9861cf3..b85c9e00 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py @@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True): } -def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", - user="user-1", collection="col-1", document_id=""): +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""): """Build a mock message wrapping a Chunk.""" chunk = Chunk( metadata=Metadata( - id=meta_id, root=root, user=user, collection=collection, + id=meta_id, root=root, collection=collection, ), chunk=text.encode("utf-8"), document_id=document_id, @@ -189,8 +188,7 @@ class TestMetadataPreservation: rels = [_make_rel("X", "rel", "Y")] flow, pub, _ = _make_flow(rels) msg = _make_chunk_msg( - "text", meta_id="c-1", root="r-1", - user="u-1", collection="coll-1", + "text", meta_id="c-1", root="r-1", collection="coll-1", ) await proc.on_message(msg, MagicMock(), flow) @@ -198,7 +196,6 @@ class TestMetadataPreservation: for triples_msg in _sent_triples(pub): assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.root == "r-1" - assert triples_msg.metadata.user == "u-1" assert triples_msg.metadata.collection == "coll-1" diff --git a/tests/unit/test_gateway/test_config_receiver.py b/tests/unit/test_gateway/test_config_receiver.py index 90ba8d33..56e96178 100644 --- a/tests/unit/test_gateway/test_config_receiver.py +++ b/tests/unit/test_gateway/test_config_receiver.py @@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader ConfigReceiver.config_loader = Mock() +def _notify(version, changes): + msg = Mock() + msg.value.return_value = Mock(version=version, changes=changes) + return msg + + class TestConfigReceiver: """Test cases for ConfigReceiver class""" @@ -47,98 +53,70 @@ class TestConfigReceiver: assert handler2 in config_receiver.flow_handlers @pytest.mark.asyncio - async def test_on_config_notify_new_version(self): - """Test on_config_notify triggers fetch for newer version""" + async def test_on_config_notify_new_version_fetches_per_workspace(self): + """Notify with newer version fetches each affected workspace.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 1 - # Mock fetch_and_apply fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with newer version - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=["flow"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - assert len(fetch_calls) == 1 + msg = _notify(2, {"flow": ["ws1", "ws2"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert set(fetch_calls) == {"ws1", "ws2"} + assert config_receiver.config_version == 2 @pytest.mark.asyncio async def test_on_config_notify_old_version_ignored(self): - """Test on_config_notify ignores older versions""" + """Older-version notifies are ignored.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 5 fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with older version - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=3, types=["flow"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - assert len(fetch_calls) == 0 + msg = _notify(3, {"flow": ["ws1"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert fetch_calls == [] @pytest.mark.asyncio async def test_on_config_notify_irrelevant_types_ignored(self): - """Test on_config_notify ignores types the gateway doesn't care about""" + """Notifies without flow changes advance version but skip fetch.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 1 fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with non-flow type - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=["prompt"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - # Version should be updated but no fetch - assert len(fetch_calls) == 0 + msg = _notify(2, {"prompt": ["ws1"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert fetch_calls == [] assert config_receiver.config_version == 2 - @pytest.mark.asyncio - async def test_on_config_notify_flow_type_triggers_fetch(self): - """Test on_config_notify fetches for flow-related types""" - mock_backend = Mock() - config_receiver = ConfigReceiver(mock_backend) - config_receiver.config_version = 1 - - fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - - for type_name in ["flow"]: - fetch_calls.clear() - config_receiver.config_version = 1 - - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=[type_name]) - - await config_receiver.on_config_notify(mock_msg, None, None) - - assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}" - @pytest.mark.asyncio async def test_on_config_notify_exception_handling(self): - """Test on_config_notify handles exceptions gracefully""" + """on_config_notify swallows exceptions from message decode.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Create notify message that causes an exception mock_msg = Mock() mock_msg.value.side_effect = Exception("Test exception") @@ -146,19 +124,18 @@ class TestConfigReceiver: await config_receiver.on_config_notify(mock_msg, None, None) @pytest.mark.asyncio - async def test_fetch_and_apply_with_new_flows(self): - """Test fetch_and_apply starts new flows""" + async def test_fetch_and_apply_workspace_starts_new_flows(self): + """fetch_and_apply_workspace starts newly-configured flows.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Mock _create_config_client to return a mock client mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { "flow1": '{"name": "test_flow_1"}', - "flow2": '{"name": "test_flow_2"}' + "flow2": '{"name": "test_flow_2"}', } } @@ -167,36 +144,39 @@ class TestConfigReceiver: config_receiver._create_config_client = Mock(return_value=mock_client) start_flow_calls = [] - async def mock_start_flow(id, flow): - start_flow_calls.append((id, flow)) + + async def mock_start_flow(workspace, id, flow): + start_flow_calls.append((workspace, id, flow)) + config_receiver.start_flow = mock_start_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") assert config_receiver.config_version == 5 - assert "flow1" in config_receiver.flows - assert "flow2" in config_receiver.flows + assert "flow1" in config_receiver.flows["default"] + assert "flow2" in config_receiver.flows["default"] assert len(start_flow_calls) == 2 + assert all(c[0] == "default" for c in start_flow_calls) @pytest.mark.asyncio - async def test_fetch_and_apply_with_removed_flows(self): - """Test fetch_and_apply stops removed flows""" + async def test_fetch_and_apply_workspace_stops_removed_flows(self): + """fetch_and_apply_workspace stops flows no longer configured.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Pre-populate with existing flows config_receiver.flows = { - "flow1": {"name": "test_flow_1"}, - "flow2": {"name": "test_flow_2"} + "default": { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"}, + } } - # Config now only has flow1 mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { - "flow1": '{"name": "test_flow_1"}' + "flow1": '{"name": "test_flow_1"}', } } @@ -205,20 +185,22 @@ class TestConfigReceiver: config_receiver._create_config_client = Mock(return_value=mock_client) stop_flow_calls = [] - async def mock_stop_flow(id, flow): - stop_flow_calls.append((id, flow)) + + async def mock_stop_flow(workspace, id, flow): + stop_flow_calls.append((workspace, id, flow)) + config_receiver.stop_flow = mock_stop_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert "flow1" in config_receiver.flows - assert "flow2" not in config_receiver.flows + assert "flow1" in config_receiver.flows["default"] + assert "flow2" not in config_receiver.flows["default"] assert len(stop_flow_calls) == 1 - assert stop_flow_calls[0][0] == "flow2" + assert stop_flow_calls[0][:2] == ("default", "flow2") @pytest.mark.asyncio - async def test_fetch_and_apply_with_no_flows(self): - """Test fetch_and_apply with empty config""" + async def test_fetch_and_apply_workspace_with_no_flows(self): + """Empty workspace config clears any local flow state.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) @@ -231,88 +213,100 @@ class TestConfigReceiver: mock_client.request.return_value = mock_resp config_receiver._create_config_client = Mock(return_value=mock_client) - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert config_receiver.flows == {} + assert config_receiver.flows.get("default", {}) == {} assert config_receiver.config_version == 1 @pytest.mark.asyncio async def test_start_flow_with_handlers(self): - """Test start_flow method with multiple handlers""" + """start_flow fans out to every registered flow handler.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() - handler1.start_flow = Mock() + handler1.start_flow = AsyncMock() handler2 = Mock() - handler2.start_flow = Mock() + handler2.start_flow = AsyncMock() config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) flow_data = {"name": "test_flow", "steps": []} - await config_receiver.start_flow("flow1", flow_data) + await config_receiver.start_flow("default", "flow1", flow_data) - handler1.start_flow.assert_called_once_with("flow1", flow_data) - handler2.start_flow.assert_called_once_with("flow1", flow_data) + handler1.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) + handler2.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_start_flow_with_handler_exception(self): - """Test start_flow method handles handler exceptions""" + """Handler exceptions in start_flow do not propagate.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler = Mock() - handler.start_flow = Mock(side_effect=Exception("Handler error")) + handler.start_flow = AsyncMock(side_effect=Exception("Handler error")) config_receiver.add_handler(handler) flow_data = {"name": "test_flow", "steps": []} # Should not raise - await config_receiver.start_flow("flow1", flow_data) + await config_receiver.start_flow("default", "flow1", flow_data) - handler.start_flow.assert_called_once_with("flow1", flow_data) + handler.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_stop_flow_with_handlers(self): - """Test stop_flow method with multiple handlers""" + """stop_flow fans out to every registered flow handler.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() - handler1.stop_flow = Mock() + handler1.stop_flow = AsyncMock() handler2 = Mock() - handler2.stop_flow = Mock() + handler2.stop_flow = AsyncMock() config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) flow_data = {"name": "test_flow", "steps": []} - await config_receiver.stop_flow("flow1", flow_data) + await config_receiver.stop_flow("default", "flow1", flow_data) - handler1.stop_flow.assert_called_once_with("flow1", flow_data) - handler2.stop_flow.assert_called_once_with("flow1", flow_data) + handler1.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) + handler2.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_stop_flow_with_handler_exception(self): - """Test stop_flow method handles handler exceptions""" + """Handler exceptions in stop_flow do not propagate.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler = Mock() - handler.stop_flow = Mock(side_effect=Exception("Handler error")) + handler.stop_flow = AsyncMock(side_effect=Exception("Handler error")) config_receiver.add_handler(handler) flow_data = {"name": "test_flow", "steps": []} # Should not raise - await config_receiver.stop_flow("flow1", flow_data) + await config_receiver.stop_flow("default", "flow1", flow_data) - handler.stop_flow.assert_called_once_with("flow1", flow_data) + handler.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @patch('asyncio.create_task') @pytest.mark.asyncio @@ -329,25 +323,25 @@ class TestConfigReceiver: mock_create_task.assert_called_once() @pytest.mark.asyncio - async def test_fetch_and_apply_mixed_flow_operations(self): - """Test fetch_and_apply with mixed add/remove operations""" + async def test_fetch_and_apply_workspace_mixed_flow_operations(self): + """fetch_and_apply_workspace adds, keeps and removes flows in one pass.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Pre-populate config_receiver.flows = { - "flow1": {"name": "test_flow_1"}, - "flow2": {"name": "test_flow_2"} + "default": { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"}, + } } - # Config removes flow1, keeps flow2, adds flow3 mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { "flow2": '{"name": "test_flow_2"}', - "flow3": '{"name": "test_flow_3"}' + "flow3": '{"name": "test_flow_3"}', } } @@ -358,20 +352,22 @@ class TestConfigReceiver: start_calls = [] stop_calls = [] - async def mock_start_flow(id, flow): - start_calls.append((id, flow)) - async def mock_stop_flow(id, flow): - stop_calls.append((id, flow)) + async def mock_start_flow(workspace, id, flow): + start_calls.append((workspace, id, flow)) + + async def mock_stop_flow(workspace, id, flow): + stop_calls.append((workspace, id, flow)) config_receiver.start_flow = mock_start_flow config_receiver.stop_flow = mock_stop_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert "flow1" not in config_receiver.flows - assert "flow2" in config_receiver.flows - assert "flow3" in config_receiver.flows + ws_flows = config_receiver.flows["default"] + assert "flow1" not in ws_flows + assert "flow2" in ws_flows + assert "flow3" in ws_flows assert len(start_calls) == 1 - assert start_calls[0][0] == "flow3" + assert start_calls[0][:2] == ("default", "flow3") assert len(stop_calls) == 1 - assert stop_calls[0][0] == "flow1" + assert stop_calls[0][:2] == ("default", "flow1") diff --git a/tests/unit/test_gateway/test_core_import_export_roundtrip.py b/tests/unit/test_gateway/test_core_import_export_roundtrip.py index 843a2b7b..cb2554ee 100644 --- a/tests/unit/test_gateway/test_core_import_export_roundtrip.py +++ b/tests/unit/test_gateway/test_core_import_export_roundtrip.py @@ -36,7 +36,6 @@ def _ge_response_dict(): "metadata": { "id": "doc-1", "root": "", - "user": "alice", "collection": "testcoll", }, "entities": [ @@ -59,7 +58,6 @@ def _triples_response_dict(): "metadata": { "id": "doc-1", "root": "", - "user": "alice", "collection": "testcoll", }, "triples": [ @@ -73,9 +71,9 @@ def _triples_response_dict(): } -def _make_request(id_="doc-1", user="alice"): +def _make_request(id_="doc-1", workspace="alice"): request = Mock() - request.query = {"id": id_, "user": user} + request.query = {"id": id_, "workspace": workspace} return request @@ -149,12 +147,8 @@ class TestCoreExportWireFormat: msg_type, payload = items[0] assert msg_type == "ge" - # Metadata envelope: only id/user/collection — no stale `m["m"]`. - assert payload["m"] == { - "i": "doc-1", - "u": "alice", - "c": "testcoll", - } + # Metadata envelope: only id/collection — no stale `m["m"]`. + assert payload["m"] == {"i": "doc-1", "c": "testcoll"} # Entities: each carries the *singular* `v` and the term envelope assert len(payload["e"]) == 2 @@ -202,11 +196,7 @@ class TestCoreExportWireFormat: msg_type, payload = items[0] assert msg_type == "t" - assert payload["m"] == { - "i": "doc-1", - "u": "alice", - "c": "testcoll", - } + assert payload["m"] == {"i": "doc-1", "c": "testcoll"} assert len(payload["t"]) == 1 @@ -240,7 +230,7 @@ class TestCoreImportWireFormat: payload = msgpack.packb(( "ge", { - "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "m": {"i": "doc-1", "c": "testcoll"}, "e": [ { "e": {"t": "i", "i": "http://example.org/alice"}, @@ -266,7 +256,7 @@ class TestCoreImportWireFormat: req = captured[0] assert req["operation"] == "put-kg-core" - assert req["user"] == "alice" + assert req["workspace"] == "alice" assert req["id"] == "doc-1" ge = req["graph-embeddings"] @@ -275,7 +265,6 @@ class TestCoreImportWireFormat: assert "metadata" not in ge["metadata"] assert ge["metadata"] == { "id": "doc-1", - "user": "alice", "collection": "default", } @@ -302,7 +291,7 @@ class TestCoreImportWireFormat: payload = msgpack.packb(( "t", { - "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "m": {"i": "doc-1", "c": "testcoll"}, "t": [ { "s": {"t": "i", "i": "http://example.org/alice"}, @@ -407,11 +396,10 @@ class TestCoreImportExportRoundTrip: original = _ge_response_dict()["graph-embeddings"] ge = req["graph-embeddings"] - # The import side overrides id/user from the URL query (intentional), + # The import side overrides id from the URL query (intentional), # so we only round-trip the entity payload itself. assert ge["metadata"]["id"] == original["metadata"]["id"] - assert ge["metadata"]["user"] == original["metadata"]["user"] - + assert len(ge["entities"]) == len(original["entities"]) for got, want in zip(ge["entities"], original["entities"]): assert got["vector"] == want["vector"] diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 4ebcb5b9..f091a46d 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -72,10 +72,10 @@ class TestDispatcherManager: flow_data = {"name": "test_flow", "steps": []} - await manager.start_flow("flow1", flow_data) - - assert "flow1" in manager.flows - assert manager.flows["flow1"] == flow_data + await manager.start_flow("default", "flow1", flow_data) + + assert ("default", "flow1") in manager.flows + assert manager.flows[("default", "flow1")] == flow_data @pytest.mark.asyncio async def test_stop_flow(self): @@ -86,11 +86,11 @@ class TestDispatcherManager: # Pre-populate with a flow flow_data = {"name": "test_flow", "steps": []} - manager.flows["flow1"] = flow_data - - await manager.stop_flow("flow1", flow_data) - - assert "flow1" not in manager.flows + manager.flows[("default", "flow1")] = flow_data + + await manager.stop_flow("default", "flow1", flow_data) + + assert ("default", "flow1") not in manager.flows def test_dispatch_global_service_returns_wrapper(self): """Test dispatch_global_service returns DispatcherWrapper""" @@ -275,12 +275,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \ patch('uuid.uuid4') as mock_uuid: mock_uuid.return_value = "test-uuid" @@ -290,7 +290,7 @@ class TestDispatcherManager: mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__contains__.return_value = True - + params = {"flow": "test_flow", "kind": "triples"} result = await manager.process_flow_import("ws", "running", params) @@ -326,12 +326,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers: mock_dispatchers.__contains__.return_value = False @@ -348,12 +348,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \ patch('uuid.uuid4') as mock_uuid: mock_uuid.return_value = "test-uuid" @@ -404,7 +404,7 @@ class TestDispatcherManager: params = {"flow": "test_flow", "kind": "agent"} result = await manager.process_flow_service("data", "responder", params) - manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent") + manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent") assert result == "flow_result" @pytest.mark.asyncio @@ -415,14 +415,14 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Add flow to the flows dictionary - manager.flows["test_flow"] = {"services": {"agent": {}}} - + manager.flows[("default", "test_flow")] = {"services": {"agent": {}}} + # Pre-populate with existing dispatcher mock_dispatcher = Mock() mock_dispatcher.process = AsyncMock(return_value="cached_result") - manager.dispatchers[("test_flow", "agent")] = mock_dispatcher - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") + manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") mock_dispatcher.process.assert_called_once_with("data", "responder") assert result == "cached_result" @@ -435,7 +435,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "agent": { "request": "agent_request_queue", @@ -443,7 +443,7 @@ class TestDispatcherManager: } } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() mock_dispatcher = Mock() @@ -452,23 +452,23 @@ class TestDispatcherManager: mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__contains__.return_value = True - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") - + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") + # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( backend=mock_backend, request_queue="agent_request_queue", response_queue="agent_response_queue", timeout=120, - consumer="api-gateway-test_flow-agent-request", - subscriber="api-gateway-test_flow-agent-request" + consumer="api-gateway-default-test_flow-agent-request", + subscriber="api-gateway-default-test_flow-agent-request" ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - + # Verify dispatcher was cached - assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher assert result == "new_result" @pytest.mark.asyncio @@ -479,26 +479,26 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "text-load": {"flow": "text_load_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \ patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers: mock_rr_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = True - + mock_dispatcher_class = Mock() mock_dispatcher = Mock() mock_dispatcher.start = AsyncMock() mock_dispatcher.process = AsyncMock(return_value="sender_result") mock_dispatcher_class.return_value = mock_dispatcher mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load") - + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load") + # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( backend=mock_backend, @@ -506,9 +506,9 @@ class TestDispatcherManager: ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - + # Verify dispatcher was cached - assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher assert result == "sender_result" @pytest.mark.asyncio @@ -519,7 +519,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) with pytest.raises(RuntimeError, match="Invalid flow"): - await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent") + await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent") @pytest.mark.asyncio async def test_invoke_flow_service_unsupported_kind_by_flow(self): @@ -529,14 +529,14 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow without agent interface - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "text-completion": {"request": "req", "response": "resp"} } } - + with pytest.raises(RuntimeError, match="This kind not supported by flow"): - await manager.invoke_flow_service("data", "responder", "test_flow", "agent") + await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") @pytest.mark.asyncio async def test_invoke_flow_service_invalid_kind(self): @@ -546,7 +546,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow with interface but unsupported kind - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "invalid-kind": {"request": "req", "response": "resp"} } @@ -558,7 +558,7 @@ class TestDispatcherManager: mock_sender_dispatchers.__contains__.return_value = False with pytest.raises(RuntimeError, match="Invalid kind"): - await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") + await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind") @pytest.mark.asyncio async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): @@ -608,7 +608,7 @@ class TestDispatcherManager: mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver) - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "agent": { "request": "agent_request_queue", @@ -630,7 +630,7 @@ class TestDispatcherManager: mock_rr_dispatchers.__contains__.return_value = True results = await asyncio.gather(*[ - manager.invoke_flow_service("data", "responder", "test_flow", "agent") + manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") for _ in range(5) ]) @@ -638,5 +638,5 @@ class TestDispatcherManager: "Dispatcher class instantiated more than once — duplicate consumer bug" ) assert mock_dispatcher.start.call_count == 1 - assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher assert all(r == "result" for r in results) \ No newline at end of file diff --git a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py index 8eddeba9..4ecfce08 100644 --- a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py +++ b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py @@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing: assert isinstance(sent, EntityContexts) assert isinstance(sent.metadata, Metadata) assert sent.metadata.id == "doc-123" - assert sent.metadata.user == "testuser" assert sent.metadata.collection == "testcollection" assert len(sent.entities) == 2 diff --git a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py index fa277178..09a2d510 100644 --- a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py +++ b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py @@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing: assert isinstance(sent, GraphEmbeddings) assert isinstance(sent.metadata, Metadata) assert sent.metadata.id == "doc-123" - assert sent.metadata.user == "testuser" assert sent.metadata.collection == "testcollection" assert len(sent.entities) == 2 diff --git a/tests/unit/test_gateway/test_rows_import_dispatcher.py b/tests/unit/test_gateway/test_rows_import_dispatcher.py index f029e9a2..0134dd39 100644 --- a/tests/unit/test_gateway/test_rows_import_dispatcher.py +++ b/tests/unit/test_gateway/test_rows_import_dispatcher.py @@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing: # Check metadata assert sent_object.metadata.id == "obj-123" - assert sent_object.metadata.user == "testuser" assert sent_object.metadata.collection == "testcollection" @patch('trustgraph.gateway.dispatch.rows_import.Publisher') diff --git a/tests/unit/test_gateway/test_text_document_translator.py b/tests/unit/test_gateway/test_text_document_translator.py index 84eedefc..da44e798 100644 --- a/tests/unit/test_gateway/test_text_document_translator.py +++ b/tests/unit/test_gateway/test_text_document_translator.py @@ -23,7 +23,6 @@ class TestTextDocumentTranslator: ) assert msg.metadata.id == "doc-1" - assert msg.metadata.user == "alice" assert msg.metadata.collection == "research" assert msg.text == payload.encode("utf-8") diff --git a/tests/unit/test_knowledge_graph/conftest.py b/tests/unit/test_knowledge_graph/conftest.py index 8e8d9e43..d0c47784 100644 --- a/tests/unit/test_knowledge_graph/conftest.py +++ b/tests/unit/test_knowledge_graph/conftest.py @@ -29,10 +29,9 @@ class Triple: self.o = o class Metadata: - def __init__(self, id, user, collection, root=""): + def __init__(self, id, collection, root=""): self.id = id self.root = root - self.user = user self.collection = collection class Triples: @@ -108,7 +107,6 @@ def sample_triples(sample_triple): """Sample Triples batch object""" metadata = Metadata( id="test-doc-123", - user="test_user", collection="test_collection", ) @@ -123,7 +121,6 @@ def sample_chunk(): """Sample text chunk for processing""" metadata = Metadata( id="test-chunk-456", - user="test_user", collection="test_collection", ) diff --git a/tests/unit/test_knowledge_graph/test_agent_extraction.py b/tests/unit/test_knowledge_graph/test_agent_extraction.py index ec985e3b..3c40a2a2 100644 --- a/tests/unit/test_knowledge_graph/test_agent_extraction.py +++ b/tests/unit/test_knowledge_graph/test_agent_extraction.py @@ -322,7 +322,6 @@ This is not JSON at all assert isinstance(sent_triples, Triples) # Check metadata fields individually since implementation creates new Metadata object assert sent_triples.metadata.id == sample_metadata.id - assert sent_triples.metadata.user == sample_metadata.user assert sent_triples.metadata.collection == sample_metadata.collection assert len(sent_triples.triples) == 1 assert sent_triples.triples[0].s.iri == "test:subject" @@ -346,7 +345,6 @@ This is not JSON at all assert isinstance(sent_contexts, EntityContexts) # Check metadata fields individually since implementation creates new Metadata object assert sent_contexts.metadata.id == sample_metadata.id - assert sent_contexts.metadata.user == sample_metadata.user assert sent_contexts.metadata.collection == sample_metadata.collection assert len(sent_contexts.entities) == 1 assert sent_contexts.entities[0].entity.iri == "test:entity" diff --git a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py index f82e4cc8..2d758481 100644 --- a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py +++ b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py @@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic: """Test ExtractedObject creation and properties""" # Arrange metadata = Metadata( - id="test-extraction-001", - user="test_user", + id="test-extraction-001", collection="test_collection", ) @@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic: assert extracted_obj.values[0]["customer_id"] == "CUST001" assert extracted_obj.confidence == 0.95 assert "John Doe" in extracted_obj.source_span - assert extracted_obj.metadata.user == "test_user" def test_config_parsing_error_handling(self): """Test configuration parsing with invalid JSON""" diff --git a/tests/unit/test_knowledge_graph/test_triple_construction.py b/tests/unit/test_knowledge_graph/test_triple_construction.py index e45c69aa..db13a7c1 100644 --- a/tests/unit/test_knowledge_graph/test_triple_construction.py +++ b/tests/unit/test_knowledge_graph/test_triple_construction.py @@ -371,7 +371,6 @@ class TestTripleConstructionLogic: metadata = Metadata( id="test-doc-123", - user="test_user", collection="test_collection", ) @@ -384,7 +383,6 @@ class TestTripleConstructionLogic: # Assert assert isinstance(triples_batch, Triples) assert triples_batch.metadata.id == "test-doc-123" - assert triples_batch.metadata.user == "test_user" assert triples_batch.metadata.collection == "test_collection" assert len(triples_batch.triples) == 2 diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py index eef83e1e..7e7be480 100644 --- a/tests/unit/test_librarian/test_chunked_upload.py +++ b/tests/unit/test_librarian/test_chunked_upload.py @@ -33,12 +33,12 @@ def _make_librarian(min_chunk_size=1): def _make_doc_metadata( - doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc" + doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc" ): meta = MagicMock() meta.id = doc_id meta.kind = kind - meta.user = user + meta.workspace = workspace meta.title = title meta.time = 1700000000 meta.comments = "" @@ -47,27 +47,27 @@ def _make_doc_metadata( def _make_begin_request( - doc_id="doc-1", kind="application/pdf", user="alice", + doc_id="doc-1", kind="application/pdf", workspace="alice", total_size=10_000_000, chunk_size=0 ): req = MagicMock() - req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user) + req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace) req.total_size = total_size req.chunk_size = chunk_size return req -def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"): +def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"): req = MagicMock() req.upload_id = upload_id req.chunk_index = chunk_index - req.user = user + req.workspace = workspace req.content = base64.b64encode(content) return req def _make_session( - user="alice", total_chunks=5, chunk_size=2_000_000, + workspace="alice", total_chunks=5, chunk_size=2_000_000, total_size=10_000_000, chunks_received=None, object_id="obj-1", s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1", ): @@ -76,11 +76,11 @@ def _make_session( if document_metadata is None: document_metadata = json.dumps({ "id": document_id, "kind": "application/pdf", - "user": user, "title": "Test", "time": 1700000000, + "workspace": workspace, "title": "Test", "time": 1700000000, "comments": "", "tags": [], }) return { - "user": user, + "workspace": workspace, "total_chunks": total_chunks, "chunk_size": chunk_size, "total_size": total_size, @@ -259,10 +259,10 @@ class TestUploadChunk: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session - req = _make_upload_chunk_request(user="bob") + req = _make_upload_chunk_request(workspace="bob") with pytest.raises(RequestError, match="Not authorized"): await lib.upload_chunk(req) @@ -353,7 +353,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.complete_upload(req) @@ -375,7 +375,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" await lib.complete_upload(req) @@ -394,7 +394,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" with pytest.raises(RequestError, match="Missing chunks"): await lib.complete_upload(req) @@ -406,7 +406,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-gone" - req.user = "alice" + req.workspace = "alice" with pytest.raises(RequestError, match="not found"): await lib.complete_upload(req) @@ -414,12 +414,12 @@ class TestCompleteUpload: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session req = MagicMock() req.upload_id = "up-1" - req.user = "bob" + req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): await lib.complete_upload(req) @@ -439,7 +439,7 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.abort_upload(req) @@ -456,7 +456,7 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-gone" - req.user = "alice" + req.workspace = "alice" with pytest.raises(RequestError, match="not found"): await lib.abort_upload(req) @@ -464,12 +464,12 @@ class TestAbortUpload: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session req = MagicMock() req.upload_id = "up-1" - req.user = "bob" + req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): await lib.abort_upload(req) @@ -492,7 +492,7 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.get_upload_status(req) @@ -510,7 +510,7 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-expired" - req.user = "alice" + req.workspace = "alice" resp = await lib.get_upload_status(req) @@ -527,7 +527,7 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.get_upload_status(req) @@ -539,12 +539,12 @@ class TestGetUploadStatus: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session req = MagicMock() req.upload_id = "up-1" - req.user = "bob" + req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): await lib.get_upload_status(req) @@ -564,7 +564,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 @@ -587,7 +587,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 @@ -608,7 +608,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 @@ -630,7 +630,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x") req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 0 # Should use default 1MB @@ -649,7 +649,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=raw) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 1000 @@ -666,7 +666,7 @@ class TestStreamDocument: lib.blob_store.get_size = AsyncMock(return_value=5000) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 512 @@ -698,7 +698,7 @@ class TestListUploads: ] req = MagicMock() - req.user = "alice" + req.workspace = "alice" resp = await lib.list_uploads(req) @@ -713,7 +713,7 @@ class TestListUploads: lib.table_store.list_upload_sessions.return_value = [] req = MagicMock() - req.user = "alice" + req.workspace = "alice" resp = await lib.list_uploads(req) diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py index 184560f0..e65ef2e3 100644 --- a/tests/unit/test_provenance/test_dag_structure.py +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -239,7 +239,7 @@ def _make_processor(tools=None): agent = MagicMock() agent.tools = tools or {} agent.additional_context = "" - processor.agent = agent + processor.agents = {"default": agent} processor.aggregator = MagicMock() return processor @@ -254,6 +254,7 @@ def _make_flow(): return producers[name] flow = MagicMock(side_effect=factory) + flow.workspace = "default" return flow @@ -299,7 +300,7 @@ class TestAgentReactDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -344,7 +345,6 @@ class TestAgentReactDagStructure: request1 = AgentRequest( question="What is 6x7?", - user="testuser", collection="default", streaming=False, session_id=session_id, @@ -433,7 +433,7 @@ class TestAgentPlanDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -480,7 +480,6 @@ class TestAgentPlanDagStructure: # Iteration 1: planning request1 = AgentRequest( question="Test?", - user="testuser", collection="default", streaming=False, session_id=session_id, @@ -537,7 +536,7 @@ class TestAgentSupervisorDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -563,7 +562,6 @@ class TestAgentSupervisorDagStructure: request = AgentRequest( question="Research quantum computing", - user="testuser", collection="default", streaming=False, session_id=str(uuid.uuid4()), diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 1cddce97..56ccc398 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: def mock_query_request(self): """Create a mock query request for testing""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 @@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_single_vector(self, processor): """Test querying document embeddings with a single vector""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with correct parameters including user/collection processor.vecstore.search.assert_called_once_with( @@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_longer_vector(self, processor): """Test querying document embeddings with a longer vector""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=3 @@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once_with( @@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_with_limit(self, processor): """Test querying document embeddings respects limit parameter""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=2 @@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with the specified limit processor.vecstore.search.assert_called_once_with( @@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_empty_vectors(self, processor): """Test querying document embeddings with empty vectors list""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[], limit=5 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called processor.vecstore.search.assert_not_called() @@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_empty_search_results(self, processor): """Test querying document embeddings with empty search results""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock empty search results processor.vecstore.search.return_value = [] - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called processor.vecstore.search.assert_called_once_with( @@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_unicode_documents(self, processor): """Test querying document embeddings with Unicode document content""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify Unicode content is preserved in ChunkMatch objects assert len(result) == 3 @@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_large_documents(self, processor): """Test querying document embeddings with large document content""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify large content is preserved in ChunkMatch objects assert len(result) == 2 @@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_special_characters(self, processor): """Test querying document embeddings with special characters in documents""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify special characters are preserved in ChunkMatch objects assert len(result) == 3 @@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_zero_limit(self, processor): """Test querying document embeddings with zero limit""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=0 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called (optimization for zero limit) processor.vecstore.search.assert_not_called() @@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_negative_limit(self, processor): """Test querying document embeddings with negative limit""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=-1 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called (optimization for negative limit) processor.vecstore.search.assert_not_called() @@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_exception_handling(self, processor): """Test exception handling during query processing""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Milvus connection failed"): - await processor.query_document_embeddings(query) + await processor.query_document_embeddings('test_user', query) @pytest.mark.asyncio async def test_query_document_embeddings_different_vector_dimensions(self, processor): """Test querying document embeddings with different vector dimensions""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector limit=5 @@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with the vector processor.vecstore.search.assert_called_once() @@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_multiple_results(self, processor): """Test querying document embeddings with multiple results""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 @@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify results are ChunkMatch objects assert len(result) == 3 diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index 397bdf1b..b50a95b8 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify index was accessed correctly (with dimension suffix) expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results1, mock_results2] - chunks = await processor.query_document_embeddings(mock_query_message) + chunks = await processor.query_document_embeddings('default', mock_query_message) # Verify both queries were made assert mock_index.query.call_count == 2 @@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify limit is passed to query mock_index.query.assert_called_once() @@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results_2d, mock_results_4d] - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify different indexes used for different dimensions assert processor.pinecone.Index.call_count == 2 @@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify no queries were made and empty result returned processor.pinecone.Index.assert_not_called() @@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_results.matches = [] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify empty results assert chunks == [] @@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify Unicode content is properly handled assert len(chunks) == 2 @@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify large content is properly handled assert len(chunks) == 1 @@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify all content types are properly handled assert len(chunks) == 5 @@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = Exception("Query failed") with pytest.raises(Exception, match="Query failed"): - await processor.query_document_embeddings(message) + await processor.query_document_embeddings('test_user', message) @pytest.mark.asyncio async def test_query_document_embeddings_index_access_failure(self, processor): @@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: processor.pinecone.Index.side_effect = Exception("Index access failed") with pytest.raises(Exception, match="Index access failed"): - await processor.query_document_embeddings(message) + await processor.query_document_embeddings('test_user', message) @pytest.mark.asyncio async def test_query_document_embeddings_vector_accumulation(self, processor): @@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3] - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify all queries were made assert mock_index.query.call_count == 3 diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index 1d2f0e6d..3602ad51 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'test_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('test_user', mock_message) # Assert # Verify query was called with correct parameters (with dimension suffix) @@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'multi_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('multi_user', mock_message) # Assert # Verify query was called once @@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'limit_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('limit_user', mock_message) # Assert # Verify query was called with exact limit (no multiplication) @@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'empty_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('empty_user', mock_message) # Assert assert result == [] @@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'dim_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('dim_user', mock_message) # Assert # Verify query was called once with correct collection @@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'utf8_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('utf8_user', mock_message) # Assert assert len(result) == 2 @@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(Exception, match="Qdrant connection failed"): - await processor.query_document_embeddings(mock_message) + await processor.query_document_embeddings('error_user', mock_message) @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'zero_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('zero_user', mock_message) # Assert # Should still query (with limit 0) @@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'large_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('large_user', mock_message) # Assert # Should query with full limit @@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert # This should raise a KeyError when trying to access payload['chunk_id'] with pytest.raises(KeyError): - await processor.query_document_embeddings(mock_message) + await processor.query_document_embeddings('payload_user', mock_message) @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index f2b8be7e..7e5c4df3 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: def mock_query_request(self): """Create a mock query request for testing""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 @@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_single_vector(self, processor): """Test querying graph embeddings with a single vector""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with correct parameters including user/collection processor.vecstore.search.assert_called_once_with( @@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_multiple_results(self, processor): """Test querying graph embeddings returns multiple results""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 @@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once_with( @@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_with_limit(self, processor): """Test querying graph embeddings respects limit parameter""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=2 @@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with 2*limit for better deduplication processor.vecstore.search.assert_called_once_with( @@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_preserves_order(self, processor): """Test that query results preserve order from the vector store""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 @@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify results are in the same order as returned by the store assert len(result) == 3 @@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_results_limited(self, processor): """Test that results are properly limited when store returns more than requested""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=2 @@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with the full vector processor.vecstore.search.assert_called_once_with( @@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_empty_vectors(self, processor): """Test querying graph embeddings with empty vectors list""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[], limit=5 ) - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify no search was called processor.vecstore.search.assert_not_called() @@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_empty_search_results(self, processor): """Test querying graph embeddings with empty search results""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Mock empty search results processor.vecstore.search.return_value = [] - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called processor.vecstore.search.assert_called_once_with( @@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor): """Test querying graph embeddings with mixed URI and literal results""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify all results are properly typed assert len(result) == 4 @@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_exception_handling(self, processor): """Test exception handling during query processing""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Milvus connection failed"): - await processor.query_graph_embeddings(query) + await processor.query_graph_embeddings('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" @@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_zero_limit(self, processor): """Test querying graph embeddings with zero limit""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=0 ) - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify no search was called (optimization for zero limit) processor.vecstore.search.assert_not_called() @@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_longer_vector(self, processor): """Test querying graph embeddings with a longer vector""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], limit=5 @@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once() diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 2c1a673a..0fc8f7c0 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify index was accessed correctly (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(mock_query_message) + entities = await processor.query_graph_embeddings('default', mock_query_message) # Verify query was made once assert mock_index.query.call_count == 1 @@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify limit is respected assert len(entities) == 2 @@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify correct index used for 2D vector processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2") @@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify no queries were made and empty result returned processor.pinecone.Index.assert_not_called() @@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_results.matches = [] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify empty results assert entities == [] @@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Should get exactly 3 unique entities (respecting limit) assert len(entities) == 3 @@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Should only return 2 entities (respecting limit) mock_index.query.assert_called_once() @@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.side_effect = Exception("Query failed") with pytest.raises(Exception, match="Query failed"): - await processor.query_graph_embeddings(message) + await processor.query_graph_embeddings('test_user', message) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 9362a8dd..41b6c8a4 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'test_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('test_user', mock_message) # Assert # Verify query was called with correct parameters (with dimension suffix) @@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'multi_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('multi_user', mock_message) # Assert # Verify query was called once @@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'limit_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('limit_user', mock_message) # Assert # Verify query was called with limit * 2 @@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'empty_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('empty_user', mock_message) # Assert assert result == [] @@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'dim_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('dim_user', mock_message) # Assert # Verify query was called once @@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'uri_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('uri_user', mock_message) # Assert assert len(result) == 3 @@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(Exception, match="Qdrant connection failed"): - await processor.query_graph_embeddings(mock_message) + await processor.query_graph_embeddings('error_user', mock_message) @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'zero_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('zero_user', mock_message) # Assert # Should still query (with limit 0) diff --git a/tests/unit/test_query/test_memgraph_user_collection_query.py b/tests/unit/test_query/test_memgraph_workspace_collection_query.py similarity index 76% rename from tests/unit/test_query/test_memgraph_user_collection_query.py rename to tests/unit/test_query/test_memgraph_workspace_collection_query.py index 038fb438..d0ab242e 100644 --- a/tests/unit/test_query/test_memgraph_user_collection_query.py +++ b/tests/unit/test_query/test_memgraph_workspace_collection_query.py @@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL -class TestMemgraphQueryUserCollectionIsolation: +class TestMemgraphQueryWorkspaceCollectionIsolation: """Test cases for Memgraph query service with user/collection isolation""" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_spo_query_with_user_collection(self, mock_graph_db): + async def test_spo_query_with_workspace_collection(self, mock_graph_db): """Test SPO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SPO query for literal includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT 1000" ) @@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation: src="http://example.com/s", rel="http://example.com/p", value="test_object", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_sp_query_with_user_collection(self, mock_graph_db): + async def test_sp_query_with_workspace_collection(self, mock_graph_db): """Test SP query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SP query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT 1000" ) @@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation: expected_literal_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_so_query_with_user_collection(self, mock_graph_db): + async def test_so_query_with_workspace_collection(self, mock_graph_db): """Test SO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SO query for nodes includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT 1000" ) @@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation: expected_query, src="http://example.com/s", uri="http://example.com/o", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_s_only_query_with_user_collection(self, mock_graph_db): + async def test_s_only_query_with_workspace_collection(self, mock_graph_db): """Test S-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify S query includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT 1000" ) @@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, src="http://example.com/s", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_po_query_with_user_collection(self, mock_graph_db): + async def test_po_query_with_workspace_collection(self, mock_graph_db): """Test PO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify PO query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT 1000" ) @@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation: expected_query, uri="http://example.com/p", value="literal", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_p_only_query_with_user_collection(self, mock_graph_db): + async def test_p_only_query_with_workspace_collection(self, mock_graph_db): """Test P-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify P query includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT 1000" ) @@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, uri="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_o_only_query_with_user_collection(self, mock_graph_db): + async def test_o_only_query_with_workspace_collection(self, mock_graph_db): """Test O-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify O query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT 1000" ) @@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, value="test_value", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_with_user_collection(self, mock_graph_db): + async def test_wildcard_query_with_workspace_collection(self, mock_graph_db): """Test wildcard query (all None) includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify wildcard query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT 1000" ) mock_driver.execute_query.assert_any_call( expected_literal_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) # Verify wildcard query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT 1000" ) mock_driver.execute_query.assert_any_call( expected_node_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples('default', query) # Verify defaults were used calls = mock_driver.execute_query.call_args_list @@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation: ([mock_record2], MagicMock(), MagicMock()) # Node query ] - result = await processor.query_triples(query) + result = await processor.query_triples("test_user", query) # Verify results are proper Triple objects assert len(result) == 2 diff --git a/tests/unit/test_query/test_neo4j_user_collection_query.py b/tests/unit/test_query/test_neo4j_workspace_collection_query.py similarity index 75% rename from tests/unit/test_query/test_neo4j_user_collection_query.py rename to tests/unit/test_query/test_neo4j_workspace_collection_query.py index d9cf1eb4..029de617 100644 --- a/tests/unit/test_query/test_neo4j_user_collection_query.py +++ b/tests/unit/test_query/test_neo4j_workspace_collection_query.py @@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL -class TestNeo4jQueryUserCollectionIsolation: +class TestNeo4jQueryWorkspaceCollectionIsolation: """Test cases for Neo4j query service with user/collection isolation""" @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_spo_query_with_user_collection(self, mock_graph_db): + async def test_spo_query_with_workspace_collection(self, mock_graph_db): """Test SPO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SPO query for literal includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT 10" ) @@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation: src="http://example.com/s", rel="http://example.com/p", value="test_object", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_sp_query_with_user_collection(self, mock_graph_db): + async def test_sp_query_with_workspace_collection(self, mock_graph_db): """Test SP query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SP query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT 10" ) @@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation: expected_literal_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) # Verify SP query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT 10" ) @@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_node_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_so_query_with_user_collection(self, mock_graph_db): + async def test_so_query_with_workspace_collection(self, mock_graph_db): """Test SO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SO query for nodes includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT 10" ) @@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_query, src="http://example.com/s", uri="http://example.com/o", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_s_only_query_with_user_collection(self, mock_graph_db): + async def test_s_only_query_with_workspace_collection(self, mock_graph_db): """Test S-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify S query includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT 10" ) @@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, src="http://example.com/s", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_po_query_with_user_collection(self, mock_graph_db): + async def test_po_query_with_workspace_collection(self, mock_graph_db): """Test PO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify PO query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT 10" ) @@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_query, uri="http://example.com/p", value="literal", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_p_only_query_with_user_collection(self, mock_graph_db): + async def test_p_only_query_with_workspace_collection(self, mock_graph_db): """Test P-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify P query includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT 10" ) @@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, uri="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_o_only_query_with_user_collection(self, mock_graph_db): + async def test_o_only_query_with_workspace_collection(self, mock_graph_db): """Test O-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify O query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT 10" ) @@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, value="test_value", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_with_user_collection(self, mock_graph_db): + async def test_wildcard_query_with_workspace_collection(self, mock_graph_db): """Test wildcard query (all None) includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify wildcard query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT 10" ) mock_driver.execute_query.assert_any_call( expected_literal_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) # Verify wildcard query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT 10" ) mock_driver.execute_query.assert_any_call( expected_node_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples('default', query) # Verify defaults were used calls = mock_driver.execute_query.call_args_list @@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation: ([mock_record2], MagicMock(), MagicMock()) # Node query ] - result = await processor.query_triples(query) + result = await processor.query_triples("test_user", query) # Verify results are proper Triple objects assert len(result) == 2 diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index c0d399c3..bb6bbe84 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic: """Test parsing of schema configuration""" processor = MagicMock() processor.schemas = {} + processor.schema_builders = {} + processor.graphql_schemas = {} processor.config_key = "schema" - processor.schema_builder = MagicMock() - processor.schema_builder.clear = MagicMock() - processor.schema_builder.add_schema = MagicMock() - processor.schema_builder.build = MagicMock(return_value=MagicMock()) + processor.query_cassandra = MagicMock() processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic: } # Process config - await processor.on_schema_config(schema_config, version=1) + await processor.on_schema_config("default", schema_config, version=1) # Verify schema was loaded - assert "customer" in processor.schemas - schema = processor.schemas["customer"] + assert "customer" in processor.schemas["default"] + schema = processor.schemas["default"]["customer"] assert schema.name == "customer" assert len(schema.fields) == 3 @@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic: status_field = next(f for f in schema.fields if f.name == "status") assert status_field.enum_values == ["active", "inactive"] - # Verify schema builder was called - processor.schema_builder.add_schema.assert_called_once() - processor.schema_builder.build.assert_called_once() + # Verify per-workspace schema builder was created and graphql schema built + assert "default" in processor.schema_builders + assert "default" in processor.graphql_schemas @pytest.mark.asyncio async def test_graphql_context_handling(self): """Test GraphQL execution context setup""" processor = MagicMock() - processor.graphql_schema = AsyncMock() + graphql_schema = AsyncMock() + processor.graphql_schemas = {"default": graphql_schema} processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) # Mock schema execution mock_result = MagicMock() mock_result.data = {"customers": [{"id": "1", "name": "Test"}]} mock_result.errors = None - processor.graphql_schema.execute.return_value = mock_result + graphql_schema.execute.return_value = mock_result result = await processor.execute_graphql_query( + workspace="default", query='{ customers { id name } }', variables={}, operation_name=None, - user="test_user", collection="test_collection" ) # Verify schema.execute was called with correct context - processor.graphql_schema.execute.assert_called_once() - call_args = processor.graphql_schema.execute.call_args + graphql_schema.execute.assert_called_once() + call_args = graphql_schema.execute.call_args # Verify context was passed context = call_args[1]['context_value'] assert context["processor"] == processor - assert context["user"] == "test_user" + assert context["workspace"] == "default" assert context["collection"] == "test_collection" # Verify result structure @@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic: async def test_error_handling_graphql_errors(self): """Test GraphQL error handling and conversion""" processor = MagicMock() - processor.graphql_schema = AsyncMock() + graphql_schema = AsyncMock() + processor.graphql_schemas = {"default": graphql_schema} processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) # Create a simple object to simulate GraphQL error @@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic: mock_result = MagicMock() mock_result.data = None mock_result.errors = [mock_error] - processor.graphql_schema.execute.return_value = mock_result + graphql_schema.execute.return_value = mock_result result = await processor.execute_graphql_query( + workspace="default", query='{ customers { invalid_field } }', variables={}, operation_name=None, - user="test_user", collection="test_collection" ) @@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic: # Create mock message mock_msg = MagicMock() mock_request = RowsQueryRequest( - user="test_user", collection="test_collection", query='{ customers { id name } }', variables={}, @@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic: # Mock flow mock_flow = MagicMock() + mock_flow.workspace = "default" mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow @@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic: # Verify query was executed processor.execute_graphql_query.assert_called_once_with( + workspace="default", query='{ customers { id name } }', variables={}, operation_name=None, - user="test_user", collection="test_collection" ) @@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic: # Create mock message mock_msg = MagicMock() mock_request = RowsQueryRequest( - user="test_user", collection="test_collection", query='{ invalid_query }', variables={}, @@ -357,7 +357,7 @@ class TestUnifiedTableQueries: # Query with filter on indexed field results = await processor.query_cassandra( - user="test_user", + workspace="test_workspace", collection="test_collection", schema_name="products", row_schema=schema, @@ -374,7 +374,7 @@ class TestUnifiedTableQueries: query = call_args[0][1] params = call_args[0][2] - assert "SELECT data, source FROM test_user.rows" in query + assert "SELECT data, source FROM test_workspace.rows" in query assert "collection = %s" in query assert "schema_name = %s" in query assert "index_name = %s" in query @@ -421,7 +421,7 @@ class TestUnifiedTableQueries: # Query with filter on non-indexed field results = await processor.query_cassandra( - user="test_user", + workspace="test_workspace", collection="test_collection", schema_name="products", row_schema=schema, diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index b620df7e..09681214 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -95,7 +95,6 @@ class TestCassandraQueryProcessor: # Create query request with all SPO values query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -103,7 +102,7 @@ class TestCassandraQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify KnowledgeGraph was created with correct parameters mock_kg_class.assert_called_once_with( @@ -170,7 +169,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -178,7 +176,7 @@ class TestCassandraQueryProcessor: limit=50 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 @@ -207,7 +205,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -215,7 +212,7 @@ class TestCassandraQueryProcessor: limit=25 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 @@ -244,7 +241,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=LITERAL, value='test_predicate'), @@ -252,7 +248,7 @@ class TestCassandraQueryProcessor: limit=10 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 @@ -281,7 +277,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -289,7 +284,7 @@ class TestCassandraQueryProcessor: limit=75 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 @@ -319,7 +314,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -327,7 +321,7 @@ class TestCassandraQueryProcessor: limit=1000 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 @@ -425,7 +419,6 @@ class TestCassandraQueryProcessor: ) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -433,7 +426,7 @@ class TestCassandraQueryProcessor: limit=100 ) - await processor.query_triples(query) + await processor.query_triples('test_user', query) # Verify KnowledgeGraph was created with authentication mock_kg_class.assert_called_once_with( @@ -463,7 +456,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -472,11 +464,11 @@ class TestCassandraQueryProcessor: ) # First query should create TrustGraph - await processor.query_triples(query) + await processor.query_triples('test_user', query) assert mock_kg_class.call_count == 1 # Second query with same table should reuse TrustGraph - await processor.query_triples(query) + await processor.query_triples('test_user', query) assert mock_kg_class.call_count == 1 # Should not increase @pytest.mark.asyncio @@ -504,7 +496,6 @@ class TestCassandraQueryProcessor: # First query query1 = TriplesQueryRequest( - user='user1', collection='collection1', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -512,12 +503,11 @@ class TestCassandraQueryProcessor: limit=100 ) - await processor.query_triples(query1) + await processor.query_triples('user1', query1) assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( - user='user2', collection='collection2', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -525,7 +515,7 @@ class TestCassandraQueryProcessor: limit=100 ) - await processor.query_triples(query2) + await processor.query_triples('user2', query2) assert processor.table == 'user2' # Verify TrustGraph was created twice @@ -544,7 +534,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -553,7 +542,7 @@ class TestCassandraQueryProcessor: ) with pytest.raises(Exception, match="Query failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -582,7 +571,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -590,7 +578,7 @@ class TestCassandraQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) assert len(result) == 2 assert result[0].o.value == 'object1' @@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations: # PO query pattern (predicate + object, find subjects) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=LITERAL, value='test_predicate'), @@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=50 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify get_po was called (should use optimized po_table) mock_tg_instance.get_po.assert_called_once_with( @@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations: # OS query pattern (object + subject, find predicates) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=25 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify get_os was called (should use optimized subject_table with clustering) mock_tg_instance.get_os.assert_called_once_with( @@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations: mock_tg_instance.reset_mock() query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value=s) if s else None, p=Term(type=LITERAL, value=p) if p else None, @@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=10 ) - await processor.query_triples(query) + await processor.query_triples('test_user', query) # Verify the correct method was called method = getattr(mock_tg_instance, expected_method) @@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations: # This is the query pattern that was slow with ALLOW FILTERING query = TriplesQueryRequest( - user='large_dataset_user', collection='massive_collection', s=None, p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'), @@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=1000 ) - result = await processor.query_triples(query) + result = await processor.query_triples('large_dataset_user', query) # Verify optimized get_po was used (no ALLOW FILTERING needed!) mock_tg_instance.get_po.assert_called_once_with( diff --git a/tests/unit/test_query/test_triples_falkordb_query.py b/tests/unit/test_query/test_triples_falkordb_query.py index d5c047d7..3d7270c6 100644 --- a/tests/unit/test_query/test_triples_falkordb_query.py +++ b/tests/unit/test_query/test_triples_falkordb_query.py @@ -123,7 +123,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -131,7 +130,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -164,7 +163,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -172,7 +170,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -209,7 +207,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -217,7 +214,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -254,7 +251,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -262,7 +258,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -299,7 +295,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -307,7 +302,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -344,7 +339,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -352,7 +346,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -389,7 +383,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -397,7 +390,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -434,7 +427,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -442,7 +434,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -474,7 +466,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -484,7 +475,7 @@ class TestFalkorDBQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_triples_memgraph_query.py b/tests/unit/test_query/test_triples_memgraph_query.py index f4222af1..a21d9008 100644 --- a/tests/unit/test_query/test_triples_memgraph_query.py +++ b/tests/unit/test_query/test_triples_memgraph_query.py @@ -122,7 +122,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -130,7 +129,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -164,7 +163,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -172,7 +170,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -210,7 +208,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -218,7 +215,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -256,7 +253,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -264,7 +260,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -302,7 +298,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -310,7 +305,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -348,7 +343,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -356,7 +350,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -394,7 +388,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -402,7 +395,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -440,7 +433,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -448,7 +440,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -478,7 +470,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -488,7 +479,7 @@ class TestMemgraphQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_triples_neo4j_query.py b/tests/unit/test_query/test_triples_neo4j_query.py index e379ed21..3751a858 100644 --- a/tests/unit/test_query/test_triples_neo4j_query.py +++ b/tests/unit/test_query/test_triples_neo4j_query.py @@ -122,7 +122,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -130,7 +129,7 @@ class TestNeo4jQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -164,7 +163,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -172,7 +170,7 @@ class TestNeo4jQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -210,7 +208,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -218,7 +215,7 @@ class TestNeo4jQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -248,7 +245,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -258,7 +254,7 @@ class TestNeo4jQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py index aded7253..2170c763 100644 --- a/tests/unit/test_reliability/test_metadata_preservation.py +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator: "title": "Test Document", "comments": "No comments", "metadata": [], - "user": "alice", + "workspace": "alice", "tags": ["finance", "q4"], "parent-id": "doc-100", "document-type": "page", @@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator: assert obj.time == 1710000000 assert obj.kind == "application/pdf" assert obj.title == "Test Document" - assert obj.user == "alice" + assert obj.workspace == "alice" assert obj.tags == ["finance", "q4"] assert obj.parent_id == "doc-100" assert obj.document_type == "page" wire = self.tx.encode(obj) assert wire["id"] == "doc-123" - assert wire["user"] == "alice" + assert wire["workspace"] == "alice" assert wire["parent-id"] == "doc-100" assert wire["document-type"] == "page" @@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator: def test_falsy_fields_omitted_from_wire(self): """Empty string fields should be omitted from wire format.""" - obj = DocumentMetadata(id="", time=0, user="") + obj = DocumentMetadata(id="", time=0, workspace="") wire = self.tx.encode(obj) assert "id" not in wire - assert "user" not in wire + assert "workspace" not in wire # --------------------------------------------------------------------------- @@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator: "document-id": "doc-123", "time": 1710000000, "flow": "default", - "user": "alice", + "workspace": "alice", "collection": "my-collection", "tags": ["tag1"], } @@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator: assert obj.id == "proc-1" assert obj.document_id == "doc-123" assert obj.flow == "default" - assert obj.user == "alice" + assert obj.workspace == "alice" assert obj.collection == "my-collection" assert obj.tags == ["tag1"] wire = self.tx.encode(obj) assert wire["id"] == "proc-1" assert wire["document-id"] == "doc-123" - assert wire["user"] == "alice" + assert wire["workspace"] == "alice" assert wire["collection"] == "my-collection" def test_missing_fields_use_defaults(self): obj = self.tx.decode({}) assert obj.id is None - assert obj.user is None + assert obj.workspace is None assert obj.collection is None def test_tags_none_omitted(self): @@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator: wire = self.tx.encode(obj) assert wire["tags"] == [] - def test_user_and_collection_preserved(self): + def test_workspace_and_collection_preserved(self): """Core pipeline routing fields must survive round-trip.""" - data = {"user": "bob", "collection": "research"} + data = {"workspace": "bob", "collection": "research"} obj = self.tx.decode(data) wire = self.tx.encode(obj) - assert wire["user"] == "bob" + assert wire["workspace"] == "bob" assert wire["collection"] == "research" diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 41a5c621..2296e961 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [] # Empty vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) # No upsert should be called proc.qdrant.upsert.assert_not_called() @@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = None # None vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.1, 0.2, 0.3] msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.1, 0.2, 0.3] msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_called_once() @pytest.mark.asyncio @@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "alice" msg.metadata.collection = "docs" emb = MagicMock() @@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.0] * 384 # 384-dim vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("alice", msg) call_args = proc.qdrant.upsert.call_args assert "d_alice_docs_384" in call_args[1]["collection_name"] @@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [0.1, 0.2, 0.3] msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [0.1, 0.2, 0.3] msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [] # Empty vector msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection: entity.chunk_id = "c1" msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_called_once() @pytest.mark.asyncio @@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "alice" msg.metadata.collection = "graphs" entity = MagicMock() @@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection: entity.chunk_id = "" msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("alice", msg) # Collection should be created with correct dimension proc.qdrant.create_collection.assert_called_once() @@ -290,11 +280,10 @@ class TestCollectionValidation: proc.collection_exists = MagicMock(return_value=False) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "deleted-col" msg.chunks = [MagicMock()] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -306,9 +295,8 @@ class TestCollectionValidation: proc.collection_exists = MagicMock(return_value=False) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "deleted-col" msg.entities = [MagicMock()] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 1ff85f5a..fd140b95 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -92,14 +92,13 @@ class TestQuery: # Initialize Query with defaults query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) # Verify initialization assert query.rag == mock_rag - assert query.user == "test_user" assert query.collection == "test_collection" assert query.verbose is False assert query.doc_limit == 20 # Default value @@ -112,7 +111,7 @@ class TestQuery: # Initialize Query with custom doc_limit query = Query( rag=mock_rag, - user="custom_user", + workspace="test_workspace", collection="custom_collection", verbose=True, doc_limit=50 @@ -120,7 +119,6 @@ class TestQuery: # Verify initialization assert query.rag == mock_rag - assert query.user == "custom_user" assert query.collection == "custom_collection" assert query.verbose is True assert query.doc_limit == 50 @@ -137,7 +135,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -162,7 +160,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -184,7 +182,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -223,7 +221,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False, doc_limit=15 @@ -240,7 +238,6 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[0.1, 0.2, 0.3], limit=15, - user="test_user", collection="test_collection" ) @@ -286,7 +283,6 @@ class TestQuery: result = await document_rag.query( query="test query", - user="test_user", collection="test_collection", doc_limit=10 ) @@ -304,7 +300,6 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[0.1, 0.2, 0.3], limit=10, - user="test_user", collection="test_collection" ) @@ -350,7 +345,6 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[[0.1, 0.2]], limit=20, # Default doc_limit - user="trustgraph", # Default user collection="default" # Default collection ) @@ -380,7 +374,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=True, doc_limit=5 @@ -453,7 +447,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -509,7 +503,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=True ) @@ -558,7 +552,6 @@ class TestQuery: result = await document_rag.query( query=query_text, - user="research_user", collection="ml_knowledge", doc_limit=25 ) @@ -619,7 +612,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False, doc_limit=10 diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index a5d42f3a..dde3acc1 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -1,6 +1,6 @@ """ Unit test for DocumentRAG service parameter passing fix. -Tests that user and collection parameters from the message are correctly +Tests that the collection parameter from the message is correctly passed to the DocumentRag.query() method. """ @@ -16,13 +16,13 @@ class TestDocumentRagService: @patch('trustgraph.retrieval.document_rag.rag.DocumentRag') @pytest.mark.asyncio - async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class): + async def test_collection_parameter_passed_to_query(self, mock_document_rag_class): """ - Test that user and collection from message are passed to DocumentRag.query(). - - This is a regression test for the bug where user/collection parameters - were ignored, causing wrong collection names like 'd_trustgraph_default_384' - instead of 'd_my_user_test_coll_1_384'. + Test that collection from message is passed to DocumentRag.query(). + + This is a regression test for the bug where the collection parameter + was ignored, causing wrong collection names like 'd_trustgraph_default_384' + instead of one that reflects the requested collection. """ # Setup processor processor = Processor( @@ -30,17 +30,16 @@ class TestDocumentRagService: id="test-processor", doc_limit=10 ) - + # Setup mock DocumentRag instance mock_rag_instance = AsyncMock() mock_document_rag_class.return_value = mock_rag_instance mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None}) - - # Setup message with custom user/collection + + # Setup message with custom collection msg = MagicMock() msg.value.return_value = DocumentRagQuery( query="test query", - user="my_user", # Custom user (not default "trustgraph") collection="test_coll_1", # Custom collection (not default "default") doc_limit=5 ) @@ -64,7 +63,7 @@ class TestDocumentRagService: # Verify: DocumentRag.query was called with correct parameters mock_rag_instance.query.assert_called_once_with( "test query", - user="my_user", # Must be from message, not hardcoded default + workspace=ANY, # Workspace comes from flow.workspace (mock) collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, explain_callback=ANY, # Explainability callback is always passed @@ -103,7 +102,6 @@ class TestDocumentRagService: msg = MagicMock() msg.value.return_value = DocumentRagQuery( query="What is a cat?", - user="trustgraph", collection="default", doc_limit=10, streaming=False # Non-streaming mode diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 00a9551f..e0f41357 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -78,14 +78,12 @@ class TestQuery: # Initialize Query with defaults query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) # Verify initialization assert query.rag == mock_rag - assert query.user == "test_user" assert query.collection == "test_collection" assert query.verbose is False assert query.entity_limit == 50 # Default value @@ -101,7 +99,6 @@ class TestQuery: # Initialize Query with custom parameters query = Query( rag=mock_rag, - user="custom_user", collection="custom_collection", verbose=True, entity_limit=100, @@ -112,7 +109,6 @@ class TestQuery: # Verify initialization assert query.rag == mock_rag - assert query.user == "custom_user" assert query.collection == "custom_collection" assert query.verbose is True assert query.entity_limit == 100 @@ -133,7 +129,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -156,7 +151,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=True ) @@ -177,7 +171,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -201,7 +194,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -244,7 +236,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, entity_limit=25 @@ -269,7 +260,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -277,7 +267,7 @@ class TestQuery: result = await query.maybe_label("entity1") assert result == "Entity One Label" - mock_cache.get.assert_called_once_with("test_user:test_collection:entity1") + mock_cache.get.assert_called_once_with("test_collection:entity1") @pytest.mark.asyncio async def test_maybe_label_with_label_lookup(self): @@ -295,7 +285,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -307,13 +296,12 @@ class TestQuery: p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, - user="test_user", collection="test_collection", g="" ) assert result == "Human Readable Label" - cache_key = "test_user:test_collection:http://example.com/entity" + cache_key = "test_collection:http://example.com/entity" mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label") @pytest.mark.asyncio @@ -330,7 +318,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -342,13 +329,12 @@ class TestQuery: p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, - user="test_user", collection="test_collection", g="" ) assert result == "unlabeled_entity" - cache_key = "test_user:test_collection:unlabeled_entity" + cache_key = "test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") @pytest.mark.asyncio @@ -375,7 +361,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, triple_limit=10 @@ -388,15 +373,15 @@ class TestQuery: mock_triples_client.query_stream.assert_any_call( s="entity1", p=None, o=None, limit=10, - user="test_user", collection="test_collection", batch_size=20, g="" + collection="test_collection", batch_size=20, g="" ) mock_triples_client.query_stream.assert_any_call( s=None, p="entity1", o=None, limit=10, - user="test_user", collection="test_collection", batch_size=20, g="" + collection="test_collection", batch_size=20, g="" ) mock_triples_client.query_stream.assert_any_call( s=None, p=None, o="entity1", limit=10, - user="test_user", collection="test_collection", batch_size=20, g="" + collection="test_collection", batch_size=20, g="" ) expected_subgraph = { @@ -415,7 +400,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -435,7 +419,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, max_subgraph_size=2 @@ -455,7 +438,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, max_path_length=1 @@ -493,7 +475,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, max_subgraph_size=100 @@ -601,7 +582,6 @@ class TestQuery: try: response = await graph_rag.query( query="test query", - user="test_user", collection="test_collection", entity_limit=25, triple_limit=15, diff --git a/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py index 603bd204..5208bf7f 100644 --- a/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py +++ b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py @@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="What is quantum computing?", - user="trustgraph", collection="default", streaming=False, ) diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py index 606aa7fe..a637a350 100644 --- a/tests/unit/test_retrieval/test_graph_rag_service.py +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -52,7 +52,6 @@ class TestGraphRagService: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="What is a cat?", - user="trustgraph", collection="default", entity_limit=50, triple_limit=30, @@ -123,7 +122,6 @@ class TestGraphRagService: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="What is a cat?", - user="trustgraph", collection="default", entity_limit=50, triple_limit=30, @@ -190,7 +188,6 @@ class TestGraphRagService: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="Test query", - user="trustgraph", collection="default", streaming=False ) diff --git a/tests/unit/test_retrieval/test_nlp_query.py b/tests/unit/test_retrieval/test_nlp_query.py index 1fd35c2e..cc285aea 100644 --- a/tests/unit/test_retrieval/test_nlp_query.py +++ b/tests/unit/test_retrieval/test_nlp_query.py @@ -286,11 +286,11 @@ class TestNLPQueryProcessor: } # Act - await processor.on_schema_config(config, "v1") + await processor.on_schema_config("default", config, "v1") # Assert - assert "test_schema" in processor.schemas - schema = processor.schemas["test_schema"] + assert "test_schema" in processor.schemas["default"] + schema = processor.schemas["default"]["test_schema"] assert schema.name == "test_schema" assert schema.description == "Test schema" assert len(schema.fields) == 2 @@ -308,10 +308,10 @@ class TestNLPQueryProcessor: } # Act - await processor.on_schema_config(config, "v1") + await processor.on_schema_config("default", config, "v1") # Assert - bad schema should be ignored - assert "bad_schema" not in processor.schemas + assert "bad_schema" not in processor.schemas.get("default", {}) def test_processor_initialization(self, mock_pulsar_client): """Test processor initialization with correct specifications""" diff --git a/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py b/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py index 8ce1b97e..45ba9fda 100644 --- a/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py +++ b/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py @@ -101,7 +101,7 @@ def service(mock_schemas): taskgroup=MagicMock(), id="test-processor" ) - service.schemas = mock_schemas + service.schemas = {"default": dict(mock_schemas)} return service @@ -109,6 +109,7 @@ def service(mock_schemas): def mock_flow(): """Create mock flow with prompt service""" flow = MagicMock() + flow.workspace = "default" prompt_request_flow = AsyncMock() flow.return_value.request = prompt_request_flow return flow, prompt_request_flow diff --git a/tests/unit/test_retrieval/test_structured_query.py b/tests/unit/test_retrieval/test_structured_query.py index 9a183f45..20056c2a 100644 --- a/tests/unit/test_retrieval/test_structured_query.py +++ b/tests/unit/test_retrieval/test_structured_query.py @@ -44,7 +44,6 @@ class TestStructuredQueryProcessor: # Arrange request = StructuredQueryRequest( question="Show me all customers from New York", - user="trustgraph", collection="default" ) @@ -110,7 +109,6 @@ class TestStructuredQueryProcessor: assert isinstance(objects_call_args, RowsQueryRequest) assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }' assert objects_call_args.variables == {"state": "NY"} - assert objects_call_args.user == "trustgraph" assert objects_call_args.collection == "default" # Verify response diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index f9d60541..830da334 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test document embeddings @@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings for a single chunk""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify insert was called once for the single chunk with its vector processor.vecstore.insert.assert_called_once_with( @@ -99,14 +97,14 @@ class TestMilvusDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): """Test storing document embeddings for multiple chunks""" - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('test_workspace', mock_message) - # Verify insert was called once per chunk with user/collection parameters + # Verify insert was called once per chunk with workspace/collection parameters expected_calls = [ # Chunk 1 - single vector - ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_workspace', 'test_collection'), # Chunk 2 - single vector - ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_workspace', 'test_collection'), ] assert processor.vecstore.insert.call_count == 2 @@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunk (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no insert was called for empty chunk processor.vecstore.insert.assert_not_called() @@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with None chunk_id""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Note: Implementation passes through None chunk_ids (only skips empty string "") processor.vecstore.insert.assert_called_once_with( @@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with mix of valid and empty chunks""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' valid_chunk = ChunkEmbeddings( @@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [valid_chunk, empty_chunk, another_valid] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify valid chunks were inserted, empty string chunk was skipped expected_calls = [ @@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunks list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.chunks = [] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no insert was called processor.vecstore.insert.assert_not_called() @@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings for chunk with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no insert was called (no vectors to insert) processor.vecstore.insert.assert_not_called() @@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each chunk has a single vector of different dimensions @@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk1, chunk2, chunk3] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify all vectors were inserted regardless of dimension with user/collection parameters expected_calls = [ @@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with Unicode content in chunk_id""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify Unicode chunk_id was stored correctly with user/collection parameters processor.vecstore.insert.assert_called_once_with( @@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with long chunk_id""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a long chunk_id @@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify long chunk_id was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( @@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with whitespace-only chunk""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify whitespace content was inserted (not filtered out) with user/collection parameters processor.vecstore.insert.assert_called_once_with( @@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor: ('test@domain.com', 'test-collection.v1'), ] - for user, collection in test_cases: + for workspace, collection in test_cases: processor.vecstore.reset_mock() # Reset mock for each test case - + message = MagicMock() message.metadata = MagicMock() - message.metadata.user = user message.metadata.collection = collection - + chunk = ChunkEmbeddings( chunk_id="Test content", vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - - await processor.store_document_embeddings(message) - - # Verify insert was called with the correct user/collection + + await processor.store_document_embeddings(workspace, message) + + # Verify insert was called with the correct workspace/collection processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Test content", user, collection + [0.1, 0.2, 0.3], "Test content", workspace, collection ) @pytest.mark.asyncio @@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Store embeddings for user1/collection1 message1 = MagicMock() message1.metadata = MagicMock() - message1.metadata.user = 'user1' message1.metadata.collection = 'collection1' chunk1 = ChunkEmbeddings( chunk_id="User1 content", @@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Store embeddings for user2/collection2 message2 = MagicMock() message2.metadata = MagicMock() - message2.metadata.user = 'user2' message2.metadata.collection = 'collection2' chunk2 = ChunkEmbeddings( chunk_id="User2 content", @@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message2.chunks = [chunk2] - await processor.store_document_embeddings(message1) - await processor.store_document_embeddings(message2) + await processor.store_document_embeddings('user1', message1) + await processor.store_document_embeddings('user2', message2) # Verify both calls were made with correct parameters expected_calls = [ @@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with special characters in user/collection names""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'user@domain.com' # Email-like user message.metadata.collection = 'test-collection.v1' # Collection with special chars - + chunk = ChunkEmbeddings( chunk_id="Special chars test", vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - - await processor.store_document_embeddings(message) - - # Verify the exact user/collection strings are passed (sanitization happens in DocVectors) + + await processor.store_document_embeddings('user@domain.com', message) + + # Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors) processor.vecstore.insert.assert_called_once_with( [0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1' ) diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py index fec4f87e..011780ed 100644 --- a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test document embeddings @@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings for a single chunk""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', side_effect=['id1', 'id2']): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index name and operations (with dimension suffix) expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test that writing to non-existent index creates it lazily""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index was created with correct dimension expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunk (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called for empty chunk mock_index.upsert.assert_not_called() @@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with None chunk (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called for None chunk mock_index.upsert.assert_not_called() @@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with chunk that decodes to empty string""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called for empty decoded chunk mock_index.upsert.assert_not_called() @@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each chunk has a single vector of different dimensions @@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunks list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.chunks = [] mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no operations were performed processor.pinecone.Index.assert_not_called() @@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings for chunk with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called (no vectors to insert) mock_index.upsert.assert_not_called() @@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test that lazy creation happens when index doesn't exist""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index was created processor.pinecone.create_index.assert_called_once() @@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test that lazy creation works correctly""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index was created and used processor.pinecone.create_index.assert_called_once() @@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with Unicode content""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify Unicode content was properly decoded and stored call_args = mock_index.upsert.call_args @@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with large document chunks""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a large document chunk @@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify large content was stored call_args = mock_index.upsert.call_args diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index 98d2dab2..ce6e6b3d 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with chunks and vectors mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_chunk = MagicMock() @@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('test_user', mock_message) # Assert # Verify collection existence was checked (with dimension suffix) @@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with multiple chunks mock_message = MagicMock() - mock_message.metadata.user = 'multi_user' mock_message.metadata.collection = 'multi_collection' mock_chunk1 = MagicMock() @@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk1, mock_chunk2] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('multi_user', mock_message) # Assert # Should be called twice (once per chunk) @@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with multiple chunks, each having a single vector mock_message = MagicMock() - mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' mock_chunk1 = MagicMock() @@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('vector_user', mock_message) # Assert # Should be called 3 times (once per chunk) @@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with empty chunk_id mock_message = MagicMock() - mock_message.metadata.user = 'empty_user' mock_message.metadata.collection = 'empty_collection' mock_chunk_empty = MagicMock() @@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk_empty] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('empty_user', mock_message) # Assert # Should not call upsert for empty chunk_ids @@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'new_user' mock_message.metadata.collection = 'new_collection' mock_chunk = MagicMock() @@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('new_user', mock_message) # Assert - collection should be lazily created expected_collection = 'd_new_user_new_collection_5' # 5 dimensions @@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'error_user' mock_message.metadata.collection = 'error_collection' mock_chunk = MagicMock() @@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Act & Assert - should propagate the creation error with pytest.raises(Exception, match="Connection error"): - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('error_user', mock_message) @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create first mock message mock_message1 = MagicMock() - mock_message1.metadata.user = 'cache_user' mock_message1.metadata.collection = 'cache_collection' mock_chunk1 = MagicMock() @@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message1.chunks = [mock_chunk1] # First call - await processor.store_document_embeddings(mock_message1) + await processor.store_document_embeddings('cache_user', mock_message1) # Reset mock to track second call mock_qdrant_instance.reset_mock() @@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create second mock message with same dimensions mock_message2 = MagicMock() - mock_message2.metadata.user = 'cache_user' mock_message2.metadata.collection = 'cache_collection' mock_chunk2 = MagicMock() @@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message2.chunks = [mock_chunk2] # Act - Second call with same collection - await processor.store_document_embeddings(mock_message2) + await processor.store_document_embeddings('cache_user', mock_message2) # Assert expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions @@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with chunks of different dimensions mock_message = MagicMock() - mock_message.metadata.user = 'dim_user' mock_message.metadata.collection = 'dim_collection' mock_chunk1 = MagicMock() @@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk1, mock_chunk2] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('dim_user', mock_message) # Assert # Should check existence of DIFFERENT collections for each dimension @@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with URI-style chunk_id mock_message = MagicMock() - mock_message.metadata.user = 'uri_user' mock_message.metadata.collection = 'uri_collection' mock_chunk = MagicMock() @@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('uri_user', mock_message) # Assert # Verify the chunk_id was stored correctly diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index e4d60adf..7f3e7469 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test entities with embeddings @@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for a single entity""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify insert was called once with the full vector processor.vecstore.insert.assert_called_once() @@ -102,14 +100,14 @@ class TestMilvusGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): """Test storing graph embeddings for multiple entities""" - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('test_workspace', mock_message) - # Verify insert was called once per entity with user/collection parameters + # Verify insert was called once per entity with workspace/collection parameters expected_calls = [ # Entity 1 - single vector - ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_workspace', 'test_collection'), # Entity 2 - single vector - ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], 'literal entity', 'test_workspace', 'test_collection'), ] assert processor.vecstore.insert.call_count == 2 @@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called for empty entity processor.vecstore.insert.assert_not_called() @@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with None entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called for None entity processor.vecstore.insert.assert_not_called() @@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with mix of valid and invalid entities""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' valid_entity = EntityEmbeddings( @@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [valid_entity, empty_entity, none_entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify only valid entity was inserted with user/collection/chunk_id parameters processor.vecstore.insert.assert_called_once_with( @@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entities list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.entities = [] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called processor.vecstore.insert.assert_not_called() @@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for entity with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called (no vectors to insert) processor.vecstore.insert.assert_not_called() @@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each entity has a single vector of different dimensions @@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity1, entity2, entity3] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify all vectors were inserted regardless of dimension expected_calls = [ @@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for both URI and literal entities""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' uri_entity = EntityEmbeddings( @@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [uri_entity, literal_entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify both entities were inserted expected_calls = [ diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py index 9ff53f4e..e0e5ce26 100644 --- a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test entity embeddings (each entity has a single vector) @@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for a single entity""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', side_effect=['id1']): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index name and operations (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test that writing to non-existent index creates it lazily""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index was created with correct dimension expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no upsert was called for empty entity mock_index.upsert.assert_not_called() @@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with None entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no upsert was called for None entity mock_index.upsert.assert_not_called() @@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each entity has a single vector of different dimensions @@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify different indexes were used for different dimensions index_calls = processor.pinecone.Index.call_args_list @@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entities list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.entities = [] mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no operations were performed processor.pinecone.Index.assert_not_called() @@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for entity with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no upsert was called (no vectors to insert) mock_index.upsert.assert_not_called() @@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test that lazy creation happens when index doesn't exist""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index was created processor.pinecone.create_index.assert_called_once() @@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test that lazy creation works correctly""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index was created and used processor.pinecone.create_index.assert_called_once() diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index 3541ccd4..d636e093 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with entities and vectors mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_entity = MagicMock() @@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('test_user', mock_message) # Assert # Verify collection existence was checked (with dimension suffix) @@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with multiple entities mock_message = MagicMock() - mock_message.metadata.user = 'multi_user' mock_message.metadata.collection = 'multi_collection' mock_entity1 = MagicMock() @@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity1, mock_entity2] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('multi_user', mock_message) # Assert # Should be called twice (once per entity) @@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with three entities mock_message = MagicMock() - mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' mock_entity1 = MagicMock() @@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity1, mock_entity2, mock_entity3] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('vector_user', mock_message) # Assert # Should be called 3 times (once per entity) @@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with empty entity value mock_message = MagicMock() - mock_message.metadata.user = 'empty_user' mock_message.metadata.collection = 'empty_collection' mock_entity_empty = MagicMock() @@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity_empty, mock_entity_none] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('empty_user', mock_message) # Assert # Should not call upsert for empty entities diff --git a/tests/unit/test_storage/test_memgraph_user_collection_isolation.py b/tests/unit/test_storage/test_memgraph_workspace_collection_isolation.py similarity index 53% rename from tests/unit/test_storage/test_memgraph_user_collection_isolation.py rename to tests/unit/test_storage/test_memgraph_workspace_collection_isolation.py index 9c330b77..ebc142f3 100644 --- a/tests/unit/test_storage/test_memgraph_user_collection_isolation.py +++ b/tests/unit/test_storage/test_memgraph_workspace_collection_isolation.py @@ -1,5 +1,5 @@ """ -Tests for Memgraph user/collection isolation in storage service +Tests for Memgraph workspace/collection isolation in storage service. """ import pytest @@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch from trustgraph.storage.triples.memgraph.write import Processor -class TestMemgraphUserCollectionIsolation: - """Test cases for Memgraph storage service with user/collection isolation""" +class TestMemgraphWorkspaceCollectionIsolation: + """Test cases for Memgraph storage service with workspace/collection isolation""" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): - """Test that storage creates both legacy and user/collection indexes""" + def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db): + """Test that storage creates both legacy and workspace/collection indexes""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = Processor(taskgroup=MagicMock()) - - # Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total) + + # 4 legacy + 4 workspace/collection = 8 total assert mock_session.run.call_count == 8 - - # Check some specific index creation calls + expected_calls = [ "CREATE INDEX ON :Node", "CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Literal", "CREATE INDEX ON :Literal(value)", - "CREATE INDEX ON :Node(user)", + "CREATE INDEX ON :Node(workspace)", "CREATE INDEX ON :Node(collection)", - "CREATE INDEX ON :Literal(user)", + "CREATE INDEX ON :Literal(workspace)", "CREATE INDEX ON :Literal(collection)" ] - + for expected_call in expected_calls: mock_session.run.assert_any_call(expected_call) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_user_collection(self, mock_graph_db): - """Test that store_triples includes user/collection in all operations""" + async def test_store_triples_with_workspace_collection(self, mock_graph_db): + """Test that store_triples includes workspace/collection in all operations""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - # Mock execute_query response mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 @@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) - # Create mock triple with URI object + from trustgraph.schema import IRI triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" - triple.o.value = "http://example.com/object" - triple.o.is_uri = True + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = IRI + triple.o.iri = "http://example.com/object" - # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) - # Verify user/collection parameters were passed to all operations - # Should have: create_node (subject), create_node (object), relate_node = 3 calls + # create_node (subject), create_node (object), relate_node = 3 calls assert mock_driver.execute_query.call_count == 3 - # Check that user and collection were included in all calls - for call in mock_driver.execute_query.call_args_list: - call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert 'user' in call_kwargs - assert 'collection' in call_kwargs - assert call_kwargs['user'] == "test_user" - assert call_kwargs['collection'] == "test_collection" + for c in mock_driver.execute_query.call_args_list: + kwargs = c.kwargs + assert kwargs['workspace'] == "test_workspace" + assert kwargs['collection'] == "test_collection" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_default_user_collection(self, mock_graph_db): - """Test that defaults are used when user/collection not provided in metadata""" + async def test_store_triples_with_default_collection(self, mock_graph_db): + """Test that default collection is used when not provided in metadata""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - # Mock execute_query response mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 @@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) - # Create mock triple + from trustgraph.schema import IRI, LITERAL triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = LITERAL triple.o.value = "literal_value" - triple.o.is_uri = False - # Create mock message without user/collection metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = None mock_message.metadata.collection = None - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("default", mock_message) - # Verify defaults were used - for call in mock_driver.execute_query.call_args_list: - call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert call_kwargs['user'] == "default" - assert call_kwargs['collection'] == "default" + for c in mock_driver.execute_query.call_args_list: + kwargs = c.kwargs + assert kwargs['workspace'] == "default" + assert kwargs['collection'] == "default" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_create_node_includes_user_collection(self, mock_graph_db): - """Test that create_node includes user/collection properties""" + def test_create_node_includes_workspace_collection(self, mock_graph_db): + """Test that create_node includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - - processor.create_node("http://example.com/node", "test_user", "test_collection") - + + processor.create_node("http://example.com/node", "test_workspace", "test_collection") + mock_driver.execute_query.assert_called_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/node", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_create_literal_includes_user_collection(self, mock_graph_db): - """Test that create_literal includes user/collection properties""" + def test_create_literal_includes_workspace_collection(self, mock_graph_db): + """Test that create_literal includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - - processor.create_literal("test_value", "test_user", "test_collection") - + + processor.create_literal("test_value", "test_workspace", "test_collection") + mock_driver.execute_query.assert_called_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value="test_value", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_relate_node_includes_user_collection(self, mock_graph_db): - """Test that relate_node includes user/collection properties""" + def test_relate_node_includes_workspace_collection(self, mock_graph_db): + """Test that relate_node includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 0 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - + processor.relate_node( "http://example.com/subject", - "http://example.com/predicate", + "http://example.com/predicate", "http://example.com/object", - "test_user", + "test_workspace", "test_collection" ) - + mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="http://example.com/object", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_relate_literal_includes_user_collection(self, mock_graph_db): - """Test that relate_literal includes user/collection properties""" + def test_relate_literal_includes_workspace_collection(self, mock_graph_db): + """Test that relate_literal includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 0 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - + processor.relate_literal( "http://example.com/subject", "http://example.com/predicate", "literal_value", - "test_user", + "test_workspace", "test_collection" ) - + mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal_value", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation: def test_add_args_includes_memgraph_parameters(self): """Test that add_args properly configures Memgraph-specific parameters""" from argparse import ArgumentParser - from unittest.mock import patch - + parser = ArgumentParser() - - # Mock the parent class add_args method + with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args: Processor.add_args(parser) - - # Verify parent add_args was called mock_parent_add_args.assert_called_once() - - # Verify our specific arguments were added with Memgraph defaults + args = parser.parse_args([]) - + assert hasattr(args, 'graph_host') assert args.graph_host == 'bolt://memgraph:7687' assert hasattr(args, 'username') @@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation: assert args.database == 'memgraph' -class TestMemgraphUserCollectionRegression: - """Regression tests to ensure user/collection isolation prevents data leakage""" +class TestMemgraphWorkspaceCollectionRegression: + """Regression tests to ensure workspace/collection isolation prevents data leakage""" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_regression_no_cross_user_data_access(self, mock_graph_db): - """Regression test: Ensure users cannot access each other's data""" + async def test_regression_no_cross_workspace_data_access(self, mock_graph_db): + """Regression test: Ensure workspaces cannot access each other's data""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - # Mock execute_query response mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 @@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression: processor = Processor(taskgroup=MagicMock()) - # Store data for user1 + from trustgraph.schema import IRI, LITERAL triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" - triple.o.value = "user1_data" - triple.o.is_uri = False + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = LITERAL + triple.o.value = "ws1_data" - message_user1 = MagicMock() - message_user1.triples = [triple] - message_user1.metadata.user = "user1" - message_user1.metadata.collection = "collection1" + message_ws1 = MagicMock() + message_ws1.triples = [triple] + message_ws1.metadata.collection = "collection1" - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message_user1) + await processor.store_triples("workspace1", message_ws1) - # Verify that all storage operations included user1/collection1 parameters - for call in mock_driver.execute_query.call_args_list: - call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - if 'user' in call_kwargs: - assert call_kwargs['user'] == "user1" - assert call_kwargs['collection'] == "collection1" + for c in mock_driver.execute_query.call_args_list: + kwargs = c.kwargs + if 'workspace' in kwargs: + assert kwargs['workspace'] == "workspace1" + assert kwargs['collection'] == "collection1" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_regression_same_uri_different_users(self, mock_graph_db): - """Regression test: Same URI can exist for different users without conflict""" + async def test_regression_same_uri_different_workspaces(self, mock_graph_db): + """Regression test: Same URI can exist in different workspaces without conflict""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - - # Same URI for different users should create separate nodes - processor.create_node("http://example.com/same-uri", "user1", "collection1") - processor.create_node("http://example.com/same-uri", "user2", "collection2") - - # Verify both calls were made with different user/collection parameters - calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls - - call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1] - call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1] - - assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1" - assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2" - - # Both should have the same URI but different user/collection - assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri" \ No newline at end of file + + processor.create_node("http://example.com/same-uri", "workspace1", "collection1") + processor.create_node("http://example.com/same-uri", "workspace2", "collection2") + + calls = mock_driver.execute_query.call_args_list[-2:] + + k1 = calls[0].kwargs + k2 = calls[1].kwargs + + assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1" + assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2" + + assert k1['uri'] == k2['uri'] == "http://example.com/same-uri" diff --git a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py b/tests/unit/test_storage/test_neo4j_workspace_collection_isolation.py similarity index 51% rename from tests/unit/test_storage/test_neo4j_user_collection_isolation.py rename to tests/unit/test_storage/test_neo4j_workspace_collection_isolation.py index dce170a7..967c144d 100644 --- a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py +++ b/tests/unit/test_storage/test_neo4j_workspace_collection_isolation.py @@ -1,5 +1,5 @@ """ -Tests for Neo4j user/collection isolation in triples storage and query +Tests for Neo4j workspace/collection isolation in triples storage and query. """ import pytest @@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL from trustgraph.schema import TriplesQueryRequest -class TestNeo4jUserCollectionIsolation: - """Test cases for Neo4j user/collection isolation functionality""" +class TestNeo4jWorkspaceCollectionIsolation: + """Test cases for Neo4j workspace/collection isolation functionality""" @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') - def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): - """Test that storage service creates compound indexes for user/collection""" + def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db): + """Test that storage service creates compound indexes for workspace/collection""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Verify both legacy and new compound indexes are created + expected_indexes = [ "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", - "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", - "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", - "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)", + "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)", + "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)", "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" ] - - # Check that all expected indexes were created + for expected_query in expected_indexes: mock_session.run.assert_any_call(expected_query) @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_user_collection(self, mock_graph_db): - """Test that triples are stored with user/collection properties""" + async def test_store_triples_with_workspace_collection(self, mock_graph_db): + """Test that triples are stored with workspace/collection properties""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Create test message with user/collection metadata - metadata = Metadata( - id="test-id", - user="test_user", - collection="test_collection" - ) - + + metadata = Metadata(id="test-id", collection="test_collection") + triple = Triple( s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), o=Term(type=LITERAL, value="literal_value") ) - - message = Triples( - metadata=metadata, - triples=[triple] - ) - - # Mock execute_query to return summaries + + message = Triples(metadata=metadata, triples=[triple]) + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) - - # Verify nodes and relationships were created with user/collection properties + await processor.store_triples("test_workspace", message) + expected_calls = [ call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/subject", - user="test_user", + workspace="test_workspace", collection="test_collection", database_='neo4j' ), call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value="literal_value", - user="test_user", + workspace="test_workspace", collection="test_collection", database_='neo4j' ), call( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal_value", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_='neo4j' ) ] - + for expected_call in expected_calls: mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs) @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_default_user_collection(self, mock_graph_db): - """Test that default user/collection are used when not provided""" + async def test_store_triples_with_default_collection(self, mock_graph_db): + """Test that default collection is used when not provided""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Create test message without user/collection + metadata = Metadata(id="test-id") - + triple = Triple( s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), o=Term(type=IRI, iri="http://example.com/object") ) - - message = Triples( - metadata=metadata, - triples=[triple] - ) - - # Mock execute_query + + message = Triples(metadata=metadata, triples=[triple]) + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) - - # Verify defaults were used + await processor.store_triples("default", message) + mock_driver.execute_query.assert_any_call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/subject", - user="default", + workspace="default", collection="default", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_query_triples_filters_by_user_collection(self, mock_graph_db): - """Test that query service filters results by user/collection""" + async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db): + """Test that query service filters results by workspace/collection""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # Create test query + query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), o=None ) - - # Mock query results + mock_records = [ MagicMock(data=lambda: {"dest": "http://example.com/object1"}), MagicMock(data=lambda: {"dest": "literal_value"}) ] - + mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock()) - - result = await processor.query_triples(query) - - # Verify queries include user/collection filters + + await processor.query_triples("test_workspace", query) + expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest" ) - - expected_node_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " - "RETURN dest.uri as dest" - ) - - # Check that queries were executed with user/collection parameters + calls = mock_driver.execute_query.call_args_list assert any( - expected_literal_query in str(call) and - "user='test_user'" in str(call) and - "collection='test_collection'" in str(call) - for call in calls + expected_literal_query in str(c) and + "workspace='test_workspace'" in str(c) and + "collection='test_collection'" in str(c) + for c in calls ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_query_triples_with_default_user_collection(self, mock_graph_db): - """Test that query service uses defaults when user/collection not provided""" + async def test_query_triples_with_default_collection(self, mock_graph_db): + """Test that query service uses default collection when not provided""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # Create test query without user/collection - query = TriplesQueryRequest( - s=None, - p=None, - o=None - ) - - # Mock empty results + + query = TriplesQueryRequest(s=None, p=None, o=None) + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - - result = await processor.query_triples(query) - - # Verify defaults were used in queries + + await processor.query_triples("default", query) + calls = mock_driver.execute_query.call_args_list assert any( - "user='default'" in str(call) and "collection='default'" in str(call) - for call in calls + "workspace='default'" in str(c) and "collection='default'" in str(c) + for c in calls ) @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_data_isolation_between_users(self, mock_graph_db): - """Test that data from different users is properly isolated""" + async def test_data_isolation_between_workspaces(self, mock_graph_db): + """Test that data from different workspaces is properly isolated""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Create messages for different users - message_user1 = Triples( - metadata=Metadata(user="user1", collection="coll1"), + + message_ws1 = Triples( + metadata=Metadata(collection="coll1"), triples=[ Triple( - s=Term(type=IRI, iri="http://example.com/user1/subject"), + s=Term(type=IRI, iri="http://example.com/ws1/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), - o=Term(type=LITERAL, value="user1_data") + o=Term(type=LITERAL, value="ws1_data") ) ] ) - - message_user2 = Triples( - metadata=Metadata(user="user2", collection="coll2"), + + message_ws2 = Triples( + metadata=Metadata(collection="coll2"), triples=[ Triple( - s=Term(type=IRI, iri="http://example.com/user2/subject"), + s=Term(type=IRI, iri="http://example.com/ws2/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), - o=Term(type=LITERAL, value="user2_data") + o=Term(type=LITERAL, value="ws2_data") ) ] ) - - # Mock execute_query + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - # Store data for both users - await processor.store_triples(message_user1) - await processor.store_triples(message_user2) - - # Verify user1 data was stored with user1/coll1 + await processor.store_triples("workspace1", message_ws1) + await processor.store_triples("workspace2", message_ws2) + mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value="user1_data", - user="user1", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value="ws1_data", + workspace="workspace1", collection="coll1", database_='neo4j' ) - - # Verify user2 data was stored with user2/coll2 + mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value="user2_data", - user="user2", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value="ws2_data", + workspace="workspace2", collection="coll2", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_respects_user_collection(self, mock_graph_db): - """Test that wildcard queries still filter by user/collection""" + async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db): + """Test that wildcard queries still filter by workspace/collection""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # Create wildcard query (all nulls) with user/collection + query = TriplesQueryRequest( - user="test_user", collection="test_collection", - s=None, - p=None, - o=None + s=None, p=None, o=None, ) - - # Mock results + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - - result = await processor.query_triples(query) - - # Verify wildcard queries include user/collection filters + + await processor.query_triples("test_workspace", query) + wildcard_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest" ) - + calls = mock_driver.execute_query.call_args_list assert any( - wildcard_query in str(call) and - "user='test_user'" in str(call) and - "collection='test_collection'" in str(call) - for call in calls + wildcard_query in str(c) and + "workspace='test_workspace'" in str(c) and + "collection='test_collection'" in str(c) + for c in calls ) def test_add_args_includes_neo4j_parameters(self): """Test that add_args includes Neo4j-specific parameters""" from argparse import ArgumentParser - from unittest.mock import patch - + parser = ArgumentParser() - + with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'): StorageProcessor.add_args(parser) - + args = parser.parse_args([]) - + assert hasattr(args, 'graph_host') assert hasattr(args, 'username') assert hasattr(args, 'password') assert hasattr(args, 'database') - - # Check defaults + assert args.graph_host == 'bolt://neo4j:7687' assert args.username == 'neo4j' assert args.password == 'password' assert args.database == 'neo4j' -class TestNeo4jUserCollectionRegression: - """Regression tests to ensure user/collection isolation prevents data leaks""" - +class TestNeo4jWorkspaceCollectionRegression: + """Regression tests to ensure workspace/collection isolation prevents data leaks""" + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') - @pytest.mark.asyncio - async def test_regression_no_cross_user_data_access(self, mock_graph_db): + @pytest.mark.asyncio + async def test_regression_no_cross_workspace_data_access(self, mock_graph_db): """ - Regression test: Ensure user1 cannot access user2's data - - This test guards against the bug where all users shared the same - Neo4j graph space, causing data contamination between users. + Regression test: Ensure workspace1 cannot access workspace2's data. + + Guards against a bug where all data shared the same Neo4j graph + space, causing data contamination between workspaces. """ mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # User1 queries for all triples - query_user1 = TriplesQueryRequest( - user="user1", + + query_ws1 = TriplesQueryRequest( collection="collection1", s=None, p=None, o=None ) - - # Mock that the database has data but none matching user1/collection1 + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - - result = await processor.query_triples(query_user1) - - # Verify empty results (user1 cannot see other users' data) + + result = await processor.query_triples("workspace1", query_ws1) + assert len(result) == 0 - - # Verify the query included user/collection filters + calls = mock_driver.execute_query.call_args_list - for call in calls: - query_str = str(call) + for c in calls: + query_str = str(c) if "MATCH" in query_str: - assert "user: $user" in query_str or "user='user1'" in query_str + assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str assert "collection: $collection" in query_str or "collection='collection1'" in query_str - + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_regression_same_uri_different_users(self, mock_graph_db): + async def test_regression_same_uri_different_workspaces(self, mock_graph_db): """ - Regression test: Same URI in different user contexts should create separate nodes - - This ensures that http://example.com/entity for user1 is completely separate - from http://example.com/entity for user2. + Regression test: Same URI in different workspace contexts should create separate nodes. + + Ensures http://example.com/entity in workspace1 is completely + separate from the same URI in workspace2. """ taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Same URI for different users + shared_uri = "http://example.com/shared_entity" - - message_user1 = Triples( - metadata=Metadata(user="user1", collection="coll1"), + + message_ws1 = Triples( + metadata=Metadata(collection="coll1"), triples=[ Triple( s=Term(type=IRI, iri=shared_uri), p=Term(type=IRI, iri="http://example.com/p"), - o=Term(type=LITERAL, value="user1_value") + o=Term(type=LITERAL, value="ws1_value") ) ] ) - - message_user2 = Triples( - metadata=Metadata(user="user2", collection="coll2"), + + message_ws2 = Triples( + metadata=Metadata(collection="coll2"), triples=[ Triple( s=Term(type=IRI, iri=shared_uri), p=Term(type=IRI, iri="http://example.com/p"), - o=Term(type=LITERAL, value="user2_value") + o=Term(type=LITERAL, value="ws2_value") ) ] ) - - # Mock execute_query + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message_user1) - await processor.store_triples(message_user2) - - # Verify two separate nodes were created with same URI but different user/collection - user1_node_call = call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + await processor.store_triples("workspace1", message_ws1) + await processor.store_triples("workspace2", message_ws2) + + ws1_node_call = call( + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri=shared_uri, - user="user1", + workspace="workspace1", collection="coll1", database_='neo4j' ) - - user2_node_call = call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + + ws2_node_call = call( + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri=shared_uri, - user="user2", + workspace="workspace2", collection="coll2", database_='neo4j' ) - - mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True) \ No newline at end of file + + mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True) diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index e1c8f3b1..8754f47c 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -1,3 +1,12 @@ + +def _flow_mock(workspace): + """Build a mock flow object that is callable and exposes .workspace.""" + from unittest.mock import MagicMock + f = MagicMock() + f.workspace = workspace + return f + + """ Unit tests for trustgraph.storage.row_embeddings.qdrant.write Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant. @@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) collection_name = processor.get_collection_name( - user="test_user", + workspace="test_workspace", collection="test_collection", schema_name="customer_data", dimension=384 ) - assert collection_name == "rows_test_user_test_collection_customer_data_384" + assert collection_name == "rows_test_workspace_test_collection_customer_data_384" @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_ensure_collection_creates_new(self, mock_qdrant_client): @@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('test_workspace', 'test_collection')] = {} # Create embeddings message metadata = MagicMock() - metadata.user = 'test_user' metadata.collection = 'test_collection' metadata.id = 'doc-123' @@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Verify upsert was called mock_qdrant_instance.upsert.assert_called_once() # Verify upsert parameters upsert_call_args = mock_qdrant_instance.upsert.call_args - assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3' + assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3' point = upsert_call_args[1]['points'][0] assert point.vector == [0.1, 0.2, 0.3] @@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('test_workspace', 'test_collection')] = {} metadata = MagicMock() - metadata.user = 'test_user' metadata.collection = 'test_collection' metadata.id = 'doc-123' @@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Should be called once for the single embedding assert mock_qdrant_instance.upsert.call_count == 1 @@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('test_workspace', 'test_collection')] = {} metadata = MagicMock() - metadata.user = 'test_user' metadata.collection = 'test_collection' metadata.id = 'doc-123' @@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Should not call upsert for empty vectors mock_qdrant_instance.upsert.assert_not_called() @@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): # No collections registered metadata = MagicMock() - metadata.user = 'unknown_user' metadata.collection = 'unknown_collection' metadata.id = 'doc-123' @@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Should not call upsert for unknown collection mock_qdrant_instance.upsert.assert_not_called() @@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): # Mock collections list mock_coll1 = MagicMock() - mock_coll1.name = 'rows_test_user_test_collection_schema1_384' + mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384' mock_coll2 = MagicMock() - mock_coll2.name = 'rows_test_user_test_collection_schema2_384' + mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384' mock_coll3 = MagicMock() - mock_coll3.name = 'rows_other_user_other_collection_schema_384' + mock_coll3.name = 'rows_other_workspace_other_collection_schema_384' mock_collections = MagicMock() mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3] @@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add('rows_test_user_test_collection_schema1_384') + processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') - await processor.delete_collection('test_user', 'test_collection') + await processor.delete_collection('test_workspace', 'test_collection') # Should delete only the matching collections assert mock_qdrant_instance.delete_collection.call_count == 2 # Verify the cached collection was removed - assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections + assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_delete_collection_schema(self, mock_qdrant_client): @@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_instance = MagicMock() mock_coll1 = MagicMock() - mock_coll1.name = 'rows_test_user_test_collection_customers_384' + mock_coll1.name = 'rows_test_workspace_test_collection_customers_384' mock_coll2 = MagicMock() - mock_coll2.name = 'rows_test_user_test_collection_orders_384' + mock_coll2.name = 'rows_test_workspace_test_collection_orders_384' mock_collections = MagicMock() mock_collections.collections = [mock_coll1, mock_coll2] @@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) await processor.delete_collection_schema( - 'test_user', 'test_collection', 'customers' + 'test_workspace', 'test_collection', 'customers' ) # Should only delete the customers schema collection mock_qdrant_instance.delete_collection.assert_called_once() call_args = mock_qdrant_instance.delete_collection.call_args[0] - assert call_args[0] == 'rows_test_user_test_collection_customers_384' + assert call_args[0] == 'rows_test_workspace_test_collection_customers_384' if __name__ == '__main__': diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index ccf193aa..852f01a1 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class _MockFlowDefault: + """Mock Flow with default workspace for testing.""" + workspace = "default" + name = "default" + id = "test-processor" + + +mock_flow_default = _MockFlowDefault() + class TestRowsCassandraStorageLogic: """Test business logic for unified table implementation""" @@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic: } # Process configuration - await processor.on_schema_config(config, version=1) + await processor.on_schema_config("default", config, version=1) # Verify schema was loaded - assert "customer_records" in processor.schemas - schema = processor.schemas["customer_records"] + assert "customer_records" in processor.schemas["default"] + schema = processor.schemas["default"]["customer_records"] assert schema.name == "customer_records" assert len(schema.fields) == 3 @@ -165,16 +176,18 @@ class TestRowsCassandraStorageLogic: """Test that row processing stores data as map""" processor = MagicMock() processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="value", type="string", size=100) - ] - ) + "default": { + "test_schema": RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="value", type="string", size=100) + ] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -191,7 +204,6 @@ class TestRowsCassandraStorageLogic: test_obj = ExtractedObject( metadata=Metadata( id="test-001", - user="test_user", collection="test_collection", ), schema_name="test_schema", @@ -205,7 +217,7 @@ class TestRowsCassandraStorageLogic: msg.value.return_value = test_obj # Process object - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify insert was executed mock_async_execute.assert_called() @@ -214,7 +226,7 @@ class TestRowsCassandraStorageLogic: values = insert_call[0][2] # Verify using unified rows table - assert "INSERT INTO test_user.rows" in insert_cql + assert "INSERT INTO default.rows" in insert_cql # Values should be: (collection, schema_name, index_name, index_value, data, source) assert values[0] == "test_collection" # collection @@ -230,16 +242,18 @@ class TestRowsCassandraStorageLogic: """Test that row is written once per indexed field""" processor = MagicMock() processor.schemas = { - "multi_index_schema": RowSchema( - name="multi_index_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="category", type="string", indexed=True), - Field(name="status", type="string", indexed=True) - ] - ) + "default": { + "multi_index_schema": RowSchema( + name="multi_index_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="status", type="string", indexed=True) + ] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -255,7 +269,6 @@ class TestRowsCassandraStorageLogic: test_obj = ExtractedObject( metadata=Metadata( id="test-001", - user="test_user", collection="test_collection", ), schema_name="multi_index_schema", @@ -267,7 +280,7 @@ class TestRowsCassandraStorageLogic: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 inserts (one per indexed field: id, category, status) assert mock_async_execute.call_count == 3 @@ -290,15 +303,17 @@ class TestRowsCassandraStorageBatchLogic: """Test processing of batch ExtractedObjects""" processor = MagicMock() processor.schemas = { - "batch_schema": RowSchema( - name="batch_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string") - ] - ) + "default": { + "batch_schema": RowSchema( + name="batch_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string") + ] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -315,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic: batch_obj = ExtractedObject( metadata=Metadata( id="batch-001", - user="test_user", collection="batch_collection", ), schema_name="batch_schema", @@ -331,7 +345,7 @@ class TestRowsCassandraStorageBatchLogic: msg = MagicMock() msg.value.return_value = batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 inserts (one per row, one index per row since only primary key) assert mock_async_execute.call_count == 3 @@ -349,12 +363,14 @@ class TestRowsCassandraStorageBatchLogic: """Test processing of empty batch ExtractedObjects""" processor = MagicMock() processor.schemas = { - "empty_schema": RowSchema( - name="empty_schema", - fields=[Field(name="id", type="string", primary=True)] - ) + "default": { + "empty_schema": RowSchema( + name="empty_schema", + fields=[Field(name="id", type="string", primary=True)] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -369,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic: empty_batch_obj = ExtractedObject( metadata=Metadata( id="empty-001", - user="test_user", collection="empty_collection", ), schema_name="empty_schema", @@ -381,7 +396,7 @@ class TestRowsCassandraStorageBatchLogic: msg = MagicMock() msg.value.return_value = empty_batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify no insert calls for empty batch processor.session.execute.assert_not_called() @@ -446,19 +461,21 @@ class TestPartitionRegistration: processor.registered_partitions = set() processor.session = MagicMock() processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="category", type="string", indexed=True) - ] - ) + "default": { + "test_schema": RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True) + ] + ) + } } processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) - processor.register_partitions("test_user", "test_collection", "test_schema") + processor.register_partitions("test_user", "test_collection", "test_schema", "default") # Should have 2 inserts (one per index: id, category) assert processor.session.execute.call_count == 2 @@ -473,7 +490,7 @@ class TestPartitionRegistration: processor.session = MagicMock() processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) - processor.register_partitions("test_user", "test_collection", "test_schema") + processor.register_partitions("test_user", "test_collection", "test_schema", "default") # Should not execute any CQL since already registered processor.session.execute.assert_not_called() diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 73272942..04acbb16 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -102,11 +102,10 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify KnowledgeGraph was called with auth parameters mock_kg_class.assert_called_once_with( @@ -129,11 +128,10 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user2' mock_message.metadata.collection = 'collection2' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user2', mock_message) # Verify KnowledgeGraph was called without auth parameters mock_kg_class.assert_called_once_with( @@ -154,16 +152,15 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] # First call should create TrustGraph - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) assert mock_kg_class.call_count == 1 # Second call with same table should reuse TrustGraph - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) assert mock_kg_class.call_count == 1 # Should not increase @pytest.mark.asyncio @@ -205,11 +202,10 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [triple1, triple2] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) assert mock_tg_instance.insert.call_count == 2 @@ -234,11 +230,10 @@ class TestCassandraStorageProcessor: # Create mock message with empty triples mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify no triples were inserted mock_tg_instance.insert.assert_not_called() @@ -255,12 +250,11 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify sleep was called before re-raising mock_sleep.assert_called_once_with(1) @@ -361,21 +355,19 @@ class TestCassandraStorageProcessor: # First message with table1 mock_message1 = MagicMock() - mock_message1.metadata.user = 'user1' mock_message1.metadata.collection = 'collection1' mock_message1.triples = [] - await processor.store_triples(mock_message1) + await processor.store_triples('user1', mock_message1) assert processor.table == 'user1' assert processor.tg == mock_tg_instance1 # Second message with different table mock_message2 = MagicMock() - mock_message2.metadata.user = 'user2' mock_message2.metadata.collection = 'collection2' mock_message2.triples = [] - await processor.store_triples(mock_message2) + await processor.store_triples('user2', mock_message2) assert processor.table == 'user2' assert processor.tg == mock_tg_instance2 @@ -407,11 +399,10 @@ class TestCassandraStorageProcessor: triple.g = None mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [triple] - await processor.store_triples(mock_message) + await processor.store_triples('test_workspace', mock_message) # Verify the triple was inserted with special characters preserved mock_tg_instance.insert.assert_called_once_with( @@ -440,12 +431,11 @@ class TestCassandraStorageProcessor: mock_kg_class.side_effect = Exception("Connection failed") mock_message = MagicMock() - mock_message.metadata.user = 'new_user' mock_message.metadata.collection = 'new_collection' mock_message.triples = [] with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples(mock_message) + await processor.store_triples('new_user', mock_message) # Table should remain unchanged since self.table = table happens after try/except assert processor.table == ('old_user', 'old_collection') @@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations: processor = Processor(taskgroup=taskgroup_mock) mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify KnowledgeGraph instance uses legacy mode assert mock_tg_instance is not None @@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations: processor = Processor(taskgroup=taskgroup_mock) mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify KnowledgeGraph instance is in optimized mode assert mock_tg_instance is not None @@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations: triple.g = None mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [triple] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) mock_tg_instance.insert.assert_called_once_with( diff --git a/tests/unit/test_storage/test_triples_falkordb_storage.py b/tests/unit/test_storage/test_triples_falkordb_storage.py index 05dcb2e5..c5b0848e 100644 --- a/tests/unit/test_storage/test_triples_falkordb_storage.py +++ b/tests/unit/test_storage/test_triples_falkordb_storage.py @@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a test triple @@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_node(test_uri, 'test_user', 'test_collection') + processor.create_node(test_uri, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", params={ "uri": test_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_literal(test_value, 'test_user', 'test_collection') + processor.create_literal(test_value, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", params={ "value": test_value, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection') + processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src_uri, "dest": dest_uri, "uri": pred_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection') + processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src_uri, "dest": literal_value, "uri": pred_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor: """Test storing triple with URI object""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple = Triple( @@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify queries were called in the correct order expected_calls = [ # Create subject node - (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), - {"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",), + {"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}), # Create object node - (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), - {"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",), + {"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}), # Create relationship - (("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), - {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), + (("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}), ] assert processor.io.query.call_count == 3 @@ -237,21 +235,21 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) - + await processor.store_triples('test_workspace', mock_message) + # Verify queries were called in the correct order expected_calls = [ # Create subject node - (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), - {"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",), + {"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}), # Create literal object - (("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",), - {"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",), + {"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}), # Create relationship - (("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), - {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), + (("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}), ] assert processor.io.query.call_count == 3 @@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor: """Test storing multiple triples""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple1 = Triple( @@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify total number of queries (3 per triple) assert processor.io.query.call_count == 6 @@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor: """Test storing empty triples list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.triples = [] @@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify no queries were made processor.io.query.assert_not_called() @@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor: """Test storing triples with mixed URI and literal objects""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple1 = Triple( @@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify total number of queries (3 per triple) assert processor.io.query.call_count == 6 @@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_node(test_uri, 'test_user', 'test_collection') + processor.create_node(test_uri, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", params={ "uri": test_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_literal(test_value, 'test_user', 'test_collection') + processor.create_literal(test_value, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", params={ "value": test_value, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_memgraph_storage.py b/tests/unit/test_storage/test_triples_memgraph_storage.py index 162586d5..6a0c68a2 100644 --- a/tests/unit/test_storage/test_triples_memgraph_storage.py +++ b/tests/unit/test_storage/test_triples_memgraph_storage.py @@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a test triple @@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor: taskgroup=MagicMock(), id='test-memgraph-storage', graph_host='bolt://localhost:7687', - username='test_user', + username='test_workspace', password='test_pass', database='test_db' ) @@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor: "CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Literal", "CREATE INDEX ON :Literal(value)", - "CREATE INDEX ON :Node(user)", + "CREATE INDEX ON :Node(workspace)", "CREATE INDEX ON :Node(collection)", - "CREATE INDEX ON :Literal(user)", + "CREATE INDEX ON :Literal(workspace)", "CREATE INDEX ON :Literal(collection)" ] @@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.create_node(test_uri, "test_user", "test_collection") + processor.create_node(test_uri, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri=test_uri, - user="test_user", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.create_literal(test_value, "test_user", "test_collection") + processor.create_literal(test_value, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value=test_value, - user="test_user", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection") + processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src=src_uri, dest=dest_uri, uri=pred_uri, - user="test_user", collection="test_collection", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection") + processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src=src_uri, dest=literal_value, uri=pred_uri, - user="test_user", collection="test_collection", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor: o=Term(type=IRI, iri='http://example.com/object') ) - processor.create_triple(mock_tx, triple, "test_user", "test_collection") + processor.create_triple(mock_tx, triple, "test_workspace", "test_collection") # Verify transaction calls expected_calls = [ # Create subject node - ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create object node - ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create relationship - ("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + ("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", {'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate', - 'user': 'test_user', 'collection': 'test_collection'}) + 'workspace': 'test_workspace', 'collection': 'test_collection'}) ] assert mock_tx.run.call_count == 3 @@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor: o=Term(type=LITERAL, value='literal object') ) - processor.create_triple(mock_tx, triple, "test_user", "test_collection") + processor.create_triple(mock_tx, triple, "test_workspace", "test_collection") # Verify transaction calls expected_calls = [ # Create subject node - ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create literal object - ("MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - {'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + {'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create relationship - ("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + ("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", {'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate', - 'user': 'test_user', 'collection': 'test_collection'}) + 'workspace': 'test_workspace', 'collection': 'test_collection'}) ] assert mock_tx.run.call_count == 3 @@ -314,8 +313,8 @@ class TestMemgraphStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) - + await processor.store_triples('test_workspace', mock_message) + # Verify execute_query was called for create_node, create_literal, and relate_literal # (since mock_message has a literal object) assert processor.io.execute_query.call_count == 3 @@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor: # Verify user/collection parameters were included for call in processor.io.execute_query.call_args_list: call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert 'user' in call_kwargs + assert 'workspace' in call_kwargs assert 'collection' in call_kwargs @pytest.mark.asyncio @@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor: # Create message with multiple triples message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple1 = Triple( @@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify execute_query was called: # Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls @@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor: # Verify user/collection parameters were included in all calls for call in processor.io.execute_query.call_args_list: call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert call_kwargs['user'] == 'test_user' + assert call_kwargs['workspace'] == 'test_workspace' assert call_kwargs['collection'] == 'test_collection' @pytest.mark.asyncio @@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor: message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.triples = [] @@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify no session calls were made (no triples to process) processor.io.session.assert_not_called() diff --git a/tests/unit/test_storage/test_triples_neo4j_storage.py b/tests/unit/test_storage/test_triples_neo4j_storage.py index a5181ed9..0dcdb55e 100644 --- a/tests/unit/test_storage/test_triples_neo4j_storage.py +++ b/tests/unit/test_storage/test_triples_neo4j_storage.py @@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor: "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", - "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", - "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", - "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)", + "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)", + "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)", "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" ] @@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) # Test create_node - processor.create_node("http://example.com/node", "test_user", "test_collection") + processor.create_node("http://example.com/node", "test_workspace", "test_collection") mock_driver.execute_query.assert_called_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/node", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) # Test create_literal - processor.create_literal("literal value", "test_user", "test_collection") + processor.create_literal("literal value", "test_workspace", "test_collection") mock_driver.execute_query.assert_called_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value="literal value", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor: "http://example.com/subject", "http://example.com/predicate", "http://example.com/object", - "test_user", + "test_workspace", "test_collection" ) mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="http://example.com/object", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor: "http://example.com/subject", "http://example.com/predicate", "literal value", - "test_user", + "test_workspace", "test_collection" ) mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal value", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor: # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Verify create_node was called for subject and object # Verify relate_node was called expected_calls = [ # Subject node creation ( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Object node creation ( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Relationship creation ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", { "src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", - "user": "test_user", + "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j" } @@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor: # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Verify create_node was called for subject # Verify create_literal was called for object @@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor: expected_calls = [ # Subject node creation ( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Literal creation ( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - {"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + {"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Relationship creation ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", { "src": "http://example.com/subject", "dest": "literal value", "uri": "http://example.com/predicate", - "user": "test_user", + "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j" } @@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor: # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple1, triple2] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Should have processed both triples # Triple1: 2 nodes + 1 relationship = 3 calls @@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor: # Create mock message with empty triples and metadata mock_message = MagicMock() mock_message.triples = [] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Should not have made any execute_query calls beyond index creation # Only index creation calls should have been made during initialization @@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor: mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Verify the triple was processed with special characters preserved mock_driver.execute_query.assert_any_call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/subject with spaces", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value='literal with "quotes" and unicode: ñáéíóú', - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) mock_driver.execute_query.assert_any_call( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject with spaces", dest='literal with "quotes" and unicode: ñáéíóú', uri="http://example.com/predicate:with/symbols", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py index 3222ec83..51cf834f 100644 --- a/tests/unit/test_structured_data/test_row_embeddings_query.py +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None): return proc -def _make_request(vector=None, user="test-user", collection="test-col", +def _make_request(vector=None, collection="test-col", schema_name="customers", limit=10, index_name=None): return RowEmbeddingsRequest( vector=vector or [0.1, 0.2, 0.3], - user=user, collection=collection, schema_name=schema_name, limit=limit, @@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col", ) +def _make_flow(workspace="test-workspace", pub=None): + """Make a mock flow object that is callable and has .workspace.""" + flow = MagicMock() + flow.return_value = pub if pub is not None else AsyncMock() + flow.workspace = workspace + return flow + + def _make_search_point(index_name, index_value, text, score): point = MagicMock() point.payload = { @@ -85,34 +92,33 @@ class TestFindCollection: def test_finds_matching_collection(self): proc = _make_processor() mock_coll = MagicMock() - mock_coll.name = "rows_test_user_test_col_customers_384" + mock_coll.name = "rows_test_workspace_test_col_customers_384" mock_collections = MagicMock() mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-user", "test-col", "customers") + result = proc.find_collection("test-workspace", "test-col", "customers") - # Prefix: rows_test_user_test_col_customers_ - assert result == "rows_test_user_test_col_customers_384" + assert result == "rows_test_workspace_test_col_customers_384" def test_returns_none_when_no_match(self): proc = _make_processor() mock_coll = MagicMock() - mock_coll.name = "rows_other_user_other_col_schema_768" + mock_coll.name = "rows_other_workspace_other_col_schema_768" mock_collections = MagicMock() mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-user", "test-col", "customers") + result = proc.find_collection("test-workspace", "test-col", "customers") assert result is None def test_returns_none_on_error(self): proc = _make_processor() proc.qdrant.get_collections.side_effect = Exception("connection error") - result = proc.find_collection("user", "col", "schema") + result = proc.find_collection("workspace", "col", "schema") assert result is None @@ -127,7 +133,7 @@ class TestQueryRowEmbeddings: proc = _make_processor() request = _make_request(vector=[]) - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert result == [] @pytest.mark.asyncio @@ -136,13 +142,13 @@ class TestQueryRowEmbeddings: proc.find_collection = MagicMock(return_value=None) request = _make_request() - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert result == [] @pytest.mark.asyncio async def test_successful_query_returns_matches(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") points = [ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), @@ -153,7 +159,7 @@ class TestQueryRowEmbeddings: proc.qdrant.query_points.return_value = mock_result request = _make_request() - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert len(result) == 2 assert isinstance(result[0], RowIndexMatch) @@ -166,14 +172,14 @@ class TestQueryRowEmbeddings: async def test_index_name_filter_applied(self): """When index_name is specified, a Qdrant filter should be used.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] proc.qdrant.query_points.return_value = mock_result request = _make_request(index_name="address") - await proc.query_row_embeddings(request) + await proc.query_row_embeddings("test-workspace", request) call_kwargs = proc.qdrant.query_points.call_args[1] assert call_kwargs["query_filter"] is not None @@ -182,14 +188,14 @@ class TestQueryRowEmbeddings: async def test_no_index_name_no_filter(self): """When index_name is empty, no filter should be applied.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] proc.qdrant.query_points.return_value = mock_result request = _make_request(index_name="") - await proc.query_row_embeddings(request) + await proc.query_row_embeddings("test-workspace", request) call_kwargs = proc.qdrant.query_points.call_args[1] assert call_kwargs["query_filter"] is None @@ -198,7 +204,7 @@ class TestQueryRowEmbeddings: async def test_missing_payload_fields_default(self): """Points with missing payload fields should use defaults.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") point = MagicMock() point.payload = {} # Empty payload @@ -209,7 +215,7 @@ class TestQueryRowEmbeddings: proc.qdrant.query_points.return_value = mock_result request = _make_request() - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert len(result) == 1 assert result[0].index_name == "" @@ -219,13 +225,13 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_qdrant_error_propagates(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.qdrant.query_points.side_effect = Exception("qdrant down") request = _make_request() with pytest.raises(Exception, match="qdrant down"): - await proc.query_row_embeddings(request) + await proc.query_row_embeddings("test-workspace", request) # --------------------------------------------------------------------------- @@ -243,7 +249,7 @@ class TestOnMessage: ]) mock_pub = AsyncMock() - flow = lambda name: mock_pub + flow = _make_flow(pub=mock_pub) msg = MagicMock() msg.value.return_value = _make_request() @@ -264,7 +270,7 @@ class TestOnMessage: ) mock_pub = AsyncMock() - flow = lambda name: mock_pub + flow = _make_flow(pub=mock_pub) msg = MagicMock() msg.value.return_value = _make_request() @@ -284,7 +290,7 @@ class TestOnMessage: proc.query_row_embeddings = AsyncMock(return_value=[]) mock_pub = AsyncMock() - flow = lambda name: mock_pub + flow = _make_flow(pub=mock_pub) msg = MagicMock() msg.value.return_value = _make_request() diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py index 4ab0ffeb..59d15b45 100644 --- a/tests/unit/test_tables/test_knowledge_table_store.py +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -45,12 +45,9 @@ class TestGetGraphEmbeddings: with `vector=` (singular) — the schema field name. A previous version used `vectors=` and TypeError'd at runtime. """ - # Arrange — fake row matching the get_triples_stmt result shape: - # row[0..2] are unused by the method, row[3] is the entities blob fake_row = ( None, None, None, [ - # ((value, is_uri), vector) (("http://example.org/alice", True), [0.1, 0.2, 0.3]), (("http://example.org/bob", True), [0.4, 0.5, 0.6]), (("a literal entity", False), [0.7, 0.8, 0.9]), @@ -67,14 +64,8 @@ class TestGetGraphEmbeddings: async def receiver(msg): received.append(msg) - # Act - await store.get_graph_embeddings( - user="alice", - document_id="doc-1", - receiver=receiver, - ) + await store.get_graph_embeddings("alice", "doc-1", receiver) - # Assert mock_async_execute.assert_called_once_with( store.cassandra, store.get_graph_embeddings_stmt, @@ -86,7 +77,6 @@ class TestGetGraphEmbeddings: assert isinstance(ge, GraphEmbeddings) assert isinstance(ge.metadata, Metadata) assert ge.metadata.id == "doc-1" - assert ge.metadata.user == "alice" assert len(ge.entities) == 3 assert all(isinstance(e, EntityEmbeddings) for e in ge.entities) @@ -122,7 +112,7 @@ class TestGetGraphEmbeddings: async def receiver(msg): received.append(msg) - await store.get_graph_embeddings("u", "d", receiver) + await store.get_graph_embeddings("w", "d", receiver) assert len(received) == 1 assert received[0].entities == [] @@ -149,7 +139,7 @@ class TestGetGraphEmbeddings: async def receiver(msg): received.append(msg) - await store.get_graph_embeddings("u", "d", receiver) + await store.get_graph_embeddings("w", "d", receiver) assert len(received) == 2 assert received[0].entities[0].entity.iri == "http://example.org/a" @@ -194,7 +184,6 @@ class TestGetTriples: assert isinstance(triples_msg, Triples) assert isinstance(triples_msg.metadata, Metadata) assert triples_msg.metadata.id == "doc-1" - assert triples_msg.metadata.user == "alice" assert len(triples_msg.triples) == 1 t = triples_msg.triples[0] diff --git a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py index 72f4796b..56a1583e 100644 --- a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py +++ b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py @@ -30,7 +30,6 @@ def sample(): metadata=Metadata( id="doc-1", root="", - user="alice", collection="testcoll", ), chunks=[ @@ -56,7 +55,6 @@ class TestDocumentEmbeddingsTranslator: assert isinstance(decoded, DocumentEmbeddings) assert isinstance(decoded.metadata, Metadata) assert decoded.metadata.id == "doc-1" - assert decoded.metadata.user == "alice" assert decoded.metadata.collection == "testcoll" assert len(decoded.chunks) == 2 diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py index 57e7ae17..64f2e5d4 100644 --- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -41,7 +41,7 @@ def translator(): def graph_embeddings_request(): return KnowledgeRequest( operation="put-kg-core", - user="alice", + workspace="alice", id="doc-1", flow="default", collection="testcoll", @@ -49,7 +49,6 @@ def graph_embeddings_request(): metadata=Metadata( id="doc-1", root="", - user="alice", collection="testcoll", ), entities=[ @@ -70,7 +69,6 @@ def graph_embeddings_request(): def triples_request(): return KnowledgeRequest( operation="put-kg-core", - user="alice", id="doc-1", flow="default", collection="testcoll", @@ -78,7 +76,6 @@ def triples_request(): metadata=Metadata( id="doc-1", root="", - user="alice", collection="testcoll", ), triples=[ @@ -113,7 +110,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings: assert isinstance(decoded, KnowledgeRequest) assert decoded.operation == "put-kg-core" - assert decoded.user == "alice" + assert decoded.workspace == "alice" assert decoded.id == "doc-1" assert decoded.flow == "default" assert decoded.collection == "testcoll" @@ -123,7 +120,6 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings: assert isinstance(ge, GraphEmbeddings) assert isinstance(ge.metadata, Metadata) assert ge.metadata.id == "doc-1" - assert ge.metadata.user == "alice" assert ge.metadata.collection == "testcoll" assert len(ge.entities) == 2 @@ -143,7 +139,6 @@ class TestKnowledgeRequestTranslatorTriples: assert decoded.triples is not None assert isinstance(decoded.triples.metadata, Metadata) assert decoded.triples.metadata.id == "doc-1" - assert decoded.triples.metadata.user == "alice" assert decoded.triples.metadata.collection == "testcoll" assert len(decoded.triples.triples) == 1 diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index 2f44aad0..e14c61a3 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -27,7 +27,6 @@ Quick Start: # Execute a graph RAG query response = flow.graph_rag( query="What are the main topics?", - user="trustgraph", collection="default" ) ``` @@ -38,7 +37,7 @@ For streaming and async operations: socket = api.socket() flow = socket.flow("default") - for chunk in flow.agent(question="Hello", user="trustgraph"): + for chunk in flow.agent(question="Hello"): print(chunk.content) # Async operations diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index dbdce0a8..9074bac1 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -50,7 +50,7 @@ class Api: token: Optional bearer token for authentication """ - def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None): + def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None, workspace: str = "default"): """ Initialize the TrustGraph API client. @@ -82,6 +82,7 @@ class Api: self.timeout = timeout self.token = token + self.workspace = workspace # Lazy initialization for new clients self._socket_client = None @@ -137,7 +138,7 @@ class Api: config.put([ConfigValue(type="llm", key="model", value="gpt-4")]) ``` """ - return Config(api=self) + return Config(api=self, workspace=self.workspace) def knowledge(self): """ @@ -151,10 +152,10 @@ class Api: knowledge = api.knowledge() # List available KG cores - cores = knowledge.list_kg_cores(user="trustgraph") + cores = knowledge.list_kg_cores() # Load a KG core - knowledge.load_kg_core(id="core-123", user="trustgraph") + knowledge.load_kg_core(id="core-123") ``` """ return Knowledge(api=self) @@ -191,6 +192,12 @@ class Api: if self.token: headers["Authorization"] = f"Bearer {self.token}" + # Ensure every REST request carries the workspace so services can + # scope their behaviour. Callers that already set workspace in the + # payload (e.g. Library client) take precedence. + if isinstance(request, dict) and "workspace" not in request: + request = {**request, "workspace": self.workspace} + # Invoke the API, input is passed as JSON resp = requests.post(url, json=request, timeout=self.timeout, headers=headers) @@ -227,13 +234,12 @@ class Api: document=b"Document content", id="doc-123", metadata=[], - user="trustgraph", title="My Document", comments="Test document" ) # List documents - docs = library.get_documents(user="trustgraph") + docs = library.get_documents() ``` """ return Library(self) @@ -253,11 +259,10 @@ class Api: collection = api.collection() # List collections - colls = collection.list_collections(user="trustgraph") + colls = collection.list_collections() # Update collection metadata collection.update_collection( - user="trustgraph", collection="default", name="Default Collection", description="Main data collection" @@ -286,7 +291,6 @@ class Api: # Stream agent responses for chunk in flow.agent( question="Explain quantum computing", - user="trustgraph", streaming=True ): if hasattr(chunk, 'content'): @@ -297,7 +301,10 @@ class Api: from . socket_client import SocketClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._socket_client = SocketClient(base_url, self.timeout, self.token) + self._socket_client = SocketClient( + base_url, self.timeout, self.token, + workspace=self.workspace, + ) return self._socket_client def bulk(self): @@ -406,7 +413,6 @@ class Api: # Stream agent responses async for chunk in flow.agent( question="Explain quantum computing", - user="trustgraph", streaming=True ): if hasattr(chunk, 'content'): @@ -417,7 +423,10 @@ class Api: from . async_socket_client import AsyncSocketClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token) + self._async_socket_client = AsyncSocketClient( + base_url, self.timeout, self.token, + workspace=self.workspace, + ) return self._async_socket_client def async_bulk(self): diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 68899341..bf0b2ba1 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -326,9 +326,7 @@ class AsyncFlow: # Use flow services result = await flow.graph_rag( - query="What is TrustGraph?", - user="trustgraph", - collection="default" + query="What is TrustGraph?",collection="default" ) ``` """ @@ -385,7 +383,7 @@ class AsyncFlowInstance: """ return await self.flow.request(f"flow/{self.flow_id}/service/{service}", request_data) - async def agent(self, question: str, user: str, state: Optional[Dict] = None, + async def agent(self, question: str, state: Optional[Dict] = None, group: Optional[str] = None, history: Optional[List] = None, **kwargs: Any) -> Dict[str, Any]: """ Execute an agent operation (non-streaming). @@ -399,7 +397,6 @@ class AsyncFlowInstance: Args: question: User question or instruction - user: User identifier state: Optional state dictionary for conversation context group: Optional group identifier for session management history: Optional conversation history list @@ -416,14 +413,12 @@ class AsyncFlowInstance: # Execute agent result = await flow.agent( question="What is the capital of France?", - user="trustgraph" - ) + ) print(f"Answer: {result.get('response')}") ``` """ request_data = { "question": question, - "user": user, "streaming": False # REST doesn't support streaming } if state is not None: @@ -481,7 +476,7 @@ class AsyncFlowInstance: model=result.get("model"), ) - async def graph_rag(self, query: str, user: str, collection: str, + async def graph_rag(self, query: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, **kwargs: Any) -> str: """ @@ -496,7 +491,6 @@ class AsyncFlowInstance: Args: query: User query text - user: User identifier collection: Collection identifier containing the knowledge graph max_subgraph_size: Maximum number of triples per subgraph (default: 1000) max_subgraph_count: Maximum number of subgraphs to retrieve (default: 5) @@ -513,9 +507,7 @@ class AsyncFlowInstance: # Query knowledge graph response = await flow.graph_rag( - query="What are the relationships between these entities?", - user="trustgraph", - collection="medical-kb", + query="What are the relationships between these entities?",collection="medical-kb", max_subgraph_count=3 ) print(response) @@ -523,7 +515,6 @@ class AsyncFlowInstance: """ request_data = { "query": query, - "user": user, "collection": collection, "max-subgraph-size": max_subgraph_size, "max-subgraph-count": max_subgraph_count, @@ -535,7 +526,7 @@ class AsyncFlowInstance: result = await self.request("graph-rag", request_data) return result.get("response", "") - async def document_rag(self, query: str, user: str, collection: str, + async def document_rag(self, query: str, collection: str, doc_limit: int = 10, **kwargs: Any) -> str: """ Execute document-based RAG query (non-streaming). @@ -549,7 +540,6 @@ class AsyncFlowInstance: Args: query: User query text - user: User identifier collection: Collection identifier containing documents doc_limit: Maximum number of document chunks to retrieve (default: 10) **kwargs: Additional service-specific parameters @@ -564,9 +554,7 @@ class AsyncFlowInstance: # Query documents response = await flow.document_rag( - query="What does the documentation say about authentication?", - user="trustgraph", - collection="docs", + query="What does the documentation say about authentication?",collection="docs", doc_limit=5 ) print(response) @@ -574,7 +562,6 @@ class AsyncFlowInstance: """ request_data = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": False @@ -584,7 +571,7 @@ class AsyncFlowInstance: result = await self.request("document-rag", request_data) return result.get("response", "") - async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any): + async def graph_embeddings_query(self, text: str, collection: str, limit: int = 10, **kwargs: Any): """ Query graph embeddings for semantic entity search. @@ -593,7 +580,6 @@ class AsyncFlowInstance: Args: text: Query text for semantic search - user: User identifier collection: Collection identifier containing graph embeddings limit: Maximum number of results to return (default: 10) **kwargs: Additional service-specific parameters @@ -608,9 +594,7 @@ class AsyncFlowInstance: # Find related entities results = await flow.graph_embeddings_query( - text="machine learning algorithms", - user="trustgraph", - collection="tech-kb", + text="machine learning algorithms",collection="tech-kb", limit=5 ) @@ -624,7 +608,6 @@ class AsyncFlowInstance: request_data = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -663,7 +646,7 @@ class AsyncFlowInstance: return await self.request("embeddings", request_data) - async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs: Any): + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any): """ Query RDF triples using pattern matching. @@ -674,7 +657,6 @@ class AsyncFlowInstance: s: Subject pattern (None for wildcard) p: Predicate pattern (None for wildcard) o: Object pattern (None for wildcard) - user: User identifier (None for all users) collection: Collection identifier (None for all collections) limit: Maximum number of triples to return (default: 100) **kwargs: Additional service-specific parameters @@ -689,9 +671,7 @@ class AsyncFlowInstance: # Find all triples with a specific predicate results = await flow.triples_query( - p="knows", - user="trustgraph", - collection="social", + p="knows",collection="social", limit=50 ) @@ -706,15 +686,13 @@ class AsyncFlowInstance: request_data["p"] = str(p) if o is not None: request_data["o"] = str(o) - if user is not None: - request_data["user"] = user if collection is not None: request_data["collection"] = collection request_data.update(kwargs) return await self.request("triples", request_data) - async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + async def rows_query(self, query: str, collection: str, variables: Optional[Dict] = None, operation_name: Optional[str] = None, **kwargs: Any): """ Execute a GraphQL query on stored rows. @@ -724,7 +702,6 @@ class AsyncFlowInstance: Args: query: GraphQL query string - user: User identifier collection: Collection identifier containing rows variables: Optional GraphQL query variables operation_name: Optional operation name for multi-operation queries @@ -750,9 +727,7 @@ class AsyncFlowInstance: ''' result = await flow.rows_query( - query=query, - user="trustgraph", - collection="users", + query=query,collection="users", variables={"status": "active"} ) @@ -762,7 +737,6 @@ class AsyncFlowInstance: """ request_data = { "query": query, - "user": user, "collection": collection } if variables: @@ -774,7 +748,7 @@ class AsyncFlowInstance: return await self.request("rows", request_data) async def row_embeddings_query( - self, text: str, schema_name: str, user: str = "trustgraph", + self, text: str, schema_name: str, collection: str = "default", index_name: Optional[str] = None, limit: int = 10, **kwargs: Any ): @@ -788,7 +762,6 @@ class AsyncFlowInstance: Args: text: Query text for semantic search schema_name: Schema name to search within - user: User identifier (default: "trustgraph") collection: Collection identifier (default: "default") index_name: Optional index name to filter search to specific index limit: Maximum number of results to return (default: 10) @@ -806,9 +779,7 @@ class AsyncFlowInstance: # Search for customers by name similarity results = await flow.row_embeddings_query( text="John Smith", - schema_name="customers", - user="trustgraph", - collection="sales", + schema_name="customers",collection="sales", limit=5 ) @@ -823,7 +794,6 @@ class AsyncFlowInstance: request_data = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 6e5064ab..e5d553ea 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -22,10 +22,14 @@ class AsyncSocketClient: Or call connect()/aclose() manually. """ - def __init__(self, url: str, timeout: int, token: Optional[str]): + def __init__( + self, url: str, timeout: int, token: Optional[str], + workspace: str = "default", + ): self.url = self._convert_to_ws_url(url) self.timeout = timeout self.token = token + self.workspace = workspace self._request_counter = 0 self._socket = None self._connect_cm = None @@ -117,6 +121,7 @@ class AsyncSocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -149,6 +154,7 @@ class AsyncSocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -251,13 +257,12 @@ class AsyncSocketFlowInstance: self.client = client self.flow_id = flow_id - async def agent(self, question: str, user: str, state: Optional[Dict[str, Any]] = None, + async def agent(self, question: str, state: Optional[Dict[str, Any]] = None, group: Optional[str] = None, history: Optional[list] = None, streaming: bool = False, **kwargs) -> Union[Dict[str, Any], AsyncIterator]: """Agent with optional streaming""" request = { "question": question, - "user": user, "streaming": streaming } if state is not None: @@ -303,13 +308,12 @@ class AsyncSocketFlowInstance: if isinstance(chunk, RAGChunk): yield chunk - async def graph_rag(self, query: str, user: str, collection: str, + async def graph_rag(self, query: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, streaming: bool = False, **kwargs): """Graph RAG with optional streaming""" request = { "query": query, - "user": user, "collection": collection, "max-subgraph-size": max_subgraph_size, "max-subgraph-count": max_subgraph_count, @@ -330,12 +334,11 @@ class AsyncSocketFlowInstance: if hasattr(chunk, 'content'): yield chunk.content - async def document_rag(self, query: str, user: str, collection: str, + async def document_rag(self, query: str, collection: str, doc_limit: int = 10, streaming: bool = False, **kwargs): """Document RAG with optional streaming""" request = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": streaming @@ -375,14 +378,13 @@ class AsyncSocketFlowInstance: if hasattr(chunk, 'content'): yield chunk.content - async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): + async def graph_embeddings_query(self, text: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" emb_result = await self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] request = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -397,7 +399,7 @@ class AsyncSocketFlowInstance: return await self.client._send_request("embeddings", self.flow_id, request) - async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs): + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs): """Triple pattern query""" request = {"limit": limit} if s is not None: @@ -406,20 +408,17 @@ class AsyncSocketFlowInstance: request["p"] = str(p) if o is not None: request["o"] = str(o) - if user is not None: - request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) return await self.client._send_request("triples", self.flow_id, request) - async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + async def rows_query(self, query: str, collection: str, variables: Optional[Dict] = None, operation_name: Optional[str] = None, **kwargs): """GraphQL query against structured rows""" request = { "query": query, - "user": user, "collection": collection } if variables: @@ -441,7 +440,7 @@ class AsyncSocketFlowInstance: return await self.client._send_request("mcp-tool", self.flow_id, request) async def row_embeddings_query( - self, text: str, schema_name: str, user: str = "trustgraph", + self, text: str, schema_name: str, collection: str = "default", index_name: Optional[str] = None, limit: int = 10, **kwargs ): @@ -452,7 +451,6 @@ class AsyncSocketFlowInstance: request = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 75999550..0e49fc4e 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -85,7 +85,7 @@ class BulkClient: Args: flow: Flow identifier triples: Iterator yielding Triple objects - metadata: Metadata dict with id, metadata, user, collection + metadata: Metadata dict with id, metadata, collection batch_size: Number of triples per batch (default 100) **kwargs: Additional parameters (reserved for future use) @@ -105,7 +105,7 @@ class BulkClient: bulk.import_triples( flow="default", triples=triple_generator(), - metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"} + metadata={"id": "doc1", "metadata": [], "collection": "default"} ) ``` """ @@ -121,7 +121,7 @@ class BulkClient: ws_url = f"{ws_url}?token={self.token}" if metadata is None: - metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"} + metadata = {"id": "", "metadata": [], "collection": "default"} async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: batch = [] @@ -418,7 +418,7 @@ class BulkClient: Args: flow: Flow identifier contexts: Iterator yielding context dictionaries - metadata: Metadata dict with id, metadata, user, collection + metadata: Metadata dict with id, metadata, collection batch_size: Number of contexts per batch (default 100) **kwargs: Additional parameters (reserved for future use) @@ -435,7 +435,7 @@ class BulkClient: bulk.import_entity_contexts( flow="default", contexts=context_generator(), - metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"} + metadata={"id": "doc1", "metadata": [], "collection": "default"} ) ``` """ @@ -451,7 +451,7 @@ class BulkClient: ws_url = f"{ws_url}?token={self.token}" if metadata is None: - metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"} + metadata = {"id": "", "metadata": [], "collection": "default"} async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: batch = [] diff --git a/trustgraph-base/trustgraph/api/collection.py b/trustgraph-base/trustgraph/api/collection.py index 414d07db..11cd2843 100644 --- a/trustgraph-base/trustgraph/api/collection.py +++ b/trustgraph-base/trustgraph/api/collection.py @@ -2,11 +2,9 @@ TrustGraph Collection Management This module provides interfaces for managing data collections in TrustGraph. -Collections provide logical grouping and isolation for documents and knowledge -graph data. +Collections provide logical grouping within a workspace. """ -import datetime import logging from . types import CollectionMetadata @@ -18,10 +16,9 @@ class Collection: """ Collection management client. - Provides methods for managing data collections, including listing, - updating metadata, and deleting collections. Collections organize - documents and knowledge graph data into logical groupings for - isolation and access control. + Provides methods for managing data collections within the configured + workspace, including listing, updating metadata, and deleting + collections. """ def __init__(self, api): @@ -45,45 +42,20 @@ class Collection: """ return self.api.request(f"collection-management", request) - def list_collections(self, user, tag_filter=None): + def list_collections(self, tag_filter=None): """ - List all collections for a user. - - Retrieves metadata for all collections owned by the specified user, - with optional filtering by tags. + List all collections in this workspace. Args: - user: User identifier - tag_filter: Optional list of tags to filter collections (default: None) + tag_filter: Optional list of tags to filter collections Returns: list[CollectionMetadata]: List of collection metadata objects - - Raises: - ProtocolException: If response format is invalid - - Example: - ```python - collection = api.collection() - - # List all collections - all_colls = collection.list_collections(user="trustgraph") - for coll in all_colls: - print(f"{coll.collection}: {coll.name}") - print(f" Description: {coll.description}") - print(f" Tags: {', '.join(coll.tags)}") - - # List collections with specific tags - research_colls = collection.list_collections( - user="trustgraph", - tag_filter=["research", "published"] - ) - ``` """ input = { "operation": "list-collections", - "user": user, + "workspace": self.api.workspace, } if tag_filter: @@ -92,7 +64,6 @@ class Collection: object = self.request(input) try: - # Handle case where collections might be None or missing if object is None or "collections" not in object: return [] @@ -102,7 +73,6 @@ class Collection: return [ CollectionMetadata( - user = v["user"], collection = v["collection"], name = v["name"], description = v["description"], @@ -114,15 +84,11 @@ class Collection: logger.error("Failed to parse collection list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def update_collection(self, user, collection, name=None, description=None, tags=None): + def update_collection(self, collection, name=None, description=None, tags=None): """ Update collection metadata. - Updates the name, description, and/or tags for an existing collection. - Only provided fields are updated; others remain unchanged. - Args: - user: User identifier collection: Collection identifier name: New collection name (optional) description: New collection description (optional) @@ -130,35 +96,11 @@ class Collection: Returns: CollectionMetadata: Updated collection metadata, or None if not found - - Raises: - ProtocolException: If response format is invalid - - Example: - ```python - collection_api = api.collection() - - # Update collection metadata - updated = collection_api.update_collection( - user="trustgraph", - collection="default", - name="Default Collection", - description="Main data collection for general use", - tags=["default", "production"] - ) - - # Update only specific fields - updated = collection_api.update_collection( - user="trustgraph", - collection="research", - description="Updated description" - ) - ``` """ input = { "operation": "update-collection", - "user": user, + "workspace": self.api.workspace, "collection": collection, } @@ -175,7 +117,6 @@ class Collection: if "collections" in object and object["collections"]: v = object["collections"][0] return CollectionMetadata( - user = v["user"], collection = v["collection"], name = v["name"], description = v["description"], @@ -186,37 +127,23 @@ class Collection: logger.error("Failed to parse collection update response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def delete_collection(self, user, collection): + def delete_collection(self, collection): """ Delete a collection. - Removes a collection and all its associated data from the system. - Args: - user: User identifier collection: Collection identifier to delete Returns: dict: Empty response object - - Example: - ```python - collection_api = api.collection() - - # Delete a collection - collection_api.delete_collection( - user="trustgraph", - collection="old-collection" - ) - ``` """ input = { "operation": "delete-collection", - "user": user, + "workspace": self.api.workspace, "collection": collection, } - object = self.request(input) + self.request(input) - return {} \ No newline at end of file + return {} diff --git a/trustgraph-base/trustgraph/api/config.py b/trustgraph-base/trustgraph/api/config.py index c8c8d5bb..5f17672f 100644 --- a/trustgraph-base/trustgraph/api/config.py +++ b/trustgraph-base/trustgraph/api/config.py @@ -21,14 +21,16 @@ class Config: and list operations. """ - def __init__(self, api): + def __init__(self, api, workspace="default"): """ Initialize Config client. Args: api: Parent Api instance for making requests + workspace: Workspace to scope all config operations to """ self.api = api + self.workspace = workspace def request(self, request): """ @@ -75,9 +77,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "get", + "workspace": self.workspace, "keys": [ { "type": k.type, "key": k.key } for k in keys @@ -123,9 +125,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "put", + "workspace": self.workspace, "values": [ { "type": v.type, "key": v.key, "value": v.value } for v in values @@ -157,9 +159,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "delete", + "workspace": self.workspace, "keys": [ { "type": v.type, "key": v.key } for v in keys @@ -195,9 +197,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "list", + "workspace": self.workspace, "type": type, } @@ -235,9 +237,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "getvalues", + "workspace": self.workspace, "type": type, } @@ -255,6 +257,46 @@ class Config: except: raise ProtocolException(f"Response not formatted correctly") + def get_values_all_workspaces(self, type): + """ + Get all configuration values of a given type across all workspaces. + + Unlike get_values(), this is not scoped to a single workspace — + it returns every entry of the given type in the system. Each + returned ConfigValue includes its workspace field. Used by + shared processors to load type-scoped config at startup. + + Args: + type: Configuration type (e.g. "prompt", "schema") + + Returns: + list[ConfigValue]: Values across all workspaces; each has + its workspace field populated. + + Raises: + ProtocolException: If response format is invalid + """ + + input = { + "operation": "getvalues-all-ws", + "type": type, + } + + object = self.request(input) + + try: + return [ + ConfigValue( + type = v["type"], + key = v["key"], + value = v["value"], + workspace = v.get("workspace", ""), + ) + for v in object["values"] + ] + except Exception: + raise ProtocolException("Response not formatted correctly") + def all(self): """ Get complete configuration and version. @@ -279,9 +321,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { - "operation": "config" + "operation": "config", + "workspace": self.workspace, } object = self.request(input) diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 08d0b4e7..656ff95f 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -486,7 +486,6 @@ class ExplainabilityClient: self, uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> Optional[ExplainEntity]: """ @@ -502,7 +501,6 @@ class ExplainabilityClient: Args: uri: The entity URI to fetch graph: Named graph to query (e.g., "urn:graph:retrieval") - user: User/keyspace identifier collection: Collection identifier Returns: @@ -515,7 +513,6 @@ class ExplainabilityClient: wire_triples = self.flow.triples_query( s=uri, g=graph, - user=user, collection=collection, limit=100 ) @@ -548,7 +545,7 @@ class ExplainabilityClient: if prev_triples: # Re-fetch and parse wire_triples = self.flow.triples_query( - s=uri, g=graph, user=user, collection=collection, limit=100 + s=uri, g=graph, collection=collection, limit=100 ) if wire_triples: triples = wire_triples_to_tuples(wire_triples) @@ -560,7 +557,6 @@ class ExplainabilityClient: self, uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> Optional[EdgeSelection]: """ @@ -569,7 +565,6 @@ class ExplainabilityClient: Args: uri: The edge selection URI graph: Named graph to query - user: User/keyspace identifier collection: Collection identifier Returns: @@ -578,7 +573,6 @@ class ExplainabilityClient: wire_triples = self.flow.triples_query( s=uri, g=graph, - user=user, collection=collection, limit=100 ) @@ -593,7 +587,6 @@ class ExplainabilityClient: self, uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> Optional[Focus]: """ @@ -602,20 +595,19 @@ class ExplainabilityClient: Args: uri: The Focus entity URI graph: Named graph to query - user: User/keyspace identifier collection: Collection identifier Returns: Focus with populated edge_selections, or None """ - entity = self.fetch_entity(uri, graph, user, collection) + entity = self.fetch_entity(uri, graph, collection) if not isinstance(entity, Focus): return None # Fetch each edge selection for edge_uri in entity.selected_edge_uris: - edge_sel = self.fetch_edge_selection(edge_uri, graph, user, collection) + edge_sel = self.fetch_edge_selection(edge_uri, graph, collection) if edge_sel: entity.edge_selections.append(edge_sel) @@ -624,7 +616,6 @@ class ExplainabilityClient: def resolve_label( self, uri: str, - user: Optional[str] = None, collection: Optional[str] = None ) -> str: """ @@ -632,7 +623,6 @@ class ExplainabilityClient: Args: uri: The URI to get label for - user: User/keyspace identifier collection: Collection identifier Returns: @@ -647,7 +637,6 @@ class ExplainabilityClient: wire_triples = self.flow.triples_query( s=uri, p=RDFS_LABEL, - user=user, collection=collection, limit=1 ) @@ -665,7 +654,6 @@ class ExplainabilityClient: def resolve_edge_labels( self, edge: Dict[str, str], - user: Optional[str] = None, collection: Optional[str] = None ) -> Tuple[str, str, str]: """ @@ -673,22 +661,20 @@ class ExplainabilityClient: Args: edge: Dict with "s", "p", "o" keys - user: User/keyspace identifier collection: Collection identifier Returns: Tuple of (s_label, p_label, o_label) """ - s_label = self.resolve_label(edge.get("s", ""), user, collection) - p_label = self.resolve_label(edge.get("p", ""), user, collection) - o_label = self.resolve_label(edge.get("o", ""), user, collection) + s_label = self.resolve_label(edge.get("s", ""), collection) + p_label = self.resolve_label(edge.get("p", ""), collection) + o_label = self.resolve_label(edge.get("o", ""), collection) return (s_label, p_label, o_label) def fetch_document_content( self, document_uri: str, api: Any, - user: Optional[str] = None, max_content: int = 10000 ) -> str: """ @@ -697,7 +683,6 @@ class ExplainabilityClient: Args: document_uri: The document URI in the librarian api: TrustGraph Api instance for librarian access - user: User identifier for librarian max_content: Maximum content length to return Returns: @@ -712,7 +697,7 @@ class ExplainabilityClient: for attempt in range(self.max_retries): try: library = api.library() - content_bytes = library.get_document_content(user=user, id=doc_id) + content_bytes = library.get_document_content(id=doc_id) # Decode as text try: @@ -736,7 +721,6 @@ class ExplainabilityClient: self, question_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, api: Any = None, max_content: int = 10000 @@ -749,7 +733,6 @@ class ExplainabilityClient: Args: question_uri: The question entity URI graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier api: TrustGraph Api instance for librarian access (optional) max_content: Maximum content length for synthesis @@ -769,7 +752,7 @@ class ExplainabilityClient: } # Fetch question - question = self.fetch_entity(question_uri, graph, user, collection) + question = self.fetch_entity(question_uri, graph, collection) if not isinstance(question, Question): return trace trace["question"] = question @@ -779,7 +762,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=question_uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -790,7 +772,7 @@ class ExplainabilityClient: for t in grounding_triples ] for gnd_uri in grounding_uris: - grounding = self.fetch_entity(gnd_uri, graph, user, collection) + grounding = self.fetch_entity(gnd_uri, graph, collection) if isinstance(grounding, Grounding): trace["grounding"] = grounding break @@ -803,7 +785,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["grounding"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -814,7 +795,7 @@ class ExplainabilityClient: for t in exploration_triples ] for exp_uri in exploration_uris: - exploration = self.fetch_entity(exp_uri, graph, user, collection) + exploration = self.fetch_entity(exp_uri, graph, collection) if isinstance(exploration, Exploration): trace["exploration"] = exploration break @@ -827,7 +808,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["exploration"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -838,7 +818,7 @@ class ExplainabilityClient: for t in focus_triples ] for focus_uri in focus_uris: - focus = self.fetch_focus_with_edges(focus_uri, graph, user, collection) + focus = self.fetch_focus_with_edges(focus_uri, graph, collection) if focus: trace["focus"] = focus break @@ -851,7 +831,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["focus"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -862,7 +841,7 @@ class ExplainabilityClient: for t in synthesis_triples ] for synth_uri in synthesis_uris: - synthesis = self.fetch_entity(synth_uri, graph, user, collection) + synthesis = self.fetch_entity(synth_uri, graph, collection) if isinstance(synthesis, Synthesis): trace["synthesis"] = synthesis break @@ -873,7 +852,6 @@ class ExplainabilityClient: self, question_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, api: Any = None, max_content: int = 10000 @@ -887,7 +865,6 @@ class ExplainabilityClient: Args: question_uri: The question entity URI graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier api: TrustGraph Api instance for librarian access (optional) max_content: Maximum content length for synthesis @@ -906,7 +883,7 @@ class ExplainabilityClient: } # Fetch question - question = self.fetch_entity(question_uri, graph, user, collection) + question = self.fetch_entity(question_uri, graph, collection) if not isinstance(question, Question): return trace trace["question"] = question @@ -916,7 +893,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=question_uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -927,7 +903,7 @@ class ExplainabilityClient: for t in grounding_triples ] for gnd_uri in grounding_uris: - grounding = self.fetch_entity(gnd_uri, graph, user, collection) + grounding = self.fetch_entity(gnd_uri, graph, collection) if isinstance(grounding, Grounding): trace["grounding"] = grounding break @@ -940,7 +916,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["grounding"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -951,7 +926,7 @@ class ExplainabilityClient: for t in exploration_triples ] for exp_uri in exploration_uris: - exploration = self.fetch_entity(exp_uri, graph, user, collection) + exploration = self.fetch_entity(exp_uri, graph, collection) if isinstance(exploration, Exploration): trace["exploration"] = exploration break @@ -964,7 +939,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["exploration"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -975,7 +949,7 @@ class ExplainabilityClient: for t in synthesis_triples ] for synth_uri in synthesis_uris: - synthesis = self.fetch_entity(synth_uri, graph, user, collection) + synthesis = self.fetch_entity(synth_uri, graph, collection) if isinstance(synthesis, Synthesis): trace["synthesis"] = synthesis break @@ -986,7 +960,6 @@ class ExplainabilityClient: self, session_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, api: Any = None, max_content: int = 10000 @@ -1002,7 +975,6 @@ class ExplainabilityClient: Args: session_uri: The agent session/question URI graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier api: TrustGraph Api instance for librarian access (optional) max_content: Maximum content length for conclusion @@ -1019,21 +991,21 @@ class ExplainabilityClient: } # Fetch question/session - question = self.fetch_entity(session_uri, graph, user, collection) + question = self.fetch_entity(session_uri, graph, collection) if not isinstance(question, Question): return trace trace["question"] = question # Follow the provenance chain from the question self._follow_provenance_chain( - session_uri, trace, graph, user, collection, + session_uri, trace, graph, collection, max_depth=50, ) return trace def _follow_provenance_chain( - self, current_uri, trace, graph, user, collection, + self, current_uri, trace, graph, collection, max_depth=50, ): """Recursively follow the provenance chain, handling branches.""" @@ -1044,7 +1016,7 @@ class ExplainabilityClient: derived_triples = self.flow.triples_query( p=PROV_WAS_DERIVED_FROM, o=current_uri, - g=graph, user=user, collection=collection, + g=graph, collection=collection, limit=20 ) @@ -1060,7 +1032,7 @@ class ExplainabilityClient: if not derived_uri: continue - entity = self.fetch_entity(derived_uri, graph, user, collection) + entity = self.fetch_entity(derived_uri, graph, collection) if entity is None: continue @@ -1070,7 +1042,7 @@ class ExplainabilityClient: # Continue following from this entity self._follow_provenance_chain( - derived_uri, trace, graph, user, collection, + derived_uri, trace, graph, collection, max_depth=max_depth - 1, ) @@ -1079,11 +1051,11 @@ class ExplainabilityClient: # Fetch the full sub-trace and embed it. if entity.question_type == "graph-rag": sub_trace = self.fetch_graphrag_trace( - derived_uri, graph, user, collection, + derived_uri, graph, collection, ) elif entity.question_type == "document-rag": sub_trace = self.fetch_docrag_trace( - derived_uri, graph, user, collection, + derived_uri, graph, collection, ) else: sub_trace = None @@ -1100,7 +1072,7 @@ class ExplainabilityClient: terminal = sub_trace.get("synthesis") if terminal: self._follow_provenance_chain( - terminal.uri, trace, graph, user, collection, + terminal.uri, trace, graph, collection, max_depth=max_depth - 1, ) @@ -1110,7 +1082,6 @@ class ExplainabilityClient: def list_sessions( self, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, limit: int = 50 ) -> List[Question]: @@ -1119,7 +1090,6 @@ class ExplainabilityClient: Args: graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier limit: Maximum number of sessions to return @@ -1133,7 +1103,6 @@ class ExplainabilityClient: query_triples = self.flow.triples_query( p=TG_QUERY, g=graph, - user=user, collection=collection, limit=limit ) @@ -1142,7 +1111,7 @@ class ExplainabilityClient: for t in query_triples: question_uri = extract_term_value(t.get("s", {})) if question_uri: - entity = self.fetch_entity(question_uri, graph, user, collection) + entity = self.fetch_entity(question_uri, graph, collection) if isinstance(entity, Question): questions.append(entity) @@ -1154,7 +1123,6 @@ class ExplainabilityClient: s=q.uri, p=PROV_WAS_DERIVED_FROM, g=graph, - user=user, collection=collection, limit=1 ) @@ -1170,7 +1138,6 @@ class ExplainabilityClient: self, session_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> str: """ @@ -1179,7 +1146,6 @@ class ExplainabilityClient: Args: session_uri: The session/question URI graph: Named graph - user: User/keyspace identifier collection: Collection identifier Returns: @@ -1201,7 +1167,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=session_uri, g=graph, - user=user, collection=collection, limit=5 ) @@ -1212,7 +1177,7 @@ class ExplainabilityClient: ] for child_uri in all_child_uris: - entity = self.fetch_entity(child_uri, graph, user, collection) + entity = self.fetch_entity(child_uri, graph, collection) if isinstance(entity, (Analysis, Decomposition, Plan)): return "agent" if isinstance(entity, Exploration): diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 7ee32dad..961e348b 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -115,72 +115,32 @@ class Flow: return FlowInstance(api=self, id=id) def list_blueprints(self): - """ - List all available flow blueprints. + """List blueprints in the current workspace.""" - Returns: - list[str]: List of blueprint names - - Example: - ```python - blueprints = api.flow().list_blueprints() - print(blueprints) # ['default', 'custom-flow', ...] - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "list-blueprints", + "workspace": self.api.workspace, } return self.request(request = input)["blueprint-names"] def get_blueprint(self, blueprint_name): - """ - Get a flow blueprint definition by name. + """Get a flow blueprint definition by name.""" - Args: - blueprint_name: Name of the blueprint to retrieve - - Returns: - dict: Blueprint definition as a dictionary - - Example: - ```python - blueprint = api.flow().get_blueprint("default") - print(blueprint) # Blueprint configuration - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "get-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, } return json.loads(self.request(request = input)["blueprint-definition"]) def put_blueprint(self, blueprint_name, definition): - """ - Create or update a flow blueprint. + """Create or update a flow blueprint.""" - Args: - blueprint_name: Name for the blueprint - definition: Blueprint definition dictionary - - Example: - ```python - definition = { - "services": ["text-completion", "graph-rag"], - "parameters": {"model": "gpt-4"} - } - api.flow().put_blueprint("my-blueprint", definition) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "put-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, "blueprint-definition": json.dumps(definition), } @@ -188,96 +148,43 @@ class Flow: self.request(request = input) def delete_blueprint(self, blueprint_name): - """ - Delete a flow blueprint. + """Delete a flow blueprint.""" - Args: - blueprint_name: Name of the blueprint to delete - - Example: - ```python - api.flow().delete_blueprint("old-blueprint") - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "delete-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, } self.request(request = input) def list(self): - """ - List all active flow instances. + """List flow instances in the current workspace.""" - Returns: - list[str]: List of flow instance IDs - - Example: - ```python - flows = api.flow().list() - print(flows) # ['default', 'flow-1', 'flow-2', ...] - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "list-flows", + "workspace": self.api.workspace, } return self.request(request = input)["flow-ids"] def get(self, id): - """ - Get the definition of a running flow instance. + """Get the definition of a flow instance.""" - Args: - id: Flow instance ID - - Returns: - dict: Flow instance definition - - Example: - ```python - flow_def = api.flow().get("default") - print(flow_def) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "get-flow", + "workspace": self.api.workspace, "flow-id": id, } return json.loads(self.request(request = input)["flow"]) def start(self, blueprint_name, id, description, parameters=None): - """ - Start a new flow instance from a blueprint. + """Start a new flow instance from a blueprint.""" - Args: - blueprint_name: Name of the blueprint to instantiate - id: Unique identifier for the flow instance - description: Human-readable description - parameters: Optional parameters dictionary - - Example: - ```python - api.flow().start( - blueprint_name="default", - id="my-flow", - description="My custom flow", - parameters={"model": "gpt-4"} - ) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "start-flow", + "workspace": self.api.workspace, "flow-id": id, "blueprint-name": blueprint_name, "description": description, @@ -289,21 +196,11 @@ class Flow: self.request(request = input) def stop(self, id): - """ - Stop a running flow instance. + """Stop a running flow instance.""" - Args: - id: Flow instance ID to stop - - Example: - ```python - api.flow().stop("my-flow") - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "stop-flow", + "workspace": self.api.workspace, "flow-id": id, } @@ -349,6 +246,13 @@ class FlowInstance: Returns: dict: Service response """ + # Inject workspace so the gateway can route to the right + # workspace's flow. If already present, keep the caller's value. + if isinstance(request, dict) and "workspace" not in request: + request = { + "workspace": self.api.api.workspace, + **request, + } return self.api.request(path = f"{self.id}/{path}", request = request) def text_completion(self, system, prompt): @@ -392,7 +296,7 @@ class FlowInstance: model=result.get("model"), ) - def agent(self, question, user="trustgraph", state=None, group=None, history=None): + def agent(self, question,state=None, group=None, history=None): """ Execute an agent operation with reasoning and tool use capabilities. @@ -401,7 +305,6 @@ class FlowInstance: Args: question: User question or instruction - user: User identifier (default: "trustgraph") state: Optional state dictionary for stateful conversations group: Optional group identifier for multi-user contexts history: Optional conversation history as list of message dicts @@ -416,8 +319,7 @@ class FlowInstance: # Simple question answer = flow.agent( question="What is the capital of France?", - user="trustgraph" - ) + ) # With conversation history history = [ @@ -425,9 +327,7 @@ class FlowInstance: {"role": "assistant", "content": "Hi! How can I help?"} ] answer = flow.agent( - question="Tell me about Paris", - user="trustgraph", - history=history + question="Tell me about Paris",history=history ) ``` """ @@ -435,7 +335,6 @@ class FlowInstance: # The input consists of a question and optional context input = { "question": question, - "user": user, } # Only include state if it has a value @@ -455,7 +354,7 @@ class FlowInstance: )["answer"] def graph_rag( - self, query, user="trustgraph", collection="default", + self, query,collection="default", entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2, edge_score_limit=30, edge_limit=25, ): @@ -467,7 +366,6 @@ class FlowInstance: Args: query: Natural language query - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") entity_limit: Maximum entities to retrieve (default: 50) triple_limit: Maximum triples per entity (default: 30) @@ -483,9 +381,7 @@ class FlowInstance: ```python flow = api.flow().id("default") response = flow.graph_rag( - query="Tell me about Marie Curie's discoveries", - user="trustgraph", - collection="scientists", + query="Tell me about Marie Curie's discoveries",collection="scientists", entity_limit=20, max_path_length=3 ) @@ -496,7 +392,6 @@ class FlowInstance: # The input consists of a question input = { "query": query, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -519,7 +414,7 @@ class FlowInstance: ) def document_rag( - self, query, user="trustgraph", collection="default", + self, query,collection="default", doc_limit=10, ): """ @@ -530,7 +425,6 @@ class FlowInstance: Args: query: Natural language query - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") doc_limit: Maximum document chunks to retrieve (default: 10) @@ -541,9 +435,7 @@ class FlowInstance: ```python flow = api.flow().id("default") response = flow.document_rag( - query="Summarize the key findings", - user="trustgraph", - collection="research-papers", + query="Summarize the key findings",collection="research-papers", doc_limit=5 ) print(response) @@ -553,7 +445,6 @@ class FlowInstance: # The input consists of a question input = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, } @@ -600,7 +491,7 @@ class FlowInstance: input )["vectors"] - def graph_embeddings_query(self, text, user, collection, limit=10): + def graph_embeddings_query(self, text, collection, limit=10): """ Query knowledge graph entities using semantic similarity. @@ -609,7 +500,6 @@ class FlowInstance: Args: text: Query text for semantic search - user: User/keyspace identifier collection: Collection identifier limit: Maximum number of results (default: 10) @@ -620,9 +510,7 @@ class FlowInstance: ```python flow = api.flow().id("default") results = flow.graph_embeddings_query( - text="physicist who discovered radioactivity", - user="trustgraph", - collection="scientists", + text="physicist who discovered radioactivity",collection="scientists", limit=5 ) # results contains {"entities": [{"entity": {...}, "score": 0.95}, ...]} @@ -636,7 +524,6 @@ class FlowInstance: # Query graph embeddings for semantic search input = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -646,7 +533,7 @@ class FlowInstance: input ) - def document_embeddings_query(self, text, user, collection, limit=10): + def document_embeddings_query(self, text, collection, limit=10): """ Query document chunks using semantic similarity. @@ -655,7 +542,6 @@ class FlowInstance: Args: text: Query text for semantic search - user: User/keyspace identifier collection: Collection identifier limit: Maximum number of results (default: 10) @@ -666,9 +552,7 @@ class FlowInstance: ```python flow = api.flow().id("default") results = flow.document_embeddings_query( - text="machine learning algorithms", - user="trustgraph", - collection="research-papers", + text="machine learning algorithms",collection="research-papers", limit=5 ) # results contains {"chunks": [{"chunk_id": "doc1/p0/c0", "score": 0.95}, ...]} @@ -682,7 +566,6 @@ class FlowInstance: # Query document embeddings for semantic search input = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -805,7 +688,7 @@ class FlowInstance: def triples_query( self, s=None, p=None, o=None, - user=None, collection=None, limit=10000 + collection=None, limit=10000 ): """ Query knowledge graph triples using pattern matching. @@ -817,7 +700,6 @@ class FlowInstance: s: Subject URI (optional, use None for wildcard) p: Predicate URI (optional, use None for wildcard) o: Object URI or Literal (optional, use None for wildcard) - user: User/keyspace identifier (optional) collection: Collection identifier (optional) limit: Maximum results to return (default: 10000) @@ -835,9 +717,7 @@ class FlowInstance: # Find all triples about a specific subject triples = flow.triples_query( - s=Uri("http://example.org/person/marie-curie"), - user="trustgraph", - collection="scientists" + s=Uri("http://example.org/person/marie-curie"),collection="scientists" ) # Find all instances of a specific relationship @@ -851,10 +731,6 @@ class FlowInstance: input = { "limit": limit } - - if user: - input["user"] = user - if collection: input["collection"] = collection @@ -888,7 +764,7 @@ class FlowInstance: ] def load_document( - self, document, id=None, metadata=None, user=None, + self, document, id=None, metadata=None, collection=None, ): """ @@ -901,7 +777,6 @@ class FlowInstance: document: Document content as bytes id: Optional document identifier (auto-generated if None) metadata: Optional metadata (list of Triples or object with emit method) - user: User/keyspace identifier (optional) collection: Collection identifier (optional) Returns: @@ -918,9 +793,7 @@ class FlowInstance: with open("research.pdf", "rb") as f: result = flow.load_document( document=f.read(), - id="research-001", - user="trustgraph", - collection="papers" + id="research-001",collection="papers" ) ``` """ @@ -955,10 +828,6 @@ class FlowInstance: "metadata": triples, "data": base64.b64encode(document).decode("utf-8"), } - - if user: - input["user"] = user - if collection: input["collection"] = collection @@ -969,7 +838,7 @@ class FlowInstance: def load_text( self, text, id=None, metadata=None, charset="utf-8", - user=None, collection=None, + collection=None, ): """ Load text content for processing. @@ -982,7 +851,6 @@ class FlowInstance: id: Optional document identifier (auto-generated if None) metadata: Optional metadata (list of Triples or object with emit method) charset: Character encoding (default: "utf-8") - user: User/keyspace identifier (optional) collection: Collection identifier (optional) Returns: @@ -1000,9 +868,7 @@ class FlowInstance: result = flow.load_text( text=text_content, id="text-001", - charset="utf-8", - user="trustgraph", - collection="documents" + charset="utf-8",collection="documents" ) ``` """ @@ -1035,10 +901,6 @@ class FlowInstance: "charset": charset, "text": base64.b64encode(text).decode("utf-8"), } - - if user: - input["user"] = user - if collection: input["collection"] = collection @@ -1048,7 +910,7 @@ class FlowInstance: ) def rows_query( - self, query, user="trustgraph", collection="default", + self, query,collection="default", variables=None, operation_name=None ): """ @@ -1059,7 +921,6 @@ class FlowInstance: Args: query: GraphQL query string - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") variables: Optional query variables dictionary operation_name: Optional operation name for multi-operation documents @@ -1085,9 +946,7 @@ class FlowInstance: } ''' result = flow.rows_query( - query=query, - user="trustgraph", - collection="scientists" + query=query,collection="scientists" ) # Query with variables @@ -1109,7 +968,6 @@ class FlowInstance: # The input consists of a GraphQL query and optional variables input = { "query": query, - "user": user, "collection": collection, } @@ -1145,7 +1003,7 @@ class FlowInstance: return result def sparql_query( - self, query, user="trustgraph", collection="default", + self, query,collection="default", limit=10000 ): """ @@ -1153,7 +1011,6 @@ class FlowInstance: Args: query: SPARQL 1.1 query string - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") limit: Safety limit on results (default: 10000) @@ -1169,7 +1026,6 @@ class FlowInstance: input = { "query": query, - "user": user, "collection": collection, "limit": limit, } @@ -1213,14 +1069,13 @@ class FlowInstance: return response - def structured_query(self, question, user="trustgraph", collection="default"): + def structured_query(self, question,collection="default"): """ Execute a natural language question against structured data. Combines NLP query conversion and GraphQL execution. Args: question: Natural language question - user: Cassandra keyspace identifier (default: "trustgraph") collection: Data collection identifier (default: "default") Returns: @@ -1229,7 +1084,6 @@ class FlowInstance: input = { "question": question, - "user": user, "collection": collection } @@ -1383,7 +1237,7 @@ class FlowInstance: return response["schema-matches"] def row_embeddings_query( - self, text, schema_name, user="trustgraph", collection="default", + self, text, schema_name,collection="default", index_name=None, limit=10 ): """ @@ -1396,7 +1250,6 @@ class FlowInstance: Args: text: Query text for semantic search schema_name: Schema name to search within - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") index_name: Optional index name to filter search to specific index limit: Maximum number of results (default: 10) @@ -1412,9 +1265,7 @@ class FlowInstance: # Search for customers by name similarity results = flow.row_embeddings_query( text="John Smith", - schema_name="customers", - user="trustgraph", - collection="sales", + schema_name="customers",collection="sales", limit=5 ) @@ -1436,7 +1287,6 @@ class FlowInstance: input = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index 84f98918..c3ec2308 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -63,105 +63,50 @@ class Knowledge: """ return self.api.request(f"knowledge", request) - def list_kg_cores(self, user="trustgraph"): + def list_kg_cores(self): """ - List all available knowledge graph cores. - - Retrieves the IDs of all KG cores available for the specified user. - - Args: - user: User identifier (default: "trustgraph") + List all available knowledge graph cores in this workspace. Returns: list[str]: List of KG core identifiers - - Example: - ```python - knowledge = api.knowledge() - - # List available KG cores - cores = knowledge.list_kg_cores(user="trustgraph") - print(f"Available KG cores: {cores}") - ``` """ - # The input consists of system and prompt strings input = { "operation": "list-kg-cores", - "user": user, + "workspace": self.api.workspace, } return self.request(request = input)["ids"] - def delete_kg_core(self, id, user="trustgraph"): + def delete_kg_core(self, id): """ - Delete a knowledge graph core. - - Removes a KG core from storage. This does not affect currently loaded - cores in flows. + Delete a knowledge graph core in this workspace. Args: id: KG core identifier to delete - user: User identifier (default: "trustgraph") - - Example: - ```python - knowledge = api.knowledge() - - # Delete a KG core - knowledge.delete_kg_core(id="medical-kb-v1", user="trustgraph") - ``` """ - # The input consists of system and prompt strings input = { "operation": "delete-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, } self.request(request = input) - def load_kg_core(self, id, user="trustgraph", flow="default", - collection="default"): + def load_kg_core(self, id, flow="default", collection="default"): """ Load a knowledge graph core into a flow. - Makes a KG core available for use in queries and RAG operations within - the specified flow and collection. - Args: id: KG core identifier to load - user: User identifier (default: "trustgraph") flow: Flow instance to load into (default: "default") collection: Collection to associate with (default: "default") - - Example: - ```python - knowledge = api.knowledge() - - # Load a medical knowledge base into the default flow - knowledge.load_kg_core( - id="medical-kb-v1", - user="trustgraph", - flow="default", - collection="medical" - ) - - # Now the flow can use this KG core for RAG queries - flow = api.flow().id("default") - response = flow.graph_rag( - query="What are the symptoms of diabetes?", - user="trustgraph", - collection="medical" - ) - ``` """ - # The input consists of system and prompt strings input = { "operation": "load-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, "flow": flow, "collection": collection, @@ -169,35 +114,18 @@ class Knowledge: self.request(request = input) - def unload_kg_core(self, id, user="trustgraph", flow="default"): + def unload_kg_core(self, id, flow="default"): """ Unload a knowledge graph core from a flow. - Removes a KG core from active use in the specified flow, freeing - resources while keeping the core available in storage. - Args: id: KG core identifier to unload - user: User identifier (default: "trustgraph") flow: Flow instance to unload from (default: "default") - - Example: - ```python - knowledge = api.knowledge() - - # Unload a KG core when no longer needed - knowledge.unload_kg_core( - id="medical-kb-v1", - user="trustgraph", - flow="default" - ) - ``` """ - # The input consists of system and prompt strings input = { "operation": "unload-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, "flow": flow, } diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index c66598aa..8f99e601 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -94,7 +94,7 @@ class Library: return self.api.request(f"librarian", request) def add_document( - self, document, id, metadata, user, title, comments, + self, document, id, metadata, title, comments, kind="text/plain", tags=[], on_progress=None, ): """ @@ -108,7 +108,6 @@ class Library: document: Document content as bytes id: Document identifier (auto-generated if None) metadata: Document metadata as list of Triple objects or object with emit method - user: User/owner identifier title: Document title comments: Document description or comments kind: MIME type of the document (default: "text/plain") @@ -131,7 +130,6 @@ class Library: document=f.read(), id="research-001", metadata=[], - user="trustgraph", title="Research Paper", comments="Key findings in quantum computing", kind="application/pdf", @@ -147,7 +145,6 @@ class Library: document=f.read(), id="large-doc-001", metadata=[], - user="trustgraph", title="Large Document", comments="A very large document", kind="application/pdf", @@ -176,7 +173,6 @@ class Library: document=document, id=id, metadata=metadata, - user=user, title=title, comments=comments, kind=kind, @@ -213,6 +209,7 @@ class Library: input = { "operation": "add-document", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), @@ -220,7 +217,7 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "user": user, + "workspace": self.api.workspace, "tags": tags }, "content": base64.b64encode(document).decode("utf-8"), @@ -229,7 +226,7 @@ class Library: return self.request(input) def _add_document_chunked( - self, document, id, metadata, user, title, comments, + self, document, id, metadata, title, comments, kind, tags, on_progress=None, ): """ @@ -245,13 +242,14 @@ class Library: # Begin upload session begin_request = { "operation": "begin-upload", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), "kind": kind, "title": title, "comments": comments, - "user": user, + "workspace": self.api.workspace, "tags": tags, }, "total-size": total_size, @@ -279,10 +277,10 @@ class Library: chunk_request = { "operation": "upload-chunk", + "workspace": self.api.workspace, "upload-id": upload_id, "chunk-index": chunk_index, "content": base64.b64encode(chunk_data).decode("utf-8"), - "user": user, } chunk_response = self.request(chunk_request) @@ -298,8 +296,8 @@ class Library: # Complete upload complete_request = { "operation": "complete-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } complete_response = self.request(complete_request) @@ -314,8 +312,8 @@ class Library: try: abort_request = { "operation": "abort-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } self.request(abort_request) logger.info(f"Aborted failed upload {upload_id}") @@ -323,15 +321,13 @@ class Library: logger.warning(f"Failed to abort upload: {abort_error}") raise - def get_documents(self, user, include_children=False): + def get_documents(self, include_children=False): """ - List all documents for a user. + List all documents in the current workspace. - Retrieves metadata for all documents owned by the specified user. By default, only returns top-level documents (not child/extracted documents). Args: - user: User identifier include_children: If True, also include child documents (default: False) Returns: @@ -345,7 +341,7 @@ class Library: library = api.library() # Get only top-level documents - docs = library.get_documents(user="trustgraph") + docs = library.get_documents() for doc in docs: print(f"{doc.id}: {doc.title} ({doc.kind})") @@ -353,13 +349,13 @@ class Library: print(f" Tags: {', '.join(doc.tags)}") # Get all documents including extracted pages - all_docs = library.get_documents(user="trustgraph", include_children=True) + all_docs = library.get_documents(include_children=True) ``` """ input = { "operation": "list-documents", - "user": user, + "workspace": self.api.workspace, "include-children": include_children, } @@ -381,7 +377,7 @@ class Library: ) for w in v["metadata"] ], - user = v["user"], + workspace = v.get("workspace", ""), tags = v["tags"], parent_id = v.get("parent-id", ""), document_type = v.get("document-type", "source"), @@ -392,14 +388,13 @@ class Library: logger.error("Failed to parse document list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def get_document(self, user, id): + def get_document(self, id): """ Get metadata for a specific document. Retrieves the metadata for a single document by ID. Args: - user: User identifier id: Document identifier Returns: @@ -411,7 +406,7 @@ class Library: Example: ```python library = api.library() - doc = library.get_document(user="trustgraph", id="doc-123") + doc = library.get_document(id="doc-123") print(f"Title: {doc.title}") print(f"Comments: {doc.comments}") ``` @@ -419,7 +414,7 @@ class Library: input = { "operation": "get-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -441,7 +436,7 @@ class Library: ) for w in doc["metadata"] ], - user = doc["user"], + workspace = doc.get("workspace", ""), tags = doc["tags"], parent_id = doc.get("parent-id", ""), document_type = doc.get("document-type", "source"), @@ -450,14 +445,13 @@ class Library: logger.error("Failed to parse document response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def update_document(self, user, id, metadata): + def update_document(self, id, metadata): """ Update document metadata. Updates the metadata for an existing document in the library. Args: - user: User identifier id: Document identifier metadata: Updated DocumentMetadata object @@ -472,7 +466,7 @@ class Library: library = api.library() # Get existing document - doc = library.get_document(user="trustgraph", id="doc-123") + doc = library.get_document(id="doc-123") # Update metadata doc.title = "Updated Title" @@ -481,7 +475,6 @@ class Library: # Save changes updated_doc = library.update_document( - user="trustgraph", id="doc-123", metadata=doc ) @@ -490,8 +483,9 @@ class Library: input = { "operation": "update-document", + "workspace": self.api.workspace, "document-metadata": { - "user": user, + "workspace": self.api.workspace, "document-id": id, "time": metadata.time, "title": metadata.title, @@ -526,21 +520,20 @@ class Library: ) for w in doc["metadata"] ], - user = doc["user"], + workspace = doc.get("workspace", ""), tags = doc["tags"] ) except Exception as e: logger.error("Failed to parse document update response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def remove_document(self, user, id): + def remove_document(self, id): """ Remove a document from the library. Deletes a document and its metadata from the library. Args: - user: User identifier id: Document identifier to remove Returns: @@ -549,13 +542,13 @@ class Library: Example: ```python library = api.library() - library.remove_document(user="trustgraph", id="doc-123") + library.remove_document(id="doc-123") ``` """ input = { "operation": "remove-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -565,7 +558,7 @@ class Library: def start_processing( self, id, document_id, flow="default", - user="trustgraph", collection="default", tags=[], + collection="default", tags=[], ): """ Start a document processing workflow. @@ -577,7 +570,6 @@ class Library: id: Unique processing job identifier document_id: ID of the document to process flow: Flow instance to use for processing (default: "default") - user: User identifier (default: "trustgraph") collection: Target collection for processed data (default: "default") tags: List of tags for the processing job (default: []) @@ -593,7 +585,6 @@ class Library: id="proc-001", document_id="doc-123", flow="default", - user="trustgraph", collection="research", tags=["automated", "extract"] ) @@ -602,12 +593,13 @@ class Library: input = { "operation": "add-processing", + "workspace": self.api.workspace, "processing-metadata": { "id": id, "document-id": document_id, "time": int(time.time()), "flow": flow, - "user": user, + "workspace": self.api.workspace, "collection": collection, "tags": tags, } @@ -618,7 +610,7 @@ class Library: return {} def stop_processing( - self, id, user="trustgraph", + self, id, ): """ Stop a running document processing job. @@ -627,7 +619,6 @@ class Library: Args: id: Processing job identifier to stop - user: User identifier (default: "trustgraph") Returns: dict: Empty response object @@ -635,29 +626,26 @@ class Library: Example: ```python library = api.library() - library.stop_processing(id="proc-001", user="trustgraph") + library.stop_processing(id="proc-001") ``` """ input = { "operation": "remove-processing", + "workspace": self.api.workspace, "processing-id": id, - "user": user, } object = self.request(input) return {} - def get_processings(self, user="trustgraph"): + def get_processings(self): """ List all active document processing jobs. Retrieves metadata for all currently running document processing workflows - for the specified user. - - Args: - user: User identifier (default: "trustgraph") + in the current workspace. Returns: list[ProcessingMetadata]: List of processing job metadata objects @@ -668,7 +656,7 @@ class Library: Example: ```python library = api.library() - jobs = library.get_processings(user="trustgraph") + jobs = library.get_processings() for job in jobs: print(f"Job {job.id}:") @@ -681,7 +669,7 @@ class Library: input = { "operation": "list-processing", - "user": user, + "workspace": self.api.workspace, } object = self.request(input) @@ -693,7 +681,7 @@ class Library: document_id = v["document-id"], time = datetime.datetime.fromtimestamp(v["time"]), flow = v["flow"], - user = v["user"], + workspace = v.get("workspace", ""), collection = v["collection"], tags = v["tags"], ) @@ -705,23 +693,20 @@ class Library: # Chunked upload management methods - def get_pending_uploads(self, user): + def get_pending_uploads(self): """ - List all pending (in-progress) uploads for a user. + List all pending (in-progress) uploads in the current workspace. Retrieves information about chunked uploads that have been started but not yet completed. - Args: - user: User identifier - Returns: list[dict]: List of pending upload information Example: ```python library = api.library() - pending = library.get_pending_uploads(user="trustgraph") + pending = library.get_pending_uploads() for upload in pending: print(f"Upload {upload['upload_id']}:") @@ -731,14 +716,14 @@ class Library: """ input = { "operation": "list-uploads", - "user": user, + "workspace": self.api.workspace, } response = self.request(input) return response.get("upload-sessions", []) - def get_upload_status(self, upload_id, user): + def get_upload_status(self, upload_id): """ Get the status of a specific upload. @@ -747,7 +732,6 @@ class Library: Args: upload_id: Upload session identifier - user: User identifier Returns: dict: Upload status information including: @@ -763,10 +747,7 @@ class Library: Example: ```python library = api.library() - status = library.get_upload_status( - upload_id="abc-123", - user="trustgraph" - ) + status = library.get_upload_status(upload_id="abc-123") if status['state'] == 'in-progress': print(f"Missing chunks: {status['missing_chunks']}") @@ -774,13 +755,13 @@ class Library: """ input = { "operation": "get-upload-status", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(input) - def abort_upload(self, upload_id, user): + def abort_upload(self, upload_id): """ Abort an in-progress upload. @@ -788,7 +769,6 @@ class Library: Args: upload_id: Upload session identifier - user: User identifier Returns: dict: Empty response on success @@ -796,18 +776,18 @@ class Library: Example: ```python library = api.library() - library.abort_upload(upload_id="abc-123", user="trustgraph") + library.abort_upload(upload_id="abc-123") ``` """ input = { "operation": "abort-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(input) - def resume_upload(self, upload_id, document, user, on_progress=None): + def resume_upload(self, upload_id, document, on_progress=None): """ Resume an interrupted upload. @@ -817,7 +797,6 @@ class Library: Args: upload_id: Upload session identifier to resume document: Complete document content as bytes - user: User identifier on_progress: Optional callback(bytes_sent, total_bytes) for progress updates Returns: @@ -828,23 +807,19 @@ class Library: library = api.library() # Check what's missing - status = library.get_upload_status( - upload_id="abc-123", - user="trustgraph" - ) + status = library.get_upload_status(upload_id="abc-123") if status['state'] == 'in-progress': # Resume with the same document with open("large_document.pdf", "rb") as f: library.resume_upload( upload_id="abc-123", - document=f.read(), - user="trustgraph" + document=f.read() ) ``` """ # Get current status - status = self.get_upload_status(upload_id, user) + status = self.get_upload_status(upload_id) if status.get("upload-state") == "expired": raise RuntimeError("Upload session has expired, please start a new upload") @@ -867,10 +842,10 @@ class Library: chunk_request = { "operation": "upload-chunk", + "workspace": self.api.workspace, "upload-id": upload_id, "chunk-index": chunk_index, "content": base64.b64encode(chunk_data).decode("utf-8"), - "user": user, } self.request(chunk_request) @@ -886,8 +861,8 @@ class Library: # Complete upload complete_request = { "operation": "complete-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(complete_request) @@ -895,7 +870,7 @@ class Library: # Child document methods def add_child_document( - self, document, id, parent_id, user, title, comments, + self, document, id, parent_id, title, comments, kind="text/plain", tags=[], metadata=None, ): """ @@ -909,7 +884,6 @@ class Library: document: Document content as bytes id: Document identifier (auto-generated if None) parent_id: Parent document identifier (required) - user: User/owner identifier title: Document title comments: Document description or comments kind: MIME type of the document (default: "text/plain") @@ -931,7 +905,6 @@ class Library: document=page_text.encode('utf-8'), id="doc-123-page-1", parent_id="doc-123", - user="trustgraph", title="Page 1 of Research Paper", comments="First page extracted from PDF", kind="text/plain", @@ -964,6 +937,7 @@ class Library: input = { "operation": "add-child-document", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), @@ -971,7 +945,7 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "user": user, + "workspace": self.api.workspace, "tags": tags, "parent-id": parent_id, "document-type": "extracted", @@ -981,13 +955,12 @@ class Library: return self.request(input) - def list_children(self, document_id, user): + def list_children(self, document_id): """ List all child documents for a given parent document. Args: document_id: Parent document identifier - user: User identifier Returns: list[DocumentMetadata]: List of child document metadata objects @@ -995,10 +968,7 @@ class Library: Example: ```python library = api.library() - children = library.list_children( - document_id="doc-123", - user="trustgraph" - ) + children = library.list_children(document_id="doc-123") for child in children: print(f"{child.id}: {child.title}") @@ -1006,8 +976,8 @@ class Library: """ input = { "operation": "list-children", + "workspace": self.api.workspace, "document-id": document_id, - "user": user, } response = self.request(input) @@ -1028,7 +998,7 @@ class Library: ) for w in v.get("metadata", []) ], - user=v["user"], + workspace=v.get("workspace", ""), tags=v.get("tags", []), parent_id=v.get("parent-id", ""), document_type=v.get("document-type", "source"), @@ -1039,14 +1009,13 @@ class Library: logger.error("Failed to parse children response", exc_info=True) raise ProtocolException("Response not formatted correctly") - def get_document_content(self, user, id): + def get_document_content(self, id): """ Get the content of a document. Retrieves the full content of a document as bytes. Args: - user: User identifier id: Document identifier Returns: @@ -1055,10 +1024,7 @@ class Library: Example: ```python library = api.library() - content = library.get_document_content( - user="trustgraph", - id="doc-123" - ) + content = library.get_document_content(id="doc-123") # Write to file with open("output.pdf", "wb") as f: @@ -1067,7 +1033,7 @@ class Library: """ input = { "operation": "get-document-content", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -1076,7 +1042,7 @@ class Library: return base64.b64decode(content_b64) - def stream_document_to_file(self, user, id, file_path, chunk_size=1024*1024, on_progress=None): + def stream_document_to_file(self, id, file_path, chunk_size=1024*1024, on_progress=None): """ Stream document content to a file. @@ -1084,7 +1050,6 @@ class Library: enabling memory-efficient handling of large documents. Args: - user: User identifier id: Document identifier file_path: Path to write the document content chunk_size: Size of each chunk to download (default 1MB) @@ -1101,7 +1066,6 @@ class Library: print(f"Downloaded {received}/{total} bytes") library.stream_document_to_file( - user="trustgraph", id="large-doc-123", file_path="/tmp/document.pdf", on_progress=progress @@ -1116,7 +1080,7 @@ class Library: while True: input = { "operation": "stream-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, "chunk-index": chunk_index, "chunk-size": chunk_size, diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index c590c9b4..4eade3e8 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -84,10 +84,14 @@ class SocketClient: for streaming responses. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__( + self, url: str, timeout: int, token: Optional[str], + workspace: str = "default", + ) -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace self._request_counter: int = 0 self._lock: Lock = Lock() self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -251,6 +255,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -290,6 +295,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -328,6 +334,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -488,7 +495,6 @@ class SocketFlowInstance: def agent( self, question: str, - user: str, state: Optional[Dict[str, Any]] = None, group: Optional[str] = None, history: Optional[List[Dict[str, Any]]] = None, @@ -498,7 +504,6 @@ class SocketFlowInstance: """Execute an agent operation with streaming support.""" request = { "question": question, - "user": user, "streaming": streaming } if state is not None: @@ -514,7 +519,6 @@ class SocketFlowInstance: def agent_explain( self, question: str, - user: str, collection: str, state: Optional[Dict[str, Any]] = None, group: Optional[str] = None, @@ -524,7 +528,6 @@ class SocketFlowInstance: """Execute an agent operation with explainability support.""" request = { "question": question, - "user": user, "collection": collection, "streaming": True } @@ -574,7 +577,6 @@ class SocketFlowInstance: def graph_rag( self, query: str, - user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, @@ -592,7 +594,6 @@ class SocketFlowInstance: """ request = { "query": query, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -619,7 +620,6 @@ class SocketFlowInstance: def graph_rag_explain( self, query: str, - user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, @@ -632,7 +632,6 @@ class SocketFlowInstance: """Execute graph-based RAG query with explainability support.""" request = { "query": query, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -653,7 +652,6 @@ class SocketFlowInstance: def document_rag( self, query: str, - user: str, collection: str, doc_limit: int = 10, streaming: bool = False, @@ -666,7 +664,6 @@ class SocketFlowInstance: """ request = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": streaming @@ -688,7 +685,6 @@ class SocketFlowInstance: def document_rag_explain( self, query: str, - user: str, collection: str, doc_limit: int = 10, **kwargs: Any @@ -696,7 +692,6 @@ class SocketFlowInstance: """Execute document-based RAG query with explainability support.""" request = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": True, @@ -748,7 +743,6 @@ class SocketFlowInstance: def graph_embeddings_query( self, text: str, - user: str, collection: str, limit: int = 10, **kwargs: Any @@ -759,7 +753,6 @@ class SocketFlowInstance: request = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -770,7 +763,6 @@ class SocketFlowInstance: def document_embeddings_query( self, text: str, - user: str, collection: str, limit: int = 10, **kwargs: Any @@ -781,7 +773,6 @@ class SocketFlowInstance: request = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -802,7 +793,6 @@ class SocketFlowInstance: p: Optional[Union[str, Dict[str, Any]]] = None, o: Optional[Union[str, Dict[str, Any]]] = None, g: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, **kwargs: Any @@ -822,8 +812,6 @@ class SocketFlowInstance: request["o"] = o_term if g is not None: request["g"] = g - if user is not None: - request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) @@ -839,7 +827,6 @@ class SocketFlowInstance: p: Optional[Union[str, Dict[str, Any]]] = None, o: Optional[Union[str, Dict[str, Any]]] = None, g: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, batch_size: int = 20, @@ -864,8 +851,6 @@ class SocketFlowInstance: request["o"] = o_term if g is not None: request["g"] = g - if user is not None: - request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) @@ -879,7 +864,6 @@ class SocketFlowInstance: def sparql_query_stream( self, query: str, - user: str = "trustgraph", collection: str = "default", limit: int = 10000, batch_size: int = 20, @@ -888,7 +872,6 @@ class SocketFlowInstance: """Execute a SPARQL query with streaming batches.""" request = { "query": query, - "user": user, "collection": collection, "limit": limit, "streaming": True, @@ -904,7 +887,6 @@ class SocketFlowInstance: def rows_query( self, query: str, - user: str, collection: str, variables: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, @@ -913,7 +895,6 @@ class SocketFlowInstance: """Execute a GraphQL query against structured rows.""" request = { "query": query, - "user": user, "collection": collection } if variables: @@ -943,7 +924,6 @@ class SocketFlowInstance: self, text: str, schema_name: str, - user: str = "trustgraph", collection: str = "default", index_name: Optional[str] = None, limit: int = 10, @@ -956,7 +936,6 @@ class SocketFlowInstance: request = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index f5987b0e..129f807a 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -45,10 +45,13 @@ class ConfigValue: type: Configuration type/category key: Specific configuration key value: Configuration value as string + workspace: Workspace the value belongs to. Only populated for + responses to getvalues-all-ws; empty otherwise. """ type : str key : str value : str + workspace : str = "" @dataclasses.dataclass class DocumentMetadata: @@ -62,7 +65,7 @@ class DocumentMetadata: title: Document title comments: Additional comments or description metadata: List of RDF triples providing structured metadata - user: User/owner identifier + workspace: Workspace the document belongs to tags: List of tags for categorization parent_id: Parent document ID for child documents (empty for top-level docs) document_type: "source" for uploaded documents, "extracted" for derived content @@ -73,7 +76,7 @@ class DocumentMetadata: title : str comments : str metadata : List[Triple] - user : str + workspace : str tags : List[str] parent_id : str = "" document_type : str = "source" @@ -88,7 +91,7 @@ class ProcessingMetadata: document_id: ID of the document being processed time: Processing start timestamp flow: Flow instance handling the processing - user: User identifier + workspace: Workspace the processing job belongs to collection: Target collection for processed data tags: List of tags for categorization """ @@ -96,7 +99,7 @@ class ProcessingMetadata: document_id : str time : datetime.datetime flow : str - user : str + workspace : str collection : str tags : List[str] @@ -105,17 +108,15 @@ class CollectionMetadata: """ Metadata for a data collection. - Collections provide logical grouping and isolation for documents and - knowledge graph data. + Collections provide logical grouping within a workspace for documents + and knowledge graph data. Attributes: - user: User/owner identifier collection: Collection identifier name: Human-readable collection name description: Collection description tags: List of tags for categorization """ - user : str collection : str name : str description : str diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 9b9328cb..a7ce4961 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -125,21 +125,39 @@ class AsyncProcessor: response_metrics = config_resp_metrics, ) - async def fetch_config(self): - """Fetch full config from config service using a short-lived - request/response client. Returns (config, version) or raises.""" - client = self._create_config_client() - try: - await client.start() - resp = await client.request( - ConfigRequest(operation="config"), - timeout=10, - ) - if resp.error: - raise RuntimeError(f"Config error: {resp.error.message}") - return resp.config, resp.version - finally: - await client.stop() + async def _fetch_type_workspace(self, client, workspace, config_type): + """Fetch config values of a single type within one workspace. + Returns dict of {key: value}.""" + resp = await client.request( + ConfigRequest( + operation="getvalues", + workspace=workspace, + type=config_type, + ), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + return {v.key: v.value for v in resp.values} + + async def _fetch_type_all_workspaces(self, client, config_type): + """Fetch config values of a single type across all workspaces. + Returns dict of {workspace: {key: value}}.""" + resp = await client.request( + ConfigRequest( + operation="getvalues-all-ws", + type=config_type, + ), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + + grouped = {} + for v in resp.values: + ws = grouped.setdefault(v.workspace, {}) + ws[v.key] = v.value + return grouped, resp.version # This is called to start dynamic behaviour. # Implements the subscribe-then-fetch pattern to avoid race conditions. @@ -155,21 +173,51 @@ class AsyncProcessor: # processed by on_config_notify, which does the version check async def fetch_and_apply_config(self): - """Fetch full config from config service and apply to all handlers. - Retries until successful — config service may not be ready yet.""" + """Startup: for each registered handler, fetch config for all its + types across all workspaces and invoke the handler once per + workspace. Retries until successful — config service may not be + ready yet.""" while self.running: try: - config, version = await self.fetch_config() + client = self._create_config_client() + try: + await client.start() - logger.info(f"Fetched config version {version}") + version = 0 - self.config_version = version + for entry in self.config_handlers: + handler_types = entry["types"] - # Apply to all handlers (startup = invoke all) - for entry in self.config_handlers: - await entry["handler"](config, version) + # Handlers registered without types get nothing + # at startup (there is no "all types" fetch). + if not handler_types: + continue + + # Group all registered types by workspace: + # {workspace: {type: {key: value}}} + per_ws = {} + for t in handler_types: + type_data, v = \ + await self._fetch_type_all_workspaces( + client, t, + ) + version = max(version, v) + for ws, kv in type_data.items(): + per_ws.setdefault(ws, {})[t] = kv + + # Call the handler once per workspace + for ws, config in per_ws.items(): + await entry["handler"](ws, config, version) + + logger.info( + f"Applied startup config version {version}" + ) + self.config_version = version + + finally: + await client.stop() return @@ -204,8 +252,9 @@ class AsyncProcessor: # Called when a config notify message arrives async def on_config_notify(self, message, consumer, flow): - notify_version = message.value().version - notify_types = set(message.value().types) + v = message.value() + notify_version = v.version + changes = v.changes # dict of type -> [workspaces] # Skip if we already have this version or newer if notify_version <= self.config_version: @@ -215,41 +264,60 @@ class AsyncProcessor: ) return - # Check if any handler cares about the affected types - if notify_types: - any_interested = False - for entry in self.config_handlers: - handler_types = entry["types"] - if handler_types is None or notify_types & handler_types: - any_interested = True - break + notify_types = set(changes.keys()) - if not any_interested: - logger.debug( - f"Ignoring config notify v{notify_version}, " - f"no handlers for types {notify_types}" - ) - self.config_version = notify_version - return + # Filter out handlers that don't care about any of the changed + # types. A handler registered without types never fires on + # notifications (nothing to scope to). + interested = [] + for entry in self.config_handlers: + handler_types = entry["types"] + if handler_types and notify_types & handler_types: + interested.append(entry) + + if not interested: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"no handlers for types {notify_types}" + ) + self.config_version = notify_version + return logger.info( - f"Config notify v{notify_version} types={list(notify_types)}, " - f"fetching config..." + f"Config notify v{notify_version} " + f"types={list(notify_types)}, fetching config..." ) - # Fetch full config using short-lived client try: - config, version = await self.fetch_config() + client = self._create_config_client() + try: + await client.start() - self.config_version = version + for entry in interested: + handler_types = entry["types"] - # Invoke handlers that care about the affected types - for entry in self.config_handlers: - handler_types = entry["types"] - if handler_types is None: - await entry["handler"](config, version) - elif not notify_types or notify_types & handler_types: - await entry["handler"](config, version) + # Build {workspace: {type: {key: value}}} for types + # this handler cares about, where the workspace was + # affected for that type. + per_ws = {} + for t in handler_types: + if t not in changes: + continue + for ws in changes[t]: + kv = await self._fetch_type_workspace( + client, ws, t, + ) + per_ws.setdefault(ws, {})[t] = kv + + for ws, config in per_ws.items(): + await entry["handler"]( + ws, config, notify_version, + ) + + finally: + await client.stop() + + self.config_version = notify_version except Exception as e: logger.error( diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index 4bd78428..3771d78e 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -48,12 +48,13 @@ class ChunkingService(FlowProcessor): await super(ChunkingService, self).start() await self.librarian.start() - async def get_document_text(self, doc): + async def get_document_text(self, doc, workspace): """ Get text content from a TextDocument, fetching from librarian if needed. Args: doc: TextDocument with either inline text or document_id + workspace: Workspace for librarian lookup (from flow.workspace) Returns: str: The document text content @@ -62,7 +63,7 @@ class ChunkingService(FlowProcessor): logger.info(f"Fetching document {doc.document_id} from librarian...") text = await self.librarian.fetch_document_text( document_id=doc.document_id, - user=doc.metadata.user, + workspace=workspace, ) logger.info(f"Fetched {len(text)} characters from librarian") return text diff --git a/trustgraph-base/trustgraph/base/collection_config_handler.py b/trustgraph-base/trustgraph/base/collection_config_handler.py index 8c1af822..4cb91c53 100644 --- a/trustgraph-base/trustgraph/base/collection_config_handler.py +++ b/trustgraph-base/trustgraph/base/collection_config_handler.py @@ -15,114 +15,139 @@ class CollectionConfigHandler: Storage services should: 1. Inherit from this class along with their service base class 2. Call register_config_handler(self.on_collection_config) in __init__ - 3. Implement create_collection(user, collection, metadata) method - 4. Implement delete_collection(user, collection) method + 3. Implement create_collection(workspace, collection, metadata) method + 4. Implement delete_collection(workspace, collection) method """ def __init__(self, **kwargs): - # Track known collections: {(user, collection): metadata_dict} + # Track known collections: {(workspace, collection): metadata_dict} self.known_collections: Dict[tuple, dict] = {} # Pass remaining kwargs up the inheritance chain super().__init__(**kwargs) - async def on_collection_config(self, config: dict, version: int): + async def on_collection_config( + self, workspace: str, config: dict, version: int + ): """ Handle config push messages and extract collection information + for a single workspace. Args: + workspace: Workspace the config applies to config: Configuration dictionary from ConfigPush message version: Configuration version number """ - logger.info(f"Processing collection configuration (version {version})") + logger.info( + f"Processing collection configuration " + f"(version {version}, workspace {workspace})" + ) - # Extract collections from config (treat missing key as empty) + # Extract collections from config (treat missing key as empty). + # Each config key IS the collection name — config is already + # partitioned by workspace, so no workspace prefix is needed + # on the key. collection_config = config.get("collection", {}) # Track which collections we've seen in this config current_collections: Set[tuple] = set() - # Process each collection in the config - for key, value_json in collection_config.items(): + for collection, value_json in collection_config.items(): try: - # Parse user:collection key - if ":" not in key: - logger.warning(f"Invalid collection key format (expected user:collection): {key}") - continue + current_collections.add((workspace, collection)) - user, collection = key.split(":", 1) - current_collections.add((user, collection)) - - # Parse metadata metadata = json.loads(value_json) - # Check if this is a new collection or updated - collection_key = (user, collection) - if collection_key not in self.known_collections: - logger.info(f"New collection detected: {user}/{collection}") - await self.create_collection(user, collection, metadata) - self.known_collections[collection_key] = metadata + key = (workspace, collection) + if key not in self.known_collections: + logger.info( + f"New collection detected: {workspace}/{collection}" + ) + await self.create_collection( + workspace, collection, metadata + ) + self.known_collections[key] = metadata else: - # Collection already exists, update metadata if changed - if self.known_collections[collection_key] != metadata: - logger.info(f"Collection metadata updated: {user}/{collection}") - # Most storage services don't need to do anything for metadata updates - # They just need to know the collection exists - self.known_collections[collection_key] = metadata + if self.known_collections[key] != metadata: + logger.info( + f"Collection metadata updated: " + f"{workspace}/{collection}" + ) + self.known_collections[key] = metadata except Exception as e: - logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True) + logger.error( + f"Error processing collection config for " + f"{workspace}/{collection}: {e}", + exc_info=True, + ) - # Find collections that were deleted (in known but not in current) - deleted_collections = set(self.known_collections.keys()) - current_collections - for user, collection in deleted_collections: - logger.info(f"Collection deleted: {user}/{collection}") + # Find collections for THIS workspace that were deleted (in + # known but not in current). Only compare collections owned by + # this workspace — other workspaces' collections are not + # affected by this config update. + known_for_ws = { + (w, c) for (w, c) in self.known_collections.keys() + if w == workspace + } + deleted_collections = known_for_ws - current_collections + for ws, collection in deleted_collections: + logger.info(f"Collection deleted: {ws}/{collection}") try: - # Remove from known_collections FIRST to immediately reject new writes - # This eliminates race condition with worker threads - del self.known_collections[(user, collection)] - # Physical deletion happens after - worker threads already rejecting writes - await self.delete_collection(user, collection) + # Remove from known_collections FIRST to immediately + # reject new writes + del self.known_collections[(ws, collection)] + await self.delete_collection(ws, collection) except Exception as e: - logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True) - # If physical deletion failed, should we re-add to known_collections? - # For now, keep it removed - collection is logically deleted per config + logger.error( + f"Error deleting collection {ws}/{collection}: {e}", + exc_info=True, + ) - logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}") + logger.debug( + f"Collection config processing complete. " + f"Known collections: {len(self.known_collections)}" + ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection( + self, workspace: str, collection: str, metadata: dict, + ): """ Create a collection in the storage backend. Subclasses must implement this method. Args: - user: User ID + workspace: Workspace ID collection: Collection ID metadata: Collection metadata dictionary """ - raise NotImplementedError("Storage service must implement create_collection method") + raise NotImplementedError( + "Storage service must implement create_collection method" + ) - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """ Delete a collection from the storage backend. Subclasses must implement this method. Args: - user: User ID + workspace: Workspace ID collection: Collection ID """ - raise NotImplementedError("Storage service must implement delete_collection method") + raise NotImplementedError( + "Storage service must implement delete_collection method" + ) - def collection_exists(self, user: str, collection: str) -> bool: + def collection_exists(self, workspace: str, collection: str) -> bool: """ - Check if a collection is known to exist + Check if a collection is known to exist. Args: - user: User ID + workspace: Workspace ID collection: Collection ID Returns: True if collection exists, False otherwise """ - return (user, collection) in self.known_collections + return (workspace, collection) in self.known_collections diff --git a/trustgraph-base/trustgraph/base/config_client.py b/trustgraph-base/trustgraph/base/config_client.py index c9ec3f9b..504a6d58 100644 --- a/trustgraph-base/trustgraph/base/config_client.py +++ b/trustgraph-base/trustgraph/base/config_client.py @@ -18,10 +18,11 @@ class ConfigClient(RequestResponse): ) return resp - async def get(self, type, key, timeout=CONFIG_TIMEOUT): + async def get(self, workspace, type, key, timeout=CONFIG_TIMEOUT): """Get a single config value. Returns the value string or None.""" resp = await self._request( operation="get", + workspace=workspace, keys=[ConfigKey(type=type, key=key)], timeout=timeout, ) @@ -29,19 +30,21 @@ class ConfigClient(RequestResponse): return resp.values[0].value return None - async def put(self, type, key, value, timeout=CONFIG_TIMEOUT): + async def put(self, workspace, type, key, value, timeout=CONFIG_TIMEOUT): """Put a single config value.""" await self._request( operation="put", + workspace=workspace, values=[ConfigValue(type=type, key=key, value=value)], timeout=timeout, ) - async def put_many(self, values, timeout=CONFIG_TIMEOUT): - """Put multiple config values in a single request. - values is a list of (type, key, value) tuples.""" + async def put_many(self, workspace, values, timeout=CONFIG_TIMEOUT): + """Put multiple config values in a single request within a + single workspace. values is a list of (type, key, value) tuples.""" await self._request( operation="put", + workspace=workspace, values=[ ConfigValue(type=t, key=k, value=v) for t, k, v in values @@ -49,19 +52,21 @@ class ConfigClient(RequestResponse): timeout=timeout, ) - async def delete(self, type, key, timeout=CONFIG_TIMEOUT): + async def delete(self, workspace, type, key, timeout=CONFIG_TIMEOUT): """Delete a single config key.""" await self._request( operation="delete", + workspace=workspace, keys=[ConfigKey(type=type, key=key)], timeout=timeout, ) - async def delete_many(self, keys, timeout=CONFIG_TIMEOUT): - """Delete multiple config keys in a single request. - keys is a list of (type, key) tuples.""" + async def delete_many(self, workspace, keys, timeout=CONFIG_TIMEOUT): + """Delete multiple config keys in a single request within a + single workspace. keys is a list of (type, key) tuples.""" await self._request( operation="delete", + workspace=workspace, keys=[ ConfigKey(type=t, key=k) for t, k in keys @@ -69,15 +74,26 @@ class ConfigClient(RequestResponse): timeout=timeout, ) - async def keys(self, type, timeout=CONFIG_TIMEOUT): - """List all keys for a config type.""" + async def keys(self, workspace, type, timeout=CONFIG_TIMEOUT): + """List all keys for a config type within a workspace.""" resp = await self._request( operation="list", + workspace=workspace, type=type, timeout=timeout, ) return resp.directory + async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT): + """Return the set of distinct workspaces with any config of + the given type.""" + resp = await self._request( + operation="getvalues-all-ws", + type=type, + timeout=timeout, + ) + return {v.workspace for v in resp.values if v.workspace} + class ConfigClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py index 023537df..af072cca 100644 --- a/trustgraph-base/trustgraph/base/consumer_spec.py +++ b/trustgraph-base/trustgraph/base/consumer_spec.py @@ -24,7 +24,10 @@ class ConsumerSpec(Spec): flow = flow, backend = processor.pubsub, topic = definition["topics"][self.name], - subscriber = processor.id + "--" + flow.name + "--" + self.name, + subscriber = ( + processor.id + "--" + flow.workspace + "--" + + flow.name + "--" + self.name + ), schema = self.schema, handler = self.handler, metrics = consumer_metrics, diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index dd985eab..a93cdc87 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -9,14 +9,12 @@ from .. knowledge import Uri, Literal logger = logging.getLogger(__name__) class DocumentEmbeddingsClient(RequestResponse): - async def query(self, vector, limit=20, user="trustgraph", - collection="default", timeout=30): + async def query(self, vector, limit=20, collection="default", timeout=30): resp = await self.request( DocumentEmbeddingsRequest( vector = vector, limit = limit, - user = user, collection = collection ), timeout=timeout diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index d5bf8421..cd9e91b1 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -60,7 +60,9 @@ class DocumentEmbeddingsQueryService(FlowProcessor): logger.debug(f"Handling document embeddings query request {id}...") - docs = await self.query_document_embeddings(request) + docs = await self.query_document_embeddings( + flow.workspace, request, + ) logger.debug("Sending document embeddings query response...") r = DocumentEmbeddingsResponse(chunks=docs, error=None) diff --git a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py index 0c7921db..96b7781f 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py @@ -41,7 +41,8 @@ class DocumentEmbeddingsStoreService(FlowProcessor): request = msg.value() - await self.store_document_embeddings(request) + # Workspace comes from the flow the message arrived on. + await self.store_document_embeddings(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/base/dynamic_tool_service.py b/trustgraph-base/trustgraph/base/dynamic_tool_service.py index bcfb71ab..00a457de 100644 --- a/trustgraph-base/trustgraph/base/dynamic_tool_service.py +++ b/trustgraph-base/trustgraph/base/dynamic_tool_service.py @@ -2,7 +2,7 @@ Base class for dynamically pluggable tool services. Tool services are Pulsar services that can be invoked as agent tools. -They receive a ToolServiceRequest with user, config, and arguments, +They receive a ToolServiceRequest with config and arguments, and return a ToolServiceResponse with the result. Uses direct Pulsar topics (no flow configuration required): @@ -42,7 +42,6 @@ class DynamicToolService(AsyncProcessor): the tool's logic. The invoke method receives: - - user: The user context for multi-tenancy - config: Dict of config values from the tool descriptor - arguments: Dict of arguments from the LLM @@ -115,14 +114,13 @@ class DynamicToolService(AsyncProcessor): id = msg.properties().get("id", "unknown") # Parse the request - user = request.user or "trustgraph" config = json.loads(request.config) if request.config else {} arguments = json.loads(request.arguments) if request.arguments else {} - logger.debug(f"Tool service request: user={user}, config={config}, arguments={arguments}") + logger.debug(f"Tool service request: config={config}, arguments={arguments}") # Invoke the tool implementation - response = await self.invoke(user, config, arguments) + response = await self.invoke(config, arguments) # Send success response await self.producer.send( @@ -159,14 +157,13 @@ class DynamicToolService(AsyncProcessor): properties={"id": id if id else "unknown"} ) - async def invoke(self, user, config, arguments): + async def invoke(self, config, arguments): """ Invoke the tool service. Override this method in subclasses to implement the tool's logic. Args: - user: The user context for multi-tenancy config: Dict of config values from the tool descriptor arguments: Dict of arguments from the LLM diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py index 9a515bf8..2caad938 100644 --- a/trustgraph-base/trustgraph/base/flow.py +++ b/trustgraph-base/trustgraph/base/flow.py @@ -4,15 +4,16 @@ import asyncio class Flow: """ Runtime representation of a deployed flow process. - + This class maintains internal processor states and orchestrates - lifecycles (start, stop) for inputs (consumers) and parameters + lifecycles (start, stop) for inputs (consumers) and parameters that drive data flowing across linked nodes. """ - def __init__(self, id, flow, processor, defn): + def __init__(self, id, flow, workspace, processor, defn): self.id = id self.name = flow + self.workspace = workspace self.producer = {} diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py index 99cb0f53..aa7bf921 100644 --- a/trustgraph-base/trustgraph/base/flow_processor.py +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -35,6 +35,8 @@ class FlowProcessor(AsyncProcessor): ) # Initialise flow information state + # Keyed by (workspace, flow) tuples; each workspace has its own + # set of flow variants for this processor. self.flows = {} # These can be overriden by a derived class: @@ -48,23 +50,28 @@ class FlowProcessor(AsyncProcessor): def register_specification(self, spec: Any) -> None: self.specifications.append(spec) - # Start processing for a new flow - async def start_flow(self, flow, defn): - self.flows[flow] = Flow(self.id, flow, self, defn) - await self.flows[flow].start() - logger.info(f"Started flow: {flow}") - - # Stop processing for a new flow - async def stop_flow(self, flow): - if flow in self.flows: - await self.flows[flow].stop() - del self.flows[flow] - logger.info(f"Stopped flow: {flow}") + # Start processing for a new flow within a workspace + async def start_flow(self, workspace, flow, defn): + key = (workspace, flow) + self.flows[key] = Flow(self.id, flow, workspace, self, defn) + await self.flows[key].start() + logger.info(f"Started flow: {workspace}/{flow}") - # Event handler - called for a configuration change - async def on_configure_flows(self, config, version): + # Stop processing for a flow within a workspace + async def stop_flow(self, workspace, flow): + key = (workspace, flow) + if key in self.flows: + await self.flows[key].stop() + del self.flows[key] + logger.info(f"Stopped flow: {workspace}/{flow}") - logger.info(f"Got config version {version}") + # Event handler - called for a configuration change for a single + # workspace + async def on_configure_flows(self, workspace, config, version): + + logger.info( + f"Got config version {version} for workspace {workspace}" + ) config_type = f"processor:{self.id}" @@ -76,26 +83,28 @@ class FlowProcessor(AsyncProcessor): for k, v in config[config_type].items() } else: - logger.debug("No configuration settings for me.") + logger.debug( + f"No configuration settings for me in {workspace}." + ) flow_config = {} - # Get list of flows which should be running and are currently - # running - wanted_flows = flow_config.keys() - # This takes a copy, needed because dict gets modified by stop_flow - current_flows = list(self.flows.keys()) + # Get list of flows which should be running in this workspace, + # and the list currently running in this workspace + wanted_flows = set(flow_config.keys()) + current_flows = { + f for (ws, f) in self.flows.keys() if ws == workspace + } - # Start all the flows which arent currently running - for flow in wanted_flows: - if flow not in current_flows: - await self.start_flow(flow, flow_config[flow]) + # Start all the flows which aren't currently running in this + # workspace + for flow in wanted_flows - current_flows: + await self.start_flow(workspace, flow, flow_config[flow]) - # Stop all the unwanted flows which are due to be stopped - for flow in current_flows: - if flow not in wanted_flows: - await self.stop_flow(flow) + # Stop all the unwanted flows in this workspace + for flow in current_flows - wanted_flows: + await self.stop_flow(workspace, flow) - logger.info("Handled config update") + logger.info(f"Handled config update for workspace {workspace}") # Start threads, just call parent async def start(self): diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py index fe717bf1..a9348c19 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -22,14 +22,12 @@ def to_value(x: Any) -> Any: return Literal(x.value or x.iri) class GraphEmbeddingsClient(RequestResponse): - async def query(self, vector, limit=20, user="trustgraph", - collection="default", timeout=30): + async def query(self, vector, limit=20, collection="default", timeout=30): resp = await self.request( GraphEmbeddingsRequest( vector = vector, limit = limit, - user = user, collection = collection ), timeout=timeout diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py index 55c8efa9..cbce810c 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -60,7 +60,9 @@ class GraphEmbeddingsQueryService(FlowProcessor): logger.debug(f"Handling graph embeddings query request {id}...") - entities = await self.query_graph_embeddings(request) + entities = await self.query_graph_embeddings( + flow.workspace, request, + ) logger.debug("Sending graph embeddings query response...") r = GraphEmbeddingsResponse(entities=entities, error=None) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py index 09bbbe6a..10cfe93c 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py @@ -41,7 +41,8 @@ class GraphEmbeddingsStoreService(FlowProcessor): request = msg.value() - await self.store_graph_embeddings(request) + # Workspace comes from the flow the message arrived on. + await self.store_graph_embeddings(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index 9db23293..e07781f9 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -3,7 +3,7 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import GraphRagQuery, GraphRagResponse class GraphRagClient(RequestResponse): - async def rag(self, query, user="trustgraph", collection="default", + async def rag(self, query, collection="default", chunk_callback=None, explain_callback=None, parent_uri="", timeout=600): @@ -12,7 +12,6 @@ class GraphRagClient(RequestResponse): Args: query: The question to ask - user: User identifier collection: Collection identifier chunk_callback: Optional async callback(text, end_of_stream) for text chunks explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications @@ -49,7 +48,6 @@ class GraphRagClient(RequestResponse): await self.request( GraphRagQuery( query = query, - user = user, collection = collection, parent_uri = parent_uri, ), diff --git a/trustgraph-base/trustgraph/base/librarian_client.py b/trustgraph-base/trustgraph/base/librarian_client.py index 5ad97f47..1876602b 100644 --- a/trustgraph-base/trustgraph/base/librarian_client.py +++ b/trustgraph-base/trustgraph/base/librarian_client.py @@ -10,7 +10,7 @@ Usage: id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params ) await self.librarian.start() - content = await self.librarian.fetch_document_content(doc_id, user) + content = await self.librarian.fetch_document_content(doc_id, workspace) """ import asyncio @@ -150,7 +150,7 @@ class LibrarianClient: finally: self._streams.pop(request_id, None) - async def fetch_document_content(self, document_id, user, timeout=120): + async def fetch_document_content(self, document_id, workspace, timeout=120): """Fetch document content using streaming. Returns base64-encoded content. Caller is responsible for decoding. @@ -158,7 +158,7 @@ class LibrarianClient: req = LibrarianRequest( operation="stream-document", document_id=document_id, - user=user, + workspace=workspace, ) chunks = await self.stream(req, timeout=timeout) @@ -176,24 +176,24 @@ class LibrarianClient: return base64.b64encode(raw) - async def fetch_document_text(self, document_id, user, timeout=120): + async def fetch_document_text(self, document_id, workspace, timeout=120): """Fetch document content and decode as UTF-8 text.""" content = await self.fetch_document_content( - document_id, user, timeout=timeout, + document_id, workspace, timeout=timeout, ) return base64.b64decode(content).decode("utf-8") - async def fetch_document_metadata(self, document_id, user, timeout=120): + async def fetch_document_metadata(self, document_id, workspace, timeout=120): """Fetch document metadata from the librarian.""" req = LibrarianRequest( operation="get-document-metadata", document_id=document_id, - user=user, + workspace=workspace, ) response = await self.request(req, timeout=timeout) return response.document_metadata - async def save_child_document(self, doc_id, parent_id, user, content, + async def save_child_document(self, doc_id, parent_id, workspace, content, document_type="chunk", title=None, kind="text/plain", timeout=120): """Save a child document to the librarian.""" @@ -202,7 +202,7 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind=kind, title=title or doc_id, parent_id=parent_id, @@ -218,7 +218,7 @@ class LibrarianClient: await self.request(req, timeout=timeout) return doc_id - async def save_document(self, doc_id, user, content, title=None, + async def save_document(self, doc_id, workspace, content, title=None, document_type="answer", kind="text/plain", timeout=120): """Save a document to the librarian.""" @@ -227,7 +227,7 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind=kind, title=title or doc_id, document_type=document_type, @@ -238,7 +238,7 @@ class LibrarianClient: document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content).decode("utf-8"), - user=user, + workspace=workspace, ) await self.request(req, timeout=timeout) diff --git a/trustgraph-base/trustgraph/base/request_response_spec.py b/trustgraph-base/trustgraph/base/request_response_spec.py index b91c655c..aa934a7f 100644 --- a/trustgraph-base/trustgraph/base/request_response_spec.py +++ b/trustgraph-base/trustgraph/base/request_response_spec.py @@ -133,8 +133,9 @@ class RequestResponseSpec(Spec): # Make subscription names unique, so that all subscribers get # to see all response messages subscription = ( - processor.id + "--" + flow.name + "--" + self.request_name + - "--" + str(uuid.uuid4()) + processor.id + "--" + flow.workspace + "--" + + flow.name + "--" + self.request_name + "--" + + str(uuid.uuid4()) ), consumer_name = flow.id, request_topic = definition["topics"][self.request_name], diff --git a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py index 811adf40..98c2e0a7 100644 --- a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py +++ b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py @@ -3,13 +3,12 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse class RowEmbeddingsQueryClient(RequestResponse): async def row_embeddings_query( - self, vector, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, collection="default", index_name=None, limit=10, timeout=600 ): request = RowEmbeddingsRequest( vector=vector, schema_name=schema_name, - user=user, collection=collection, limit=limit ) diff --git a/trustgraph-base/trustgraph/base/structured_query_client.py b/trustgraph-base/trustgraph/base/structured_query_client.py index 84d6bff3..49b30cd1 100644 --- a/trustgraph-base/trustgraph/base/structured_query_client.py +++ b/trustgraph-base/trustgraph/base/structured_query_client.py @@ -2,11 +2,10 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import StructuredQueryRequest, StructuredQueryResponse class StructuredQueryClient(RequestResponse): - async def structured_query(self, question, user="trustgraph", collection="default", timeout=600): + async def structured_query(self, question, collection="default", timeout=600): resp = await self.request( StructuredQueryRequest( question = question, - user = user, collection = collection ), timeout=timeout diff --git a/trustgraph-base/trustgraph/base/subscriber_spec.py b/trustgraph-base/trustgraph/base/subscriber_spec.py index bf35f869..80f9b0d5 100644 --- a/trustgraph-base/trustgraph/base/subscriber_spec.py +++ b/trustgraph-base/trustgraph/base/subscriber_spec.py @@ -21,7 +21,7 @@ class SubscriberSpec(Spec): subscriber = Subscriber( backend = processor.pubsub, topic = definition["topics"][self.name], - subscription = flow.id, + subscription = flow.id + "--" + flow.workspace + "--" + flow.name, consumer_name = flow.id, schema = self.schema, metrics = subscriber_metrics, diff --git a/trustgraph-base/trustgraph/base/tool_service.py b/trustgraph-base/trustgraph/base/tool_service.py index 3ff977d1..eeaced6a 100644 --- a/trustgraph-base/trustgraph/base/tool_service.py +++ b/trustgraph-base/trustgraph/base/tool_service.py @@ -64,6 +64,7 @@ class ToolService(FlowProcessor): id = msg.properties()["id"] response = await self.invoke_tool( + flow.workspace, request.name, json.loads(request.parameters) if request.parameters else {}, ) diff --git a/trustgraph-base/trustgraph/base/tool_service_client.py b/trustgraph-base/trustgraph/base/tool_service_client.py index 81930ba0..db5946e9 100644 --- a/trustgraph-base/trustgraph/base/tool_service_client.py +++ b/trustgraph-base/trustgraph/base/tool_service_client.py @@ -11,12 +11,11 @@ logger = logging.getLogger(__name__) class ToolServiceClient(RequestResponse): """Client for invoking dynamically configured tool services.""" - async def call(self, user, config, arguments, timeout=600): + async def call(self, config, arguments, timeout=600): """ Call a tool service. Args: - user: User context for multi-tenancy config: Dict of config values (e.g., {"collection": "customers"}) arguments: Dict of arguments from LLM timeout: Request timeout in seconds @@ -26,7 +25,6 @@ class ToolServiceClient(RequestResponse): """ resp = await self.request( ToolServiceRequest( - user=user, config=json.dumps(config) if config else "{}", arguments=json.dumps(arguments) if arguments else "{}", ), @@ -38,12 +36,11 @@ class ToolServiceClient(RequestResponse): return resp.response - async def call_streaming(self, user, config, arguments, callback, timeout=600): + async def call_streaming(self, config, arguments, callback, timeout=600): """ Call a tool service with streaming response. Args: - user: User context for multi-tenancy config: Dict of config values arguments: Dict of arguments from LLM callback: Async function called with each response chunk @@ -66,7 +63,6 @@ class ToolServiceClient(RequestResponse): await self.request( ToolServiceRequest( - user=user, config=json.dumps(config) if config else "{}", arguments=json.dumps(arguments) if arguments else "{}", ), diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index a81a5cd0..2601a1e1 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -45,7 +45,7 @@ def from_value(x: Any) -> Any: class TriplesClient(RequestResponse): async def query(self, s=None, p=None, o=None, limit=20, - user="trustgraph", collection="default", + collection="default", timeout=30, g=None): resp = await self.request( @@ -54,7 +54,6 @@ class TriplesClient(RequestResponse): p = from_value(p), o = from_value(o), limit = limit, - user = user, collection = collection, g = g, ), @@ -72,7 +71,7 @@ class TriplesClient(RequestResponse): return triples async def query_stream(self, s=None, p=None, o=None, limit=20, - user="trustgraph", collection="default", + collection="default", batch_size=20, timeout=30, batch_callback=None, g=None): """ @@ -81,7 +80,6 @@ class TriplesClient(RequestResponse): Args: s, p, o: Triple pattern (None for wildcard) limit: Maximum total triples to return - user: User/keyspace collection: Collection name batch_size: Triples per batch timeout: Request timeout in seconds @@ -116,7 +114,6 @@ class TriplesClient(RequestResponse): p=from_value(p), o=from_value(o), limit=limit, - user=user, collection=collection, streaming=True, batch_size=batch_size, diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index 832ff6f1..5850307c 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -58,9 +58,13 @@ class TriplesQueryService(FlowProcessor): logger.debug(f"Handling triples query request {id}...") + workspace = flow.workspace + if request.streaming: # Streaming mode: send batches - async for batch, is_final in self.query_triples_stream(request): + async for batch, is_final in self.query_triples_stream( + workspace, request, + ): r = TriplesQueryResponse( triples=batch, error=None, @@ -70,7 +74,7 @@ class TriplesQueryService(FlowProcessor): logger.debug("Triples query streaming completed") else: # Non-streaming mode: single response - triples = await self.query_triples(request) + triples = await self.query_triples(workspace, request) logger.debug("Sending triples query response...") r = TriplesQueryResponse(triples=triples, error=None) await flow("response").send(r, properties={"id": id}) @@ -92,13 +96,13 @@ class TriplesQueryService(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def query_triples_stream(self, request): + async def query_triples_stream(self, workspace, request): """ Streaming query - yields (batch, is_final) tuples. Default implementation batches results from query_triples. Override for true streaming from backend. """ - triples = await self.query_triples(request) + triples = await self.query_triples(workspace, request) batch_size = request.batch_size if request.batch_size > 0 else 20 for i in range(0, len(triples), batch_size): diff --git a/trustgraph-base/trustgraph/base/triples_store_service.py b/trustgraph-base/trustgraph/base/triples_store_service.py index abd3aab8..7c44fe29 100644 --- a/trustgraph-base/trustgraph/base/triples_store_service.py +++ b/trustgraph-base/trustgraph/base/triples_store_service.py @@ -45,7 +45,10 @@ class TriplesStoreService(FlowProcessor): request = msg.value() - await self.store_triples(request) + # Workspace is derived from the flow the message arrived on, + # not from fields in the message payload. Topic routing is + # the isolation boundary. + await self.store_triples(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index 78b62688..25c1af94 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -33,6 +33,7 @@ class ConfigClient(BaseClient): subscriber=None, input_queue=None, output_queue=None, + workspace="default", **pubsub_config, ): @@ -51,10 +52,13 @@ class ConfigClient(BaseClient): **pubsub_config, ) + self.workspace = workspace + def get(self, keys, timeout=300): resp = self.call( operation="get", + workspace=self.workspace, keys=[ ConfigKey( type = k["type"], @@ -78,6 +82,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="list", + workspace=self.workspace, type=type, timeout=timeout ) @@ -88,6 +93,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="getvalues", + workspace=self.workspace, type=type, timeout=timeout ) @@ -101,10 +107,31 @@ class ConfigClient(BaseClient): for v in resp.values ] + def getvalues_all_ws(self, type, timeout=300): + """Fetch all values of a given type across all workspaces. + Returns a list of dicts including a 'workspace' field.""" + + resp = self.call( + operation="getvalues-all-ws", + type=type, + timeout=timeout + ) + + return [ + { + "workspace": v.workspace, + "type": v.type, + "key": v.key, + "value": v.value, + } + for v in resp.values + ] + def delete(self, keys, timeout=300): resp = self.call( operation="delete", + workspace=self.workspace, keys=[ ConfigKey( type = k["type"], @@ -121,6 +148,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="put", + workspace=self.workspace, values=[ ConfigValue( type = v["type"], @@ -138,6 +166,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="config", + workspace=self.workspace, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index ebbad397..ad20206c 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -35,11 +35,11 @@ class DocumentEmbeddingsClient(BaseClient): ) def request( - self, vector, user="trustgraph", collection="default", + self, vector, collection="default", limit=10, timeout=300 ): return self.call( - user=user, collection=collection, + collection=collection, vector=vector, limit=limit, timeout=timeout ).chunks diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 365ea09d..e8deaafd 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -33,14 +33,13 @@ class DocumentRagClient(BaseClient): output_schema=DocumentRagResponse, ) - def request(self, query, user="trustgraph", collection="default", + def request(self, query, collection="default", chunk_callback=None, explain_callback=None, timeout=300): """ Request a document RAG query with optional streaming callbacks. Args: query: The question to ask - user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications @@ -71,7 +70,7 @@ class DocumentRagClient(BaseClient): return False # Continue receiving self.call( - query=query, user=user, collection=collection, + query=query, collection=collection, inspect=inspect, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index 62a55609..9b38a11b 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -35,11 +35,11 @@ class GraphEmbeddingsClient(BaseClient): ) def request( - self, vector, user="trustgraph", collection="default", + self, vector, collection="default", limit=10, timeout=300 ): return self.call( - user=user, collection=collection, + collection=collection, vector=vector, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 0d33bf91..f1d2374e 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -34,7 +34,7 @@ class GraphRagClient(BaseClient): ) def request( - self, query, user="trustgraph", collection="default", + self, query, collection="default", chunk_callback=None, explain_callback=None, timeout=500 @@ -44,7 +44,6 @@ class GraphRagClient(BaseClient): Args: query: The question to ask - user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications @@ -76,7 +75,7 @@ class GraphRagClient(BaseClient): return False # Continue receiving self.call( - user=user, collection=collection, query=query, + collection=collection, query=query, inspect=inspect, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py index 6e10de29..c2329f9d 100644 --- a/trustgraph-base/trustgraph/clients/row_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -35,11 +35,11 @@ class RowEmbeddingsClient(BaseClient): ) def request( - self, vector, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, collection="default", index_name=None, limit=10, timeout=300 ): kwargs = dict( - user=user, collection=collection, + collection=collection, vector=vector, schema_name=schema_name, limit=limit, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index 403d02ea..864f4442 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -45,16 +45,15 @@ class TriplesQueryClient(BaseClient): return Term(type=LITERAL, value=ent) def request( - self, + self, s, p, o, - user="trustgraph", collection="default", + collection="default", limit=10, timeout=120, ): return self.call( s=self.create_value(s), p=self.create_value(p), o=self.create_value(o), - user=user, collection=collection, limit=limit, timeout=timeout, diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 7df59907..d1e13e33 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -13,7 +13,6 @@ class AgentRequestTranslator(MessageTranslator): state=data.get("state", None), group=data.get("group", None), history=data.get("history", []), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), streaming=data.get("streaming", False), session_id=data.get("session_id", ""), @@ -33,7 +32,6 @@ class AgentRequestTranslator(MessageTranslator): "state": obj.state, "group": obj.group, "history": obj.history, - "user": obj.user, "collection": getattr(obj, "collection", "default"), "streaming": getattr(obj, "streaming", False), "session_id": getattr(obj, "session_id", ""), diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index 2e39e8c2..cd07bc99 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -9,7 +9,7 @@ class CollectionManagementRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest: return CollectionManagementRequest( operation=data.get("operation"), - user=data.get("user"), + workspace=data.get("workspace", ""), collection=data.get("collection"), timestamp=data.get("timestamp"), name=data.get("name"), @@ -24,8 +24,8 @@ class CollectionManagementRequestTranslator(MessageTranslator): if obj.operation is not None: result["operation"] = obj.operation - if obj.user is not None: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection is not None: result["collection"] = obj.collection if obj.timestamp is not None: @@ -63,7 +63,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): if "collections" in data: for coll_data in data["collections"]: collections.append(CollectionMetadata( - user=coll_data.get("user"), collection=coll_data.get("collection"), name=coll_data.get("name"), description=coll_data.get("description"), @@ -91,7 +90,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): result["collections"] = [] for coll in obj.collections: result["collections"].append({ - "user": coll.user, "collection": coll.collection, "name": coll.name, "description": coll.description, diff --git a/trustgraph-base/trustgraph/messaging/translators/config.py b/trustgraph-base/trustgraph/messaging/translators/config.py index e166362a..223db6c8 100644 --- a/trustgraph-base/trustgraph/messaging/translators/config.py +++ b/trustgraph-base/trustgraph/messaging/translators/config.py @@ -23,13 +23,15 @@ class ConfigRequestTranslator(MessageTranslator): ConfigValue( type=v["type"], key=v["key"], - value=v["value"] + value=v["value"], + workspace=v.get("workspace", ""), ) for v in data["values"] ] return ConfigRequest( operation=data.get("operation"), + workspace=data.get("workspace", ""), keys=keys, type=data.get("type"), values=values @@ -37,10 +39,13 @@ class ConfigRequestTranslator(MessageTranslator): def encode(self, obj: ConfigRequest) -> Dict[str, Any]: result = {} - + if obj.operation is not None: result["operation"] = obj.operation + if obj.workspace is not None: + result["workspace"] = obj.workspace + if obj.type is not None: result["type"] = obj.type @@ -56,13 +61,14 @@ class ConfigRequestTranslator(MessageTranslator): if obj.values is not None: result["values"] = [ { + **({"workspace": v.workspace} if v.workspace else {}), "type": v.type, "key": v.key, - "value": v.value + "value": v.value, } for v in obj.values ] - + return result @@ -81,13 +87,14 @@ class ConfigResponseTranslator(MessageTranslator): if obj.values is not None: result["values"] = [ { + **({"workspace": v.workspace} if v.workspace else {}), "type": v.type, "key": v.key, - "value": v.value + "value": v.value, } for v in obj.values ] - + if obj.directory is not None: result["directory"] = obj.directory diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index df2aa3ba..61917321 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -39,7 +39,6 @@ class DocumentTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), data=base64.b64encode(doc).decode("utf-8") @@ -56,8 +55,6 @@ class DocumentTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -79,7 +76,6 @@ class TextDocumentTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), text=text.encode("utf-8") @@ -96,8 +92,6 @@ class TextDocumentTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -115,7 +109,6 @@ class ChunkTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"] @@ -132,8 +125,6 @@ class ChunkTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -161,7 +152,6 @@ class DocumentEmbeddingsTranslator(SendTranslator): metadata=Metadata( id=metadata.get("id"), root=metadata.get("root", ""), - user=metadata.get("user", "trustgraph"), collection=metadata.get("collection", "default"), ), chunks=chunks @@ -184,8 +174,6 @@ class DocumentEmbeddingsTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index fce1625e..c435ba48 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -15,7 +15,6 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): return DocumentEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) @@ -23,7 +22,6 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): return { "vector": obj.vector, "limit": obj.limit, - "user": obj.user, "collection": obj.collection } @@ -60,7 +58,6 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): return GraphEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) @@ -68,7 +65,6 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): return { "vector": obj.vector, "limit": obj.limit, - "user": obj.user, "collection": obj.collection } @@ -108,7 +104,6 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): return RowEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), schema_name=data.get("schema_name", ""), index_name=data.get("index_name") @@ -118,7 +113,6 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): result = { "vector": obj.vector, "limit": obj.limit, - "user": obj.user, "collection": obj.collection, "schema_name": obj.schema_name, } diff --git a/trustgraph-base/trustgraph/messaging/translators/flow.py b/trustgraph-base/trustgraph/messaging/translators/flow.py index 2047475e..07304c18 100644 --- a/trustgraph-base/trustgraph/messaging/translators/flow.py +++ b/trustgraph-base/trustgraph/messaging/translators/flow.py @@ -9,18 +9,21 @@ class FlowRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> FlowRequest: return FlowRequest( operation=data.get("operation"), + workspace=data.get("workspace", ""), blueprint_name=data.get("blueprint-name"), blueprint_definition=data.get("blueprint-definition"), description=data.get("description"), flow_id=data.get("flow-id"), parameters=data.get("parameters") ) - + def encode(self, obj: FlowRequest) -> Dict[str, Any]: result = {} if obj.operation is not None: result["operation"] = obj.operation + if obj.workspace is not None: + result["workspace"] = obj.workspace if obj.blueprint_name is not None: result["blueprint-name"] = obj.blueprint_name if obj.blueprint_definition is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f819dc9c..83cdbbf4 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -21,7 +21,6 @@ class KnowledgeRequestTranslator(MessageTranslator): metadata=Metadata( id=data["triples"]["metadata"]["id"], root=data["triples"]["metadata"].get("root", ""), - user=data["triples"]["metadata"]["user"], collection=data["triples"]["metadata"]["collection"] ), triples=self.subgraph_translator.decode(data["triples"]["triples"]), @@ -33,7 +32,6 @@ class KnowledgeRequestTranslator(MessageTranslator): metadata=Metadata( id=data["graph-embeddings"]["metadata"]["id"], root=data["graph-embeddings"]["metadata"].get("root", ""), - user=data["graph-embeddings"]["metadata"]["user"], collection=data["graph-embeddings"]["metadata"]["collection"] ), entities=[ @@ -47,7 +45,7 @@ class KnowledgeRequestTranslator(MessageTranslator): return KnowledgeRequest( operation=data.get("operation"), - user=data.get("user"), + workspace=data.get("workspace", ""), id=data.get("id"), flow=data.get("flow"), collection=data.get("collection"), @@ -60,8 +58,8 @@ class KnowledgeRequestTranslator(MessageTranslator): if obj.operation: result["operation"] = obj.operation - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.id: result["id"] = obj.id if obj.flow: @@ -74,7 +72,6 @@ class KnowledgeRequestTranslator(MessageTranslator): "metadata": { "id": obj.triples.metadata.id, "root": obj.triples.metadata.root, - "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.encode(obj.triples.triples), @@ -85,7 +82,6 @@ class KnowledgeRequestTranslator(MessageTranslator): "metadata": { "id": obj.graph_embeddings.metadata.id, "root": obj.graph_embeddings.metadata.root, - "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ @@ -122,7 +118,6 @@ class KnowledgeResponseTranslator(MessageTranslator): "metadata": { "id": obj.triples.metadata.id, "root": obj.triples.metadata.root, - "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.encode(obj.triples.triples), @@ -136,7 +131,6 @@ class KnowledgeResponseTranslator(MessageTranslator): "metadata": { "id": obj.graph_embeddings.metadata.id, "root": obj.graph_embeddings.metadata.root, - "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py index 7c77c39c..d528097e 100644 --- a/trustgraph-base/trustgraph/messaging/translators/library.py +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -49,7 +49,7 @@ class LibraryRequestTranslator(MessageTranslator): document_metadata=doc_metadata, processing_metadata=proc_metadata, content=content, - user=data.get("user", ""), + workspace=data.get("workspace", ""), collection=data.get("collection", ""), criteria=criteria, # Chunked upload fields @@ -76,8 +76,8 @@ class LibraryRequestTranslator(MessageTranslator): result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata) if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.criteria is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py index 3e141c19..9da5d5c0 100644 --- a/trustgraph-base/trustgraph/messaging/translators/metadata.py +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -19,7 +19,7 @@ class DocumentMetadataTranslator(Translator): title=data.get("title"), comments=data.get("comments"), metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [], - user=data.get("user"), + workspace=data.get("workspace"), tags=data.get("tags"), parent_id=data.get("parent-id", ""), document_type=data.get("document-type", "source"), @@ -40,8 +40,8 @@ class DocumentMetadataTranslator(Translator): result["comments"] = obj.comments if obj.metadata is not None: result["metadata"] = self.subgraph_translator.encode(obj.metadata) - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.tags is not None: result["tags"] = obj.tags if obj.parent_id: @@ -61,7 +61,7 @@ class ProcessingMetadataTranslator(Translator): document_id=data.get("document-id"), time=data.get("time"), flow=data.get("flow"), - user=data.get("user"), + workspace=data.get("workspace"), collection=data.get("collection"), tags=data.get("tags") ) @@ -77,8 +77,8 @@ class ProcessingMetadataTranslator(Translator): result["time"] = obj.time if obj.flow: result["flow"] = obj.flow - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.tags is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index e37b76e1..fe766522 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -10,7 +10,6 @@ class DocumentRagRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> DocumentRagQuery: return DocumentRagQuery( query=data["query"], - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), doc_limit=int(data.get("doc-limit", 20)), streaming=data.get("streaming", False) @@ -19,7 +18,6 @@ class DocumentRagRequestTranslator(MessageTranslator): def encode(self, obj: DocumentRagQuery) -> Dict[str, Any]: return { "query": obj.query, - "user": obj.user, "collection": obj.collection, "doc-limit": obj.doc_limit, "streaming": getattr(obj, "streaming", False) @@ -96,7 +94,6 @@ class GraphRagRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> GraphRagQuery: return GraphRagQuery( query=data["query"], - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), entity_limit=int(data.get("entity-limit", 50)), triple_limit=int(data.get("triple-limit", 30)), @@ -110,7 +107,6 @@ class GraphRagRequestTranslator(MessageTranslator): def encode(self, obj: GraphRagQuery) -> Dict[str, Any]: return { "query": obj.query, - "user": obj.user, "collection": obj.collection, "entity-limit": obj.entity_limit, "triple-limit": obj.triple_limit, diff --git a/trustgraph-base/trustgraph/messaging/translators/rows_query.py b/trustgraph-base/trustgraph/messaging/translators/rows_query.py index 6153901c..3d3f682f 100644 --- a/trustgraph-base/trustgraph/messaging/translators/rows_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/rows_query.py @@ -9,7 +9,6 @@ class RowsQueryRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> RowsQueryRequest: return RowsQueryRequest( - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), query=data.get("query", ""), variables=data.get("variables", {}), @@ -18,7 +17,6 @@ class RowsQueryRequestTranslator(MessageTranslator): def encode(self, obj: RowsQueryRequest) -> Dict[str, Any]: result = { - "user": obj.user, "collection": obj.collection, "query": obj.query, "variables": dict(obj.variables) if obj.variables else {} diff --git a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py index a8b13865..e69d998a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py @@ -12,7 +12,6 @@ class SparqlQueryRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> SparqlQueryRequest: return SparqlQueryRequest( - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), query=data.get("query", ""), limit=int(data.get("limit", 10000)), @@ -22,7 +21,6 @@ class SparqlQueryRequestTranslator(MessageTranslator): def encode(self, obj: SparqlQueryRequest) -> Dict[str, Any]: return { - "user": obj.user, "collection": obj.collection, "query": obj.query, "limit": obj.limit, diff --git a/trustgraph-base/trustgraph/messaging/translators/structured_query.py b/trustgraph-base/trustgraph/messaging/translators/structured_query.py index 6b0b38a1..bb76f3e7 100644 --- a/trustgraph-base/trustgraph/messaging/translators/structured_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/structured_query.py @@ -10,14 +10,12 @@ class StructuredQueryRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> StructuredQueryRequest: return StructuredQueryRequest( question=data.get("question", ""), - user=data.get("user", "trustgraph"), # Default fallback - collection=data.get("collection", "default") # Default fallback + collection=data.get("collection", "default") ) - + def encode(self, obj: StructuredQueryRequest) -> Dict[str, Any]: return { "question": obj.question, - "user": obj.user, "collection": obj.collection } diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py index 21d2698f..7a48ff15 100644 --- a/trustgraph-base/trustgraph/messaging/translators/triples.py +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -22,16 +22,14 @@ class TriplesQueryRequestTranslator(MessageTranslator): o=o, g=g, limit=int(data.get("limit", 10000)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), streaming=data.get("streaming", False), batch_size=int(data.get("batch-size", 20)), ) - + def encode(self, obj: TriplesQueryRequest) -> Dict[str, Any]: result = { "limit": obj.limit, - "user": obj.user, "collection": obj.collection, "streaming": obj.streaming, "batch-size": obj.batch_size, diff --git a/trustgraph-base/trustgraph/schema/core/metadata.py b/trustgraph-base/trustgraph/schema/core/metadata.py index a37a8d62..a307db4f 100644 --- a/trustgraph-base/trustgraph/schema/core/metadata.py +++ b/trustgraph-base/trustgraph/schema/core/metadata.py @@ -8,6 +8,7 @@ class Metadata: # Root document identifier (set by librarian, preserved through pipeline) root: str = "" - # Collection management - user: str = "" + # Collection the message belongs to. Workspace is NOT carried on the + # message — consumers derive it from flow.workspace (the flow the + # message arrived on), which is the trusted isolation boundary. collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 0c4a9f7c..37969566 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings # <- (error) # list-kg-cores -# -> (user) +# -> (workspace) # <- () # <- (error) @@ -27,8 +27,8 @@ class KnowledgeRequest: # load-kg-core, unload-kg-core operation: str = "" - # list-kg-cores, delete-kg-core, put-kg-core - user: str = "" + # Workspace the cores belong to. Partition / isolation boundary. + workspace: str = "" # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # load-kg-core, unload-kg-core diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index cd4a2b45..50ac1dd1 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -22,7 +22,6 @@ class AgentStep: action: str = "" arguments: dict[str, str] = field(default_factory=dict) observation: str = "" - user: str = "" # User context for the step step_type: str = "" # "react", "plan", "execute", "decompose", "synthesise" plan: list[PlanStep] = field(default_factory=list) # Plan steps (for plan-then-execute) subagent_results: dict[str, str] = field(default_factory=dict) # Subagent results keyed by goal @@ -33,7 +32,6 @@ class AgentRequest: state: str = "" group: list[str] | None = None history: list[AgentStep] = field(default_factory=list) - user: str = "" # User context for multi-tenancy collection: str = "default" # Collection for provenance traces streaming: bool = False # Enable streaming response delivery (default false) session_id: str = "" # For provenance tracking across iterations diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index f4b5fc6e..13dd0607 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -13,7 +13,6 @@ from ..core.topic import queue @dataclass class CollectionMetadata: """Collection metadata record""" - user: str = "" collection: str = "" name: str = "" description: str = "" @@ -23,11 +22,17 @@ class CollectionMetadata: @dataclass class CollectionManagementRequest: - """Request for collection management operations""" + """Request for collection management operations. + + Collection-management is a global (non-flow-scoped) service, so the + workspace has to travel on the wire — it's the isolation boundary + for which workspace's collections the request operates on. + """ operation: str = "" # e.g., "delete-collection" - # For 'list-collections' - user: str = "" + # Workspace the collection belongs to. + workspace: str = "" + collection: str = "" timestamp: str = "" # ISO timestamp name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index c08e96d7..3bcbc72c 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -7,12 +7,19 @@ from ..core.primitives import Error ############################################################################ # Config service: -# get(keys) -> (version, values) -# list(type) -> (version, values) -# getvalues(type) -> (version, values) -# put(values) -> () -# delete(keys) -> () -# config() -> (version, config) +# get(workspace, keys) -> (version, values) +# list(workspace, type) -> (version, directory) +# getvalues(workspace, type) -> (version, values) +# getvalues-all-ws(type) -> (version, values with workspace field) +# put(workspace, values) -> () +# delete(workspace, keys) -> () +# config(workspace) -> (version, config) +# +# Most operations are scoped to a workspace. The workspace field on the +# request identifies which workspace's config to read or modify. +# getvalues-all-ws returns values across all workspaces for a single +# type — used by shared processors to load type-scoped config at startup. + @dataclass class ConfigKey: type: str = "" @@ -23,16 +30,24 @@ class ConfigValue: type: str = "" key: str = "" value: str = "" + # Populated by getvalues-all-ws responses so callers can identify + # which workspace each value belongs to. Empty otherwise. + workspace: str = "" -# Prompt services, abstract the prompt generation @dataclass class ConfigRequest: - operation: str = "" # get, list, getvalues, delete, put, config + # Operations: get, list, getvalues, getvalues-all-ws, delete, put, + # config + operation: str = "" + + # Workspace scope — required on all operations except + # getvalues-all-ws which spans all workspaces. + workspace: str = "" # get, delete keys: list[ConfigKey] = field(default_factory=list) - # list, getvalues + # list, getvalues, getvalues-all-ws type: str = "" # put @@ -58,7 +73,12 @@ class ConfigResponse: @dataclass class ConfigPush: version: int = 0 - types: list[str] = field(default_factory=list) + + # Dict of config type -> list of affected workspaces. + # Handlers look up their registered type and get the list of + # workspaces that need refreshing. + # e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]} + changes: dict[str, list[str]] = field(default_factory=dict) config_request_queue = queue('config', cls='request') config_response_queue = queue('config', cls='response') diff --git a/trustgraph-base/trustgraph/schema/services/flow.py b/trustgraph-base/trustgraph/schema/services/flow.py index 0d497dd7..586c160d 100644 --- a/trustgraph-base/trustgraph/schema/services/flow.py +++ b/trustgraph-base/trustgraph/schema/services/flow.py @@ -17,12 +17,14 @@ from ..core.primitives import Error # start_flow(flowid, blueprintname) -> () # stop_flow(flowid) -> () -# Prompt services, abstract the prompt generation @dataclass class FlowRequest: operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint # list-flows, get-flow, start-flow, stop-flow + # Workspace scope — all operations act within this workspace + workspace: str = "" + # get_blueprint, put_blueprint, delete_blueprint, start_flow blueprint_name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index f5d4592c..961b47dc 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -43,12 +43,12 @@ from ..core.metadata import Metadata # <- (error) # list-documents -# -> (user, collection?) +# -> (workspace, collection?) # <- (document_metadata[]) # <- (error) # list-processing -# -> (user, collection?) +# -> (workspace, collection?) # <- (processing_metadata[]) # <- (error) @@ -78,7 +78,7 @@ from ..core.metadata import Metadata # <- (error) # list-uploads -# -> (user) +# -> (workspace) # <- (uploads[]) # <- (error) @@ -90,7 +90,7 @@ class DocumentMetadata: title: str = "" comments: str = "" metadata: list[Triple] = field(default_factory=list) - user: str = "" + workspace: str = "" tags: list[str] = field(default_factory=list) # Child document support parent_id: str = "" # Empty for top-level docs, set for children @@ -107,7 +107,7 @@ class ProcessingMetadata: document_id: str = "" time: int = 0 flow: str = "" - user: str = "" + workspace: str = "" collection: str = "" tags: list[str] = field(default_factory=list) @@ -162,8 +162,8 @@ class LibrarianRequest: # add-document, upload-chunk content: bytes = b"" - # list-documents, list-processing, list-uploads - user: str = "" + # Workspace scopes every library operation. + workspace: str = "" # list-documents?, list-processing? collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index f9f08658..9c11a157 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -11,7 +11,6 @@ from ..core.topic import queue class GraphEmbeddingsRequest: vector: list[float] = field(default_factory=list) limit: int = 0 - user: str = "" collection: str = "" @dataclass @@ -31,7 +30,6 @@ class GraphEmbeddingsResponse: @dataclass class TriplesQueryRequest: - user: str = "" collection: str = "" s: Term | None = None p: Term | None = None @@ -55,7 +53,6 @@ class TriplesQueryResponse: class DocumentEmbeddingsRequest: vector: list[float] = field(default_factory=list) limit: int = 0 - user: str = "" collection: str = "" @dataclass @@ -89,7 +86,6 @@ class RowEmbeddingsRequest: """Request for row embeddings semantic search""" vector: list[float] = field(default_factory=list) # Query vector limit: int = 10 # Max results to return - user: str = "" # User/keyspace collection: str = "" # Collection name schema_name: str = "" # Schema name to search within index_name: str | None = None # Optional: filter to specific index diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index a1af9170..e937e720 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -8,7 +8,6 @@ from ..core.primitives import Error, Term, Triple @dataclass class GraphRagQuery: query: str = "" - user: str = "" collection: str = "" entity_limit: int = 0 triple_limit: int = 0 @@ -40,7 +39,6 @@ class GraphRagResponse: @dataclass class DocumentRagQuery: query: str = "" - user: str = "" collection: str = "" doc_limit: int = 0 streaming: bool = False diff --git a/trustgraph-base/trustgraph/schema/services/rows_query.py b/trustgraph-base/trustgraph/schema/services/rows_query.py index e3c4f14c..ea0759f1 100644 --- a/trustgraph-base/trustgraph/schema/services/rows_query.py +++ b/trustgraph-base/trustgraph/schema/services/rows_query.py @@ -15,7 +15,6 @@ class GraphQLError: @dataclass class RowsQueryRequest: - user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest) collection: str = "" # Data collection identifier (required for partition key) query: str = "" # GraphQL query string variables: dict[str, str] = field(default_factory=dict) # GraphQL variables diff --git a/trustgraph-base/trustgraph/schema/services/sparql_query.py b/trustgraph-base/trustgraph/schema/services/sparql_query.py index 62c02c93..a5ed502f 100644 --- a/trustgraph-base/trustgraph/schema/services/sparql_query.py +++ b/trustgraph-base/trustgraph/schema/services/sparql_query.py @@ -16,7 +16,6 @@ class SparqlBinding: @dataclass class SparqlQueryRequest: - user: str = "" collection: str = "" query: str = "" # SPARQL query string limit: int = 10000 # Safety limit on results diff --git a/trustgraph-base/trustgraph/schema/services/structured_query.py b/trustgraph-base/trustgraph/schema/services/structured_query.py index 5f54ac16..272643ac 100644 --- a/trustgraph-base/trustgraph/schema/services/structured_query.py +++ b/trustgraph-base/trustgraph/schema/services/structured_query.py @@ -9,7 +9,6 @@ from ..core.primitives import Error @dataclass class StructuredQueryRequest: question: str = "" - user: str = "" # Cassandra keyspace identifier collection: str = "" # Data collection identifier @dataclass diff --git a/trustgraph-base/trustgraph/schema/services/tool_service.py b/trustgraph-base/trustgraph/schema/services/tool_service.py index 18315f29..a42fd5e3 100644 --- a/trustgraph-base/trustgraph/schema/services/tool_service.py +++ b/trustgraph-base/trustgraph/schema/services/tool_service.py @@ -7,8 +7,6 @@ from ..core.primitives import Error @dataclass class ToolServiceRequest: """Request to a dynamically configured tool service.""" - # User context for multi-tenancy - user: str = "" # Config values (collection, etc.) as JSON config: str = "" # Arguments from LLM as JSON diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index a60b2bba..a5738449 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -95,6 +95,8 @@ tg-list-config-items = "trustgraph.cli.list_config_items:main" tg-get-config-item = "trustgraph.cli.get_config_item:main" tg-put-config-item = "trustgraph.cli.put_config_item:main" tg-delete-config-item = "trustgraph.cli.delete_config_item:main" +tg-export-workspace-config = "trustgraph.cli.export_workspace_config:main" +tg-import-workspace-config = "trustgraph.cli.import_workspace_config:main" tg-list-collections = "trustgraph.cli.list_collections:main" tg-set-collection = "trustgraph.cli.set_collection:main" tg-delete-collection = "trustgraph.cli.delete_collection:main" diff --git a/trustgraph-cli/trustgraph/cli/add_library_document.py b/trustgraph-cli/trustgraph/cli/add_library_document.py index 3273e63d..8d08d11a 100644 --- a/trustgraph-cli/trustgraph/cli/add_library_document.py +++ b/trustgraph-cli/trustgraph/cli/add_library_document.py @@ -15,17 +15,17 @@ from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import DigitalDocument default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") class Loader: def __init__( - self, id, url, user, metadata, title, comments, kind, tags - ): + self, id, url, metadata, title, comments, kind, tags + , token=None, workspace="default"): - self.api = Api(url).library() + self.api = Api(url, token=token, workspace=workspace).library() - self.user = user self.metadata = metadata self.title = title self.comments = comments @@ -55,13 +55,13 @@ class Loader: else: id = hash(data) id = to_uri(PREF_DOC, id) - + self.metadata.id = id self.api.add_document( - document=data, id=id, metadata=self.metadata, - user=self.user, kind=self.kind, title=self.title, + document=data, id=id, metadata=self.metadata, + kind=self.kind, title=self.title, comments=self.comments, tags=self.tags ) @@ -83,11 +83,16 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -186,12 +191,13 @@ def main(): p = Loader( id=args.identifier, url=args.url, - user=args.user, metadata=document, title=args.name, comments=args.description, kind=args.kind, tags=args.tags, + token=args.token, + workspace=args.workspace, ) p.load(args.files) diff --git a/trustgraph-cli/trustgraph/cli/delete_collection.py b/trustgraph-cli/trustgraph/cli/delete_collection.py index 3e19ac09..aedd801a 100644 --- a/trustgraph-cli/trustgraph/cli/delete_collection.py +++ b/trustgraph-cli/trustgraph/cli/delete_collection.py @@ -7,9 +7,11 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_collection(url, user, collection, confirm): + +def delete_collection(url, collection, confirm, token=None, workspace="default"): if not confirm: response = input(f"Are you sure you want to delete collection '{collection}' and all its data? (y/N): ") @@ -17,9 +19,9 @@ def delete_collection(url, user, collection, confirm): print("Operation cancelled.") return - api = Api(url).collection() + api = Api(url, token=token, workspace=workspace).collection() - api.delete_collection(user=user, collection=collection) + api.delete_collection(collection=collection) print(f"Collection '{collection}' deleted successfully.") @@ -41,27 +43,34 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-y', '--yes', action='store_true', help='Skip confirmation prompt' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: delete_collection( url = args.api_url, - user = args.user, collection = args.collection, - confirm = args.yes + confirm = args.yes, + token = args.token, + workspace = args.workspace, ) except Exception as e: @@ -69,4 +78,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/delete_config_item.py b/trustgraph-cli/trustgraph/cli/delete_config_item.py index cf4cba93..801c2a99 100644 --- a/trustgraph-cli/trustgraph/cli/delete_config_item.py +++ b/trustgraph-cli/trustgraph/cli/delete_config_item.py @@ -9,10 +9,11 @@ from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_config_item(url, config_type, key, token=None): +def delete_config_item(url, config_type, key, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_key = ConfigKey(type=config_type, key=key) api.delete([config_key]) @@ -50,6 +51,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -59,6 +66,8 @@ def main(): config_type=args.type, key=args.key, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py index 9ff8aeba..62140f0e 100644 --- a/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py @@ -9,10 +9,13 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_flow_blueprint(url, blueprint_name): +def delete_flow_blueprint(url, blueprint_name, token=None, + workspace="default"): - api = Api(url).flow() + api = Api(url, token=token, workspace=workspace).flow() blueprint_names = api.delete_blueprint(blueprint_name) @@ -29,6 +32,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', help=f'Flow blueprint name', @@ -41,6 +56,8 @@ def main(): delete_flow_blueprint( url=args.api_url, blueprint_name=args.blueprint_name, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_kg_core.py b/trustgraph-cli/trustgraph/cli/delete_kg_core.py index 81f95e45..0e0753e0 100644 --- a/trustgraph-cli/trustgraph/cli/delete_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/delete_kg_core.py @@ -1,20 +1,20 @@ """ -Deletes a flow class +Deletes a knowledge core """ import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_kg_core(url, user, id): +def delete_kg_core(url, id, token=None, workspace="default"): - api = Api(url).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.delete_kg_core(user = user, id = id) + api.delete_kg_core(id=id) def main(): @@ -29,26 +29,33 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', - ) - parser.add_argument( '--id', '--identifier', required=True, help=f'Knowledge core ID', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: delete_kg_core( url=args.api_url, - user=args.user, id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py b/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py index a3ae7e77..eed9ed21 100644 --- a/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py @@ -10,12 +10,16 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def delete_mcp_tool( url : str, id : str, + token=None, + workspace="default", ): - api = Api(url).config() + api = Api(url, token=token, workspace=workspace).config() # Check if the tool exists first try: @@ -73,6 +77,18 @@ def main(): help='MCP tool ID to delete', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -81,8 +97,10 @@ def main(): raise RuntimeError("Must specify --id for MCP tool to delete") delete_mcp_tool( - url=args.api_url, - id=args.id + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_tool.py b/trustgraph-cli/trustgraph/cli/delete_tool.py index 961c9aa8..50f43fdd 100644 --- a/trustgraph-cli/trustgraph/cli/delete_tool.py +++ b/trustgraph-cli/trustgraph/cli/delete_tool.py @@ -12,12 +12,16 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def delete_tool( url : str, id : str, + token=None, + workspace="default", ): - api = Api(url).config() + api = Api(url, token=token, workspace=workspace).config() # Check if the tool configuration exists try: @@ -78,6 +82,18 @@ def main(): help='Tool ID to delete', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -86,8 +102,10 @@ def main(): raise RuntimeError("Must specify --id for tool to delete") delete_tool( - url=args.api_url, - id=args.id + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/export_workspace_config.py b/trustgraph-cli/trustgraph/cli/export_workspace_config.py new file mode 100644 index 00000000..feef97de --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/export_workspace_config.py @@ -0,0 +1,114 @@ +""" +Exports a curated subset of a workspace's configuration to a JSON file +for later reload into another workspace (useful for cloning test setups). + +The subset covers the config types that define workspace behaviour: +mcp-tool, tool, flow-blueprint, token-cost, agent-pattern, +agent-task-type, parameter-type, interface-description, prompt. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +EXPORT_TYPES = [ + "mcp-tool", + "tool", + "flow-blueprint", + "token-cost", + "agent-pattern", + "agent-task-type", + "parameter-type", + "interface-description", + "prompt", +] + + +def export_workspace_config(url, workspace, output, token=None): + + api = Api(url, token=token, workspace=workspace).config() + + config, version = api.all() + + subset = {} + for t in EXPORT_TYPES: + if t in config: + subset[t] = config[t] + + payload = { + "source_workspace": workspace, + "source_version": version, + "config": subset, + } + + if output == "-": + json.dump(payload, sys.stdout, indent=2) + sys.stdout.write("\n") + else: + with open(output, "w") as f: + json.dump(payload, f, indent=2) + + total = sum(len(v) for v in subset.values()) + print( + f"Exported {total} items across {len(subset)} types " + f"from workspace '{workspace}' (version {version}).", + file=sys.stderr, + ) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-export-workspace-config', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Source workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help='Output JSON file path (use "-" for stdout)', + ) + + args = parser.parse_args() + + try: + + export_workspace_config( + url=args.api_url, + workspace=args.workspace, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_config_item.py b/trustgraph-cli/trustgraph/cli/get_config_item.py index c2421e94..028cc064 100644 --- a/trustgraph-cli/trustgraph/cli/get_config_item.py +++ b/trustgraph-cli/trustgraph/cli/get_config_item.py @@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_config_item(url, config_type, key, format_type, token=None): +def get_config_item(url, config_type, key, format_type, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_key = ConfigKey(type=config_type, key=key) values = api.get([config_key]) @@ -66,6 +68,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -76,6 +84,7 @@ def main(): key=args.key, format_type=args.format, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py index 3d70f37d..62fa7ca2 100644 --- a/trustgraph-cli/trustgraph/cli/get_document_content.py +++ b/trustgraph-cli/trustgraph/cli/get_document_content.py @@ -9,21 +9,19 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_content(url, user, document_id, output_file, token=None): +def get_content(url, document_id, output_file, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - content = api.get_document_content(user=user, id=document_id) + content = api.get_document_content(id=document_id) if output_file: with open(output_file, 'wb') as f: f.write(content) print(f"Written {len(content)} bytes to {output_file}") else: - # Write to stdout - # Try to decode as text, fall back to binary info try: text = content.decode('utf-8') print(text) @@ -51,9 +49,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -73,10 +71,10 @@ def main(): get_content( url=args.api_url, - user=args.user, document_id=args.document_id, output_file=args.output, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py index 817b8f47..56d43a7c 100644 --- a/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py @@ -9,10 +9,12 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_flow_blueprint(url, blueprint_name): +def get_flow_blueprint(url, blueprint_name, token=None, workspace="default"): - api = Api(url).flow() + api = Api(url, token=token, workspace=workspace).flow() cls = api.get_blueprint(blueprint_name) @@ -31,6 +33,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', required=True, @@ -44,6 +58,8 @@ def main(): get_flow_blueprint( url=args.api_url, blueprint_name=args.blueprint_name, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index b75f7155..8bee4115 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,7 +5,6 @@ to a local file in msgpack format. import argparse import os -import textwrap import uuid import asyncio import json @@ -13,17 +12,16 @@ from websockets.asyncio.client import connect import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def write_triple(f, data): msg = ( "t", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -36,9 +34,8 @@ def write_ge(f, data): "ge", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ @@ -52,7 +49,7 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, user, id, output, token=None): +async def fetch(url, workspace, id, output, token=None): if not url.endswith("/"): url += "/" @@ -68,10 +65,11 @@ async def fetch(url, user, id, output, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "get-kg-core", - "user": user, + "workspace": workspace, "id": id, } }) @@ -124,10 +122,11 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -154,11 +153,11 @@ def main(): asyncio.run( fetch( - url = args.url, - user = args.user, - id = args.id, - output = args.output, - token = args.token, + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) ) @@ -167,4 +166,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py index 840f8574..4d4a94b3 100644 --- a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py +++ b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py @@ -13,9 +13,9 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def term_to_rdflib(term): @@ -58,9 +58,10 @@ def term_to_rdflib(term): return rdflib.term.Literal(str(term)) -def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): +def show_graph(url, flow_id, collection, limit, batch_size, + token=None, workspace="default"): - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) g = rdflib.Graph() @@ -68,7 +69,7 @@ def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): try: for batch in flow.triples_query_stream( s=None, p=None, o=None, - user=user, collection=collection, + collection=collection, limit=limit, batch_size=batch_size, ): @@ -108,12 +109,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -126,6 +121,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-l', '--limit', type=int, @@ -147,11 +148,11 @@ def main(): show_graph( url = args.api_url, flow_id = args.flow_id, - user = args.user, collection = args.collection, limit = args.limit, batch_size = args.batch_size, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/import_workspace_config.py b/trustgraph-cli/trustgraph/cli/import_workspace_config.py new file mode 100644 index 00000000..3fe3be97 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/import_workspace_config.py @@ -0,0 +1,143 @@ +""" +Imports a workspace-config dump produced by tg-export-workspace-config +into a target workspace. Writes mcp-tool, tool, flow-blueprint, +token-cost, agent-pattern, agent-task-type, parameter-type, +interface-description and prompt items verbatim. + +Existing items with the same (type, key) are overwritten. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api +from trustgraph.api.types import ConfigValue + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +IMPORT_TYPES = { + "mcp-tool", + "tool", + "flow-blueprint", + "token-cost", + "agent-pattern", + "agent-task-type", + "parameter-type", + "interface-description", + "prompt", +} + + +def import_workspace_config(url, workspace, input_path, token=None, + dry_run=False): + + if input_path == "-": + payload = json.load(sys.stdin) + else: + with open(input_path, "r") as f: + payload = json.load(f) + + # Accept both the wrapped export format and a bare {type: {key: value}} + # dict, so hand-written files are also loadable. + if isinstance(payload, dict) and "config" in payload \ + and isinstance(payload["config"], dict): + config = payload["config"] + source = payload.get("source_workspace") + else: + config = payload + source = None + + skipped_types = set(config.keys()) - IMPORT_TYPES + if skipped_types: + print( + f"Ignoring unsupported types: {sorted(skipped_types)}", + file=sys.stderr, + ) + + values = [] + for t in IMPORT_TYPES: + items = config.get(t, {}) + for key, value in items.items(): + values.append(ConfigValue(type=t, key=key, value=value)) + + if not values: + print("Nothing to import.", file=sys.stderr) + return + + if dry_run: + print( + f"[dry-run] would import {len(values)} items into " + f"workspace '{workspace}'" + + (f" (from '{source}')" if source else "") + ) + return + + api = Api(url, token=token, workspace=workspace).config() + api.put(values) + + print( + f"Imported {len(values)} items into workspace '{workspace}'" + + (f" (from '{source}')." if source else "."), + ) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-import-workspace-config', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Target workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help='Input JSON file path (use "-" for stdin)', + ) + + parser.add_argument( + '--dry-run', + action='store_true', + help='Parse and validate the input without writing anything', + ) + + args = parser.parse_args() + + try: + + import_workspace_config( + url=args.api_url, + workspace=args.workspace, + input_path=args.input, + token=args.token, + dry_run=args.dry_run, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py index 18c240ef..d984f925 100644 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ b/trustgraph-cli/trustgraph/cli/init_trustgraph.py @@ -69,10 +69,11 @@ def ensure_namespace(url, tenant, namespace, config): print(f"Namespace {tenant}/{namespace} created.", flush=True) -def ensure_config(config, **pubsub_config): +def ensure_config(config, workspace="default", **pubsub_config): cli = ConfigClient( subscriber=subscriber, + workspace=workspace, **pubsub_config, ) @@ -147,7 +148,8 @@ def init_pulsar(pulsar_admin_url, tenant): }) -def push_config(config_json, config_file, **pubsub_config): +def push_config(config_json, config_file, workspace="default", + **pubsub_config): """Push initial config if provided.""" if config_json is not None: @@ -160,7 +162,7 @@ def push_config(config_json, config_file, **pubsub_config): print("Exception:", e, flush=True) raise e - ensure_config(dec, **pubsub_config) + ensure_config(dec, workspace=workspace, **pubsub_config) elif config_file is not None: @@ -172,7 +174,7 @@ def push_config(config_json, config_file, **pubsub_config): print("Exception:", e, flush=True) raise e - ensure_config(dec, **pubsub_config) + ensure_config(dec, workspace=workspace, **pubsub_config) else: print("No config to update.", flush=True) @@ -207,6 +209,12 @@ def main(): help=f'Tenant (default: tg)', ) + parser.add_argument( + '-w', '--workspace', + default="default", + help=f'Workspace (default: default)', + ) + add_pubsub_args(parser) args = parser.parse_args() @@ -216,7 +224,10 @@ def main(): # Extract pubsub config from args pubsub_config = { k: v for k, v in vars(args).items() - if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant') + if k not in ( + 'pulsar_admin_url', 'config', 'config_file', 'tenant', + 'workspace', + ) } while True: @@ -241,6 +252,7 @@ def main(): # Push config (works with any backend) push_config( args.config, args.config_file, + workspace=args.workspace, **pubsub_config, ) diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index b379c2df..d815aacd 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -26,7 +26,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class Outputter: @@ -115,11 +115,12 @@ def output(text, prefix="> ", width=78): print(out) def question_explainable( - url, question_text, flow_id, user, collection, - state=None, group=None, verbose=False, token=None, debug=False + url, question_text, flow_id, collection, + state=None, group=None, verbose=False, token=None, debug=False, + workspace="default", ): """Execute agent with explainability - shows provenance events inline.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -132,7 +133,6 @@ def question_explainable( # Stream agent with explainability - process events as they arrive for item in flow.agent_explain( question=question_text, - user=user, collection=collection, state=state, group=group, @@ -191,7 +191,6 @@ def question_explainable( entity = explain_client.fetch_entity( prov_id, graph=explain_graph, - user=user, collection=collection ) @@ -269,11 +268,11 @@ def question_explainable( def question( - url, question, flow_id, user, collection, + url, question, flow_id, collection, plan=None, state=None, group=None, pattern=None, verbose=False, streaming=True, token=None, explainable=False, debug=False, - show_usage=False + show_usage=False, workspace="default", ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -281,13 +280,13 @@ def question( url=url, question_text=question, flow_id=flow_id, - user=user, collection=collection, state=state, group=group, verbose=verbose, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return @@ -296,14 +295,13 @@ def question( print() # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) # Prepare request parameters request_params = { "question": question, - "user": user, "streaming": streaming, } @@ -418,6 +416,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -430,12 +434,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -502,7 +500,6 @@ def main(): url = args.url, flow_id = args.flow_id, question = args.question, - user = args.user, collection = args.collection, plan = args.plan, state = args.state, @@ -514,6 +511,7 @@ def main(): explainable = args.explainable, debug = args.debug, show_usage = args.show_usage, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py index 43bcc985..ed851dff 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, user, collection, limit, token=None): +def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): # Call document embeddings query service result = flow.document_embeddings_query( text=query_text, - user=user, collection=collection, limit=limit ) @@ -59,15 +59,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -97,10 +97,10 @@ def main(): url=args.url, flow_id=args.flow_id, query_text=args.query[0], - user=args.user, collection=args.collection, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index d566f51d..01512ac8 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -18,16 +18,17 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_doc_limit = 10 def question_explainable( - url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False + url, flow_id, question_text, collection, doc_limit, token=None, debug=False, + workspace="default", ): """Execute document RAG with explainability - shows provenance events inline.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -36,8 +37,7 @@ def question_explainable( # Stream DocumentRAG with explainability - process events as they arrive for item in flow.document_rag_explain( query=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, ): if isinstance(item, RAGChunk): @@ -54,8 +54,7 @@ def question_explainable( entity = explain_client.fetch_entity( prov_id, graph=explain_graph, - user=user, - collection=collection + collection=collection ) if entity is None: @@ -98,9 +97,9 @@ def question_explainable( def question( - url, flow_id, question_text, user, collection, doc_limit, + url, flow_id, question_text, collection, doc_limit, streaming=True, token=None, explainable=False, debug=False, - show_usage=False + show_usage=False, workspace="default", ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -108,16 +107,16 @@ def question( url=url, flow_id=flow_id, question_text=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) if streaming: # Use socket client for streaming @@ -127,8 +126,7 @@ def question( try: response = flow.document_rag( query=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, streaming=True ) @@ -155,8 +153,7 @@ def question( flow = api.flow().id(flow_id) result = flow.document_rag( query=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, ) print(result.text) @@ -189,6 +186,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -201,12 +204,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -252,7 +249,6 @@ def main(): url=args.url, flow_id=args.flow_id, question_text=args.question, - user=args.user, collection=args.collection, doc_limit=args.doc_limit, streaming=not args.no_streaming, @@ -260,6 +256,7 @@ def main(): explainable=args.explainable, debug=args.debug, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py index 699a85cf..62eaa039 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, texts, token=None): +def query(url, flow_id, texts, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -51,6 +52,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -72,6 +79,8 @@ def main(): flow_id=args.flow_id, texts=args.texts, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py index 5b0f4c67..c7237c06 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, user, collection, limit, token=None): +def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): # Call graph embeddings query service result = flow.graph_embeddings_query( text=query_text, - user=user, collection=collection, limit=limit ) @@ -69,15 +69,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -107,10 +107,10 @@ def main(): url=args.url, flow_id=args.flow_id, query_text=args.query[0], - user=args.user, collection=args.collection, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index c9efe54d..23d6bcac 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -22,7 +22,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_entity_limit = 50 default_triple_limit = 30 @@ -108,7 +108,7 @@ def _format_provenance_details(event_type, triples): return lines -async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph=None, debug=False): +async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False): """Query triples for a provenance node (single attempt)""" request = { "id": "triples-request", @@ -116,7 +116,6 @@ async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph= "flow": flow_id, "request": { "s": {"t": "i", "i": prov_id}, - "user": user, "collection": collection, "limit": 100 } @@ -182,10 +181,10 @@ async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph= return triples -async def _query_triples(ws_url, flow_id, prov_id, user, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): +async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): """Query triples for a provenance node with retries for race condition""" for attempt in range(max_retries): - triples = await _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph=graph, debug=debug) + triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug) if triples: return triples # Wait before retry if empty (triples may not be stored yet) @@ -196,7 +195,7 @@ async def _query_triples(ws_url, flow_id, prov_id, user, collection, graph=None, return [] -async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, collection, debug=False): +async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False): """ Query for provenance of an edge (s, p, o) in the knowledge graph. @@ -220,7 +219,6 @@ async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, } }, - "user": user, "collection": collection, "limit": 10 } @@ -273,7 +271,6 @@ async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, "request": { "s": {"t": "i", "i": stmt_uri}, "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "user": user, "collection": collection, "limit": 10 } @@ -312,7 +309,7 @@ async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, return sources -async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=False): +async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False): """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" request = { "id": "parent-request", @@ -321,7 +318,6 @@ async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=Fals "request": { "s": {"t": "i", "i": uri}, "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "user": user, "collection": collection, "limit": 1 } @@ -355,7 +351,7 @@ async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=Fals return None -async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, label_cache, debug=False): +async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False): """ Trace the full provenance chain from a source URI up to the root document. Returns a list of (uri, label) tuples from leaf to root. @@ -369,11 +365,11 @@ async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, break # Get label for current entity - label = await _query_label(ws_url, flow_id, current, user, collection, label_cache, debug) + label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug) chain.append((current, label)) # Get parent - parent = await _query_derived_from(ws_url, flow_id, current, user, collection, debug) + parent = await _query_derived_from(ws_url, flow_id, current, collection, debug) if not parent or parent == current: break current = parent @@ -401,7 +397,7 @@ def _is_iri(value): return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") -async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debug=False): +async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False): """ Query for the rdfs:label of an IRI. Uses label_cache to avoid repeated queries. @@ -421,7 +417,6 @@ async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debu "request": { "s": {"t": "i", "i": iri}, "p": {"t": "i", "i": RDFS_LABEL}, - "user": user, "collection": collection, "limit": 1 } @@ -460,7 +455,7 @@ async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debu return label -async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, label_cache, debug=False): +async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False): """ Resolve labels for all IRI components of an edge triple. Returns (s_label, p_label, o_label). @@ -469,15 +464,15 @@ async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, l p = edge_triple.get("p", "?") o = edge_triple.get("o", "?") - s_label = await _query_label(ws_url, flow_id, s, user, collection, label_cache, debug) - p_label = await _query_label(ws_url, flow_id, p, user, collection, label_cache, debug) - o_label = await _query_label(ws_url, flow_id, o, user, collection, label_cache, debug) + s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug) + p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug) + o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug) return s_label, p_label, o_label async def _question_explainable( - url, flow_id, question, user, collection, entity_limit, triple_limit, + url, flow_id, question, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, token=None, debug=False ): """Execute graph RAG with explainability - shows provenance events with details""" @@ -502,7 +497,6 @@ async def _question_explainable( "flow": flow_id, "request": { "query": question, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -549,7 +543,7 @@ async def _question_explainable( # Query triples for this explain node (using named graph filter) triples = await _query_triples( - ws_url, flow_id, explain_id, user, collection, graph=explain_graph, debug=debug + ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug ) # Format and display details @@ -564,7 +558,7 @@ async def _question_explainable( print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) for iri in entity_iris: label = await _query_label( - ws_url, flow_id, iri, user, collection, + ws_url, flow_id, iri, collection, label_cache, debug=debug ) print(f" - {label}", file=sys.stderr) @@ -579,7 +573,7 @@ async def _question_explainable( print(f" [debug] querying edge selection: {o}", file=sys.stderr) # Query the edge selection entity (using named graph filter) edge_triples = await _query_triples( - ws_url, flow_id, o, user, collection, graph=explain_graph, debug=debug + ws_url, flow_id, o, collection, graph=explain_graph, debug=debug ) if debug: print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) @@ -597,7 +591,7 @@ async def _question_explainable( if edge_triple: # Resolve labels for edge components s_label, p_label, o_label = await _resolve_edge_labels( - ws_url, flow_id, edge_triple, user, collection, + ws_url, flow_id, edge_triple, collection, label_cache, debug=debug ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) @@ -605,21 +599,21 @@ async def _question_explainable( r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning print(f" Reason: {r_short}", file=sys.stderr) - # Trace edge provenance in the user's collection (not explainability) + # Trace edge provenance in the workspace collection (not explainability) if edge_triple: sources = await _query_edge_provenance( ws_url, flow_id, edge_triple.get("s", ""), edge_triple.get("p", ""), edge_triple.get("o", ""), - user, collection, # Use the query collection, not explainability + collection, # Use the query collection, not explainability debug=debug ) if sources: for src in sources: # Trace full chain from source to root document chain = await _trace_provenance_chain( - ws_url, flow_id, src, user, collection, + ws_url, flow_id, src, collection, label_cache, debug=debug ) chain_str = _format_provenance_chain(chain) @@ -639,12 +633,12 @@ async def _question_explainable( def _question_explainable_api( - url, flow_id, question_text, user, collection, entity_limit, triple_limit, + url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, - edge_limit=25, token=None, debug=False + edge_limit=25, token=None, debug=False, workspace="default", ): """Execute graph RAG with explainability using the new API classes.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -653,8 +647,7 @@ def _question_explainable_api( # Stream GraphRAG with explainability - process events as they arrive for item in flow.graph_rag_explain( query=question_text, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -676,8 +669,7 @@ def _question_explainable_api( entity = explain_client.fetch_entity( prov_id, graph=explain_graph, - user=user, - collection=collection + collection=collection ) if entity is None: @@ -707,7 +699,7 @@ def _question_explainable_api( if entity.entities: print(f" Seed entities: {len(entity.entities)}", file=sys.stderr) for ent in entity.entities: - label = explain_client.resolve_label(ent, user, collection) + label = explain_client.resolve_label(ent, collection) print(f" - {label}", file=sys.stderr) elif isinstance(entity, Focus): @@ -719,15 +711,14 @@ def _question_explainable_api( focus_full = explain_client.fetch_focus_with_edges( prov_id, graph=explain_graph, - user=user, - collection=collection + collection=collection ) if focus_full and focus_full.edge_selections: for edge_sel in focus_full.edge_selections: if edge_sel.edge: # Resolve labels for edge components s_label, p_label, o_label = explain_client.resolve_edge_labels( - edge_sel.edge, user, collection + edge_sel.edge, collection ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) if edge_sel.reasoning: @@ -750,10 +741,11 @@ def _question_explainable_api( def question( - url, flow_id, question, user, collection, entity_limit, triple_limit, + url, flow_id, question, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=50, edge_limit=25, streaming=True, token=None, - explainable=False, debug=False, show_usage=False + explainable=False, debug=False, show_usage=False, + workspace="default", ): # Explainable mode uses the API to capture and process provenance events @@ -762,8 +754,7 @@ def question( url=url, flow_id=flow_id, question_text=question, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -771,12 +762,13 @@ def question( edge_score_limit=edge_score_limit, edge_limit=edge_limit, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) if streaming: # Use socket client for streaming @@ -786,8 +778,7 @@ def question( try: response = flow.graph_rag( query=question, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -819,8 +810,7 @@ def question( flow = api.flow().id(flow_id) result = flow.graph_rag( query=question, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -857,6 +847,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -869,12 +865,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -955,7 +945,6 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, - user=args.user, collection=args.collection, entity_limit=args.entity_limit, triple_limit=args.triple_limit, @@ -968,6 +957,7 @@ def main(): explainable=args.explainable, debug=args.debug, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py index 3bf521f6..2006e9e8 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_llm.py +++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py @@ -9,12 +9,13 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def query(url, flow_id, system, prompt, streaming=True, token=None, - show_usage=False): + show_usage=False, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -74,6 +75,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( 'system', nargs=1, @@ -116,6 +123,7 @@ def main(): streaming=not args.no_streaming, token=args.token, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py b/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py index c5700c5c..32c20768 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py @@ -11,10 +11,12 @@ import json from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, name, parameters): +def query(url, flow_id, name, parameters, token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) resp = api.mcp_tool(name=name, parameters=parameters) @@ -36,6 +38,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -68,6 +82,8 @@ def main(): flow_id = args.flow_id, name = args.name, parameters = parameters, + token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py b/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py index 8b01187c..332531db 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py @@ -10,9 +10,11 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def nlp_query(url, flow_id, question, max_results, output_format='json'): +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") +def nlp_query(url, flow_id, question, max_results, output_format='json', token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) resp = api.nlp_query( question=question, @@ -63,6 +65,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -100,6 +113,11 @@ def main(): question=args.question, max_results=args.max_results, output_format=args.format, + + token = args.token, + + workspace = args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py index 86f7a024..ed47df90 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py +++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py @@ -14,12 +14,13 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def query(url, flow_id, template_id, variables, streaming=True, token=None, - show_usage=False): + show_usage=False, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -80,6 +81,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -135,6 +142,7 @@ specified multiple times''', streaming=not args.no_streaming, token=args.token, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py index 7393b4c3..8244ae99 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None): +def query(url, flow_id, query_text, schema_name, collection, index_name, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -22,7 +23,6 @@ def query(url, flow_id, query_text, schema_name, user, collection, index_name, l result = flow.row_embeddings_query( text=query_text, schema_name=schema_name, - user=user, collection=collection, index_name=index_name, limit=limit @@ -60,15 +60,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -111,11 +111,11 @@ def main(): flow_id=args.flow_id, query_text=args.query[0], schema_name=args.schema_name, - user=args.user, collection=args.collection, index_name=args.index_name, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_rows_query.py b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py index 962f353c..46fba4d7 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_rows_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py @@ -12,10 +12,11 @@ from trustgraph.api import Api from tabulate import tabulate default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' -def format_output(data, output_format): +def format_output(data, output_format, token=None, workspace="default"): """Format GraphQL response data in the specified format""" if not data: return "No data returned" @@ -82,10 +83,10 @@ def format_table_data(rows, table_name, output_format): return json.dumps({table_name: rows}, indent=2) def rows_query( - url, flow_id, query, user, collection, variables, operation_name, output_format='table' + url, flow_id, query, collection, variables, operation_name, output_format='table', token=None, workspace="default" ): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) # Parse variables if provided as JSON string parsed_variables = {} @@ -98,7 +99,6 @@ def rows_query( resp = api.rows_query( query=query, - user=user, collection=collection, variables=parsed_variables if parsed_variables else None, operation_name=operation_name @@ -135,6 +135,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -148,12 +159,6 @@ def main(): help='GraphQL query to execute', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -185,11 +190,13 @@ def main(): url=args.url, flow_id=args.flow_id, query=args.query, - user=args.user, collection=args.collection, variables=args.variables, operation_name=args.operation_name, output_format=args.format, + token=args.token, + workspace=args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py index 7b1ae9a6..26e03929 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -9,7 +9,8 @@ import sys from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' @@ -44,10 +45,10 @@ def _term_str(val): return str(val) -def sparql_query(url, token, flow_id, query, user, collection, limit, - batch_size, output_format): +def sparql_query(url, token, flow_id, query, collection, limit, + batch_size, output_format, workspace="default"): - socket = Api(url=url, token=token).socket() + socket = Api(url=url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) variables = None @@ -57,7 +58,6 @@ def sparql_query(url, token, flow_id, query, user, collection, limit, for response in flow.sparql_query_stream( query=query, - user=user, collection=collection, limit=limit, batch_size=batch_size, @@ -154,8 +154,14 @@ def main(): parser.add_argument( '-t', '--token', - default=os.getenv("TRUSTGRAPH_TOKEN"), - help='API bearer token (default: TRUSTGRAPH_TOKEN env var)', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -174,12 +180,6 @@ def main(): help='Read SPARQL query from file (use - for stdin)', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -228,11 +228,11 @@ def main(): token=args.token, flow_id=args.flow_id, query=query, - user=args.user, collection=args.collection, limit=args.limit, batch_size=args.batch_size, output_format=args.format, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py index 9f5f8540..af2060bb 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py @@ -13,7 +13,9 @@ from tabulate import tabulate default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def format_output(data, output_format): +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") +def format_output(data, output_format, token=None, workspace="default"): """Format structured query response data in the specified format""" if not data: return "No data returned" @@ -79,11 +81,11 @@ def format_table_data(rows, table_name, output_format): else: return json.dumps({table_name: rows}, indent=2) -def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'): +def structured_query(url, flow_id, question, collection='default', output_format='table', token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) - resp = api.structured_query(question=question, user=user, collection=collection) + resp = api.structured_query(question=question, collection=collection) # Check for errors if "error" in resp and resp["error"]: @@ -119,6 +121,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -132,12 +145,6 @@ def main(): help='Natural language question to execute', ) - parser.add_argument( - '--user', - default='trustgraph', - help='Cassandra keyspace identifier (default: trustgraph)' - ) - parser.add_argument( '--collection', default='default', @@ -159,9 +166,12 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, - user=args.user, collection=args.collection, output_format=args.format, + token=args.token, + + workspace = args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_collections.py b/trustgraph-cli/trustgraph/cli/list_collections.py index 4086f471..e2f90f56 100644 --- a/trustgraph-cli/trustgraph/cli/list_collections.py +++ b/trustgraph-cli/trustgraph/cli/list_collections.py @@ -1,23 +1,22 @@ """ -List collections for a user +List collections in a workspace """ import argparse import os import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def list_collections(url, user, tag_filter): +def list_collections(url, tag_filter, token=None, workspace="default"): - api = Api(url).collection() + api = Api(url, token=token, workspace=workspace).collection() - collections = api.list_collections(user=user, tag_filter=tag_filter) + collections = api.list_collections(tag_filter=tag_filter) - # Handle None or empty collections if not collections or len(collections) == 0: print("No collections found.") return @@ -54,26 +53,33 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--tag-filter', action='append', help='Filter by tags (can be specified multiple times)' ) + parser.add_argument( + '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: list_collections( url = args.api_url, - user = args.user, - tag_filter = args.tag_filter + tag_filter = args.tag_filter, + token = args.token, + workspace = args.workspace, ) except Exception as e: @@ -81,4 +87,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/list_config_items.py b/trustgraph-cli/trustgraph/cli/list_config_items.py index 5cd0f233..8bc3f683 100644 --- a/trustgraph-cli/trustgraph/cli/list_config_items.py +++ b/trustgraph-cli/trustgraph/cli/list_config_items.py @@ -9,10 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def list_config_items(url, config_type, format_type, token=None): +def list_config_items(url, config_type, format_type, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() keys = api.list(config_type) @@ -54,6 +56,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -63,6 +71,7 @@ def main(): config_type=args.type, format_type=args.format, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py index e6d1e075..9bc87db6 100644 --- a/trustgraph-cli/trustgraph/cli/list_explain_traces.py +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -18,7 +18,7 @@ from trustgraph.api import Api, ExplainabilityClient default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Retrieval graph @@ -86,9 +86,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -120,7 +120,7 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() flow = socket.flow(args.flow_id) explain_client = ExplainabilityClient(flow) @@ -129,7 +129,6 @@ def main(): # List all sessions — uses persistent websocket via SocketClient questions = explain_client.list_sessions( graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, limit=args.limit, ) @@ -141,7 +140,6 @@ def main(): session_type = explain_client.detect_session_type( q.uri, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection ) diff --git a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py index 20c78515..a776c59b 100644 --- a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py @@ -46,7 +46,6 @@ async def load_de(running, queue, url): "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], - "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "chunks": [ @@ -77,7 +76,7 @@ async def stats(running): f"Graph embeddings: {de_counts:10d}" ) -async def loader(running, de_queue, path, format, user, collection): +async def loader(running, de_queue, path, format, collection): if format == "json": @@ -96,9 +95,6 @@ async def loader(running, de_queue, path, format, user, collection): except: break - if user: - unpacked["metadata"]["user"] = user - if collection: unpacked["metadata"]["collection"] = collection @@ -148,9 +144,9 @@ async def run(running, **args): running=running, de_queue=de_q, path=args["input_file"], format=args["format"], - user=args["user"], collection=args["collection"], + collection=args["collection"], ) - + ) de_task = asyncio.create_task( @@ -178,7 +174,6 @@ async def main(running): ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") - default_user = "trustgraph" collection = "default" parser.add_argument( @@ -207,11 +202,6 @@ async def main(running): help=f'Output format (default: msgpack)', ) - parser.add_argument( - '--user', - help=f'User ID to load as (default: from input)' - ) - parser.add_argument( '--collection', help=f'Collection ID to load as (default: from input)' diff --git a/trustgraph-cli/trustgraph/cli/load_kg_core.py b/trustgraph-cli/trustgraph/cli/load_kg_core.py index 008b124f..281255be 100644 --- a/trustgraph-cli/trustgraph/cli/load_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/load_kg_core.py @@ -6,20 +6,19 @@ run this utility. import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_flow = "default" default_collection = "default" -def load_kg_core(url, user, id, flow, collection): +def load_kg_core(url, id, flow, collection, token=None, workspace="default"): - api = Api(url).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.load_kg_core(user = user, id = id, flow=flow, - collection=collection) + api.load_kg_core(id=id, flow=flow, collection=collection) def main(): @@ -34,12 +33,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', - ) - parser.add_argument( '--id', '--identifier', required=True, @@ -49,13 +42,25 @@ def main(): parser.add_argument( '-f', '--flow-id', default=default_flow, - help=f'Flow ID (default: {default_flow}', + help=f'Flow ID (default: {default_flow})', ) parser.add_argument( '-C', '--collection', default=default_collection, - help=f'Collection ID (default: {default_collection}', + help=f'Collection ID (default: {default_collection})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -64,10 +69,11 @@ def main(): load_kg_core( url=args.api_url, - user=args.user, id=args.id, flow=args.flow_id, collection=args.collection, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/load_knowledge.py b/trustgraph-cli/trustgraph/cli/load_knowledge.py index 5e96850f..7e9dadd4 100644 --- a/trustgraph-cli/trustgraph/cli/load_knowledge.py +++ b/trustgraph-cli/trustgraph/cli/load_knowledge.py @@ -13,7 +13,7 @@ from trustgraph.log_level import LogLevel default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class KnowledgeLoader: @@ -22,19 +22,18 @@ class KnowledgeLoader: self, files, flow, - user, collection, document_id, url=default_url, - token=None, + token=None, workspace="default", ): self.files = files self.flow = flow - self.user = user self.collection = collection self.document_id = document_id self.url = url self.token = token + self.workspace = workspace def load_triples_from_file(self, file) -> Iterator[Triple]: """Generator that yields Triple objects from a Turtle file""" @@ -43,11 +42,9 @@ class KnowledgeLoader: g.parse(file, format="turtle") for e in g: - # Extract subject, predicate, object s_value = str(e[0]) p_value = str(e[1]) - # Check if object is a URI or literal if isinstance(e[2], rdflib.term.URIRef): o_value = str(e[2]) o_is_uri = True @@ -55,9 +52,6 @@ class KnowledgeLoader: o_value = str(e[2]) o_is_uri = False - # Create Triple object - # Note: The Triple dataclass has 's', 'p', 'o' fields as strings - # The API will handle the metadata wrapping yield Triple(s=s_value, p=p_value, o=o_value) def load_entity_contexts_from_file(self, file) -> Iterator[Tuple[str, str]]: @@ -67,11 +61,9 @@ class KnowledgeLoader: g.parse(file, format="turtle") for s, p, o in g: - # If object is a URI, skip (we only want literal contexts) if isinstance(o, rdflib.term.URIRef): continue - # If object is a literal, create entity context for subject s_str = str(s) o_str = str(o) @@ -81,11 +73,9 @@ class KnowledgeLoader: """Load triples and entity contexts using Python API""" try: - # Create API client - api = Api(url=self.url, token=self.token) + api = Api(url=self.url, token=self.token, workspace=self.workspace) bulk = api.bulk() - # Load triples from all files print("Loading triples...") total_triples = 0 for file in self.files: @@ -104,7 +94,6 @@ class KnowledgeLoader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, "collection": self.collection } ) @@ -113,20 +102,16 @@ class KnowledgeLoader: print(f"Triples loaded. Total: {total_triples}") - # Load entity contexts from all files print("Loading entity contexts...") total_contexts = 0 for file in self.files: print(f" Processing {file}...") count = 0 - # Convert tuples to the format expected by import_entity_contexts - # Entity must be in Term format: {"t": "i", "i": uri} for IRI def entity_context_generator(): nonlocal count for entity, context in self.load_entity_contexts_from_file(file): count += 1 - # Entities from RDF are URIs, use IRI term format yield { "entity": {"t": "i", "i": entity}, "context": context @@ -138,7 +123,6 @@ class KnowledgeLoader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, "collection": self.collection } ) @@ -170,6 +154,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -182,12 +172,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -210,8 +194,8 @@ def main(): token=args.token, flow=args.flow_id, files=args.files, - user=args.user, collection=args.collection, + workspace=args.workspace, ) loader.run() diff --git a/trustgraph-cli/trustgraph/cli/load_sample_documents.py b/trustgraph-cli/trustgraph/cli/load_sample_documents.py index 186006a8..0398864c 100644 --- a/trustgraph-cli/trustgraph/cli/load_sample_documents.py +++ b/trustgraph-cli/trustgraph/cli/load_sample_documents.py @@ -12,8 +12,8 @@ from trustgraph.api import Api from trustgraph.api.types import hash, Uri, Literal, Triple default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") from requests.adapters import HTTPAdapter @@ -656,11 +656,10 @@ documents = [ class Loader: def __init__( - self, url, user, token=None + self, url, token=None, workspace="default", ): - self.api = Api(url, token=token).library() - self.user = user + self.api = Api(url, token=token, workspace=workspace).library() def load(self, documents): @@ -689,10 +688,10 @@ class Loader: print(" adding...") self.api.add_document( - id = doc["id"], metadata = doc["metadata"], - user = self.user, kind = doc["kind"], title = doc["title"], - comments = doc["comments"], tags = doc["tags"], - document = content + id=doc["id"], metadata=doc["metadata"], + kind=doc["kind"], title=doc["title"], + comments=doc["comments"], tags=doc["tags"], + document=content, ) print(" successful.") @@ -714,26 +713,26 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--token', default=default_token, help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: p = Loader( url=args.url, - user=args.user, token=args.token, + workspace=args.workspace, ) p.load(documents) diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index fa167917..3cd2a229 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def load_structured_data( @@ -39,11 +40,11 @@ def load_structured_data( sample_chars: int = 500, schema_name: str = None, flow: str = 'default', - user: str = 'trustgraph', collection: str = 'default', dry_run: bool = False, verbose: bool = False, - token: str = None + token: str = None, + workspace: str = "default", ): """ Load structured data using a descriptor configuration. @@ -62,7 +63,6 @@ def load_structured_data( sample_chars: Maximum characters to read for sampling schema_name: Target schema name for generation flow: TrustGraph flow name to use for prompts - user: User name for metadata (default: trustgraph) collection: Collection name for metadata (default: default) dry_run: If True, validate but don't import data verbose: Enable verbose logging @@ -78,7 +78,7 @@ def load_structured_data( logger.info("Step 1: Analyzing data to discover best matching schema...") # Step 1: Auto-discover schema (reuse discover_schema logic) - discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger) + discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) if not discovered_schema: logger.error("Failed to discover suitable schema automatically") print("❌ Could not automatically determine the best schema for your data.") @@ -90,7 +90,7 @@ def load_structured_data( # Step 2: Auto-generate descriptor logger.info("Step 2: Generating descriptor configuration...") - auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger) + auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace) if not auto_descriptor: logger.error("Failed to generate descriptor automatically") print("❌ Could not automatically generate descriptor configuration.") @@ -110,7 +110,7 @@ def load_structured_data( try: # Use shared pipeline for preview (small sample) - preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, user, collection, sample_size=5) + preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, collection, sample_size=5) # Show preview print("📊 Data Preview (first few records):") @@ -131,13 +131,13 @@ def load_structured_data( print("🚀 Importing data to TrustGraph...") # Use shared pipeline for full processing (no sample limit) - output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, user, collection) + output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, collection) # Get batch size from descriptor batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) # Send to TrustGraph using shared function - imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token) + imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token, workspace=workspace) # Summary format_info = descriptor.get('format', {}) @@ -172,7 +172,7 @@ def load_structured_data( logger.info(f"Sample chars: {sample_chars} characters") # Use the helper function to discover schema (get raw response for display) - response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True) + response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace) if response: # Debug: print response type and content @@ -203,7 +203,7 @@ def load_structured_data( # If no schema specified, discover it first if not schema_name: logger.info("No schema specified, auto-discovering...") - schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger) + schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) if not schema_name: print("Error: Could not determine schema automatically.") print("Please specify a schema using --schema-name or run --discover-schema first.") @@ -213,7 +213,7 @@ def load_structured_data( logger.info(f"Target schema: {schema_name}") # Generate descriptor using helper function - descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger) + descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace) if descriptor: # Output the generated descriptor @@ -242,7 +242,7 @@ def load_structured_data( logger.info(f"Parsing {input_file} with descriptor {descriptor_file}...") # Use shared pipeline - output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size) + output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection, sample_size) # Output results if output_file: @@ -286,7 +286,7 @@ def load_structured_data( logger.info(f"Loading {input_file} to TrustGraph using descriptor {descriptor_file}...") # Use shared pipeline (no sample_size limit for full load) - output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection) + output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection) # Get batch size from descriptor or use default batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) @@ -527,18 +527,17 @@ def _apply_transformations(records, mappings): return processed_records -def _format_extracted_objects(processed_records, descriptor, user, collection): +def _format_extracted_objects(processed_records, descriptor, collection): """Convert to TrustGraph ExtractedObject format""" output_records = [] schema_name = descriptor.get('output', {}).get('schema_name', 'default') confidence = descriptor.get('output', {}).get('options', {}).get('confidence', 0.9) - + for record in processed_records: output_record = { "metadata": { "id": f"parsed-{len(output_records)+1}", "metadata": [], # Empty metadata triples - "user": user, "collection": collection }, "schema_name": schema_name, @@ -551,7 +550,7 @@ def _format_extracted_objects(processed_records, descriptor, user, collection): return output_records -def _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size=None): +def _process_data_pipeline(input_file, descriptor_file, collection, sample_size=None): """Shared pipeline: load descriptor → read → parse → transform → format""" # Load descriptor configuration descriptor = _load_descriptor(descriptor_file) @@ -568,12 +567,12 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample processed_records = _apply_transformations(parsed_records, mappings) # Format output for TrustGraph ExtractedObject structure - output_records = _format_extracted_objects(processed_records, descriptor, user, collection) + output_records = _format_extracted_objects(processed_records, descriptor, collection) return output_records, descriptor -def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): +def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, workspace="default"): """Send ExtractedObject records to TrustGraph using Python API""" from trustgraph.api import Api @@ -582,7 +581,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): logger.info(f"Importing {total_records} records to TrustGraph...") # Use Python API bulk import - api = Api(api_url, token=token) + api = Api(api_url, token=token, workspace=workspace) bulk = api.bulk() bulk.import_rows(flow=flow, rows=iter(rows)) @@ -604,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): # Helper functions for auto mode -def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False): +def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"): """Auto-discover the best matching schema for the input data Args: @@ -627,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url) + api = Api(api_url, workspace=workspace) config_api = api.config() # Get available schemas @@ -708,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur return None -def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger): +def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"): """Auto-generate descriptor configuration for the discovered schema""" try: # Read sample data @@ -718,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url) + api = Api(api_url, workspace=workspace) config_api = api.config() # Get schema definition @@ -885,12 +884,6 @@ For more information on the descriptor format, see: help='TrustGraph flow name to use for prompts and import (default: default)' ) - parser.add_argument( - '--user', - default='trustgraph', - help='User name for metadata (default: trustgraph)' - ) - parser.add_argument( '--collection', default='default', @@ -997,6 +990,12 @@ For more information on the descriptor format, see: help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() # Input validation @@ -1046,11 +1045,11 @@ For more information on the descriptor format, see: sample_chars=args.sample_chars, schema_name=args.schema_name, flow=args.flow, - user=args.user, collection=args.collection, dry_run=args.dry_run, verbose=args.verbose, - token=args.token + token=args.token, + workspace=args.workspace, ) except FileNotFoundError as e: print(f"Error: File not found - {e}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/load_turtle.py b/trustgraph-cli/trustgraph/cli/load_turtle.py index adb578f5..43ef9e6f 100644 --- a/trustgraph-cli/trustgraph/cli/load_turtle.py +++ b/trustgraph-cli/trustgraph/cli/load_turtle.py @@ -13,7 +13,7 @@ from trustgraph.log_level import LogLevel default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class Loader: @@ -22,15 +22,14 @@ class Loader: self, files, flow, - user, collection, document_id, url=default_url, - token=None, + token=None, workspace="default", ): self.files = files self.flow = flow - self.user = user + self.workspace = workspace self.collection = collection self.document_id = document_id self.url = url @@ -43,28 +42,23 @@ class Loader: g.parse(file, format="turtle") for e in g: - # Extract subject, predicate, object s_value = str(e[0]) p_value = str(e[1]) - # Check if object is a URI or literal if isinstance(e[2], rdflib.term.URIRef): o_value = str(e[2]) else: o_value = str(e[2]) - # Create Triple object yield Triple(s=s_value, p=p_value, o=o_value) def run(self): """Load triples using Python API""" try: - # Create API client - api = Api(url=self.url, token=self.token) + api = Api(url=self.url, token=self.token, workspace=self.workspace) bulk = api.bulk() - # Load triples from all files print("Loading triples...") for file in self.files: print(f" Processing {file}...") @@ -76,7 +70,6 @@ class Loader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, "collection": self.collection } ) @@ -106,6 +99,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -118,12 +117,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -146,8 +139,8 @@ def main(): token=args.token, flow=args.flow_id, files=args.files, - user=args.user, collection=args.collection, + workspace=args.workspace, ) loader.run() diff --git a/trustgraph-cli/trustgraph/cli/put_config_item.py b/trustgraph-cli/trustgraph/cli/put_config_item.py index d79864a4..fda9cbeb 100644 --- a/trustgraph-cli/trustgraph/cli/put_config_item.py +++ b/trustgraph-cli/trustgraph/cli/put_config_item.py @@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigValue default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def put_config_item(url, config_type, key, value, token=None): +def put_config_item(url, config_type, key, value, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_value = ConfigValue(type=config_type, key=key, value=value) api.put([config_value]) @@ -63,6 +65,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -78,6 +86,7 @@ def main(): key=args.key, value=value, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py index 740a224a..96db6bec 100644 --- a/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py @@ -10,10 +10,12 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def put_flow_blueprint(url, blueprint_name, config, token=None): +def put_flow_blueprint(url, blueprint_name, config, token=None, + workspace="default"): - api = Api(url, token=token) + api = Api(url, token=token, workspace=workspace) blueprint_names = api.flow().put_blueprint(blueprint_name, config) @@ -36,6 +38,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', help=f'Flow blueprint name', @@ -55,6 +63,7 @@ def main(): blueprint_name=args.blueprint_name, config=json.loads(args.config), token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index cd0738fe..bd3169c8 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -1,10 +1,9 @@ """ -Uses the agent service to answer a question +Puts a knowledge core into the knowledge manager via the API socket. """ import argparse import os -import textwrap import uuid import asyncio import json @@ -12,18 +11,17 @@ from websockets.asyncio.client import connect import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): -def read_message(unpacked, id, user): - if unpacked[0] == "ge": msg = unpacked[1] return "ge", { "metadata": { "id": id, "metadata": msg["m"]["m"], - "user": user, "collection": "default", # Not used? }, "entities": [ @@ -40,7 +38,6 @@ def read_message(unpacked, id, user): "metadata": { "id": id, "metadata": msg["m"]["m"], - "user": user, "collection": "default", # Not used by receiver? }, "triples": msg["t"], @@ -48,7 +45,7 @@ def read_message(unpacked, id, user): else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, user, id, input, token=None): +async def put(url, workspace, id, input, token=None): if not url.endswith("/"): url += "/" @@ -60,7 +57,6 @@ async def put(url, user, id, input, token=None): async with connect(url) as ws: - ge = 0 t = 0 @@ -75,7 +71,7 @@ async def put(url, user, id, input, token=None): except: break - kind, msg = read_message(unpacked, id, user) + kind, msg = read_message(unpacked, id) mid = str(uuid.uuid4()) @@ -85,10 +81,11 @@ async def put(url, user, id, input, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "graph-embeddings": msg } @@ -100,10 +97,11 @@ async def put(url, user, id, input, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "triples": msg } @@ -117,7 +115,7 @@ async def put(url, user, id, input, token=None): # Retry loop, wait for right response to come back while True: - + msg = await ws.recv() msg = json.loads(msg) @@ -146,10 +144,11 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -176,11 +175,11 @@ def main(): asyncio.run( put( - url = args.url, - user = args.user, - id = args.id, - input = args.input, - token = args.token, + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) ) @@ -189,4 +188,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/query_graph.py b/trustgraph-cli/trustgraph/cli/query_graph.py index a2c38353..091f0599 100644 --- a/trustgraph-cli/trustgraph/cli/query_graph.py +++ b/trustgraph-cli/trustgraph/cli/query_graph.py @@ -23,9 +23,9 @@ import sys from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def parse_inline_quoted_triple(value): @@ -285,15 +285,16 @@ def output_jsonl(triples): def query_graph( - url, flow_id, user, collection, limit, batch_size, + url, flow_id, collection, limit, batch_size, subject=None, predicate=None, obj=None, graph=None, - output_format="space", headers=False, token=None + output_format="space", headers=False, token=None, + workspace="default", ): """Query the triple store with pattern matching. Uses the API's triples_query_stream for efficient streaming delivery. """ - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) all_triples = [] @@ -305,7 +306,6 @@ def query_graph( p=predicate, o=obj, g=graph, - user=user, collection=collection, limit=limit, batch_size=batch_size, @@ -456,13 +456,6 @@ def main(): help='Flow ID (default: default)' ) - std_group.add_argument( - '-U', '--user', - default=default_user, - metavar='USER', - help=f'User/keyspace (default: {default_user})' - ) - std_group.add_argument( '-C', '--collection', default=default_collection, @@ -477,6 +470,12 @@ def main(): help='Auth token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + std_group.add_argument( '-l', '--limit', type=int, @@ -550,7 +549,6 @@ def main(): query_graph( url=args.api_url, flow_id=args.flow_id, - user=args.user, collection=args.collection, limit=args.limit, batch_size=args.batch_size, @@ -561,6 +559,8 @@ def main(): output_format=args.format, headers=args.headers, token=args.token, + + workspace=args.workspace, ) except json.JSONDecodeError as e: diff --git a/trustgraph-cli/trustgraph/cli/remove_library_document.py b/trustgraph-cli/trustgraph/cli/remove_library_document.py index 07a1fd59..d6500d50 100644 --- a/trustgraph-cli/trustgraph/cli/remove_library_document.py +++ b/trustgraph-cli/trustgraph/cli/remove_library_document.py @@ -4,20 +4,19 @@ Remove a document from the library import argparse import os -import uuid from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def remove_doc(url, user, id, token=None): +def remove_doc(url, id, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - api.remove_document(user=user, id=id) + api.remove_document(id=id) def main(): @@ -32,12 +31,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '--identifier', '--id', required=True, @@ -50,15 +43,24 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: - remove_doc(args.url, args.user, args.identifier, token=args.token) + remove_doc( + args.url, args.identifier, + token=args.token, workspace=args.workspace, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py index ca8d25de..99d6b4db 100644 --- a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py @@ -21,7 +21,7 @@ class Running: def get(self): return self.running def stop(self): self.running = False -async def fetch_de(running, queue, user, collection, url): +async def fetch_de(running, queue, collection, url): async with aiohttp.ClientSession() as session: @@ -38,10 +38,6 @@ async def fetch_de(running, queue, user, collection, url): data = msg.json() - if user: - if data["metadata"]["user"] != user: - continue - if collection: if data["metadata"]["collection"] != collection: continue @@ -52,7 +48,6 @@ async def fetch_de(running, queue, user, collection, url): "m": { "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "c": [ @@ -119,7 +114,7 @@ async def run(running, **args): de_task = asyncio.create_task( fetch_de( running=running, - queue=q, user=args["user"], collection=args["collection"], + queue=q, collection=args["collection"], url = f"{url}api/v1/flow/{flow_id}/export/document-embeddings" ) ) @@ -148,7 +143,6 @@ async def main(running): ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") - default_user = "trustgraph" collection = "default" parser.add_argument( @@ -177,11 +171,6 @@ async def main(running): help=f'Output format (default: msgpack)', ) - parser.add_argument( - '--user', - help=f'User ID to filter on (default: no filter)' - ) - parser.add_argument( '--collection', help=f'Collection ID to filter on (default: no filter)' diff --git a/trustgraph-cli/trustgraph/cli/set_collection.py b/trustgraph-cli/trustgraph/cli/set_collection.py index dd4148ea..53aaa74d 100644 --- a/trustgraph-cli/trustgraph/cli/set_collection.py +++ b/trustgraph-cli/trustgraph/cli/set_collection.py @@ -8,15 +8,14 @@ import tabulate from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_collection(url, user, collection, name, description, tags, token=None): +def set_collection(url, collection, name, description, tags, token=None, workspace="default"): - api = Api(url, token=token).collection() + api = Api(url, token=token, workspace=workspace).collection() result = api.update_collection( - user=user, collection=collection, name=name, description=description, @@ -59,12 +58,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-n', '--name', help='Collection name' @@ -88,18 +81,24 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: set_collection( url = args.api_url, - user = args.user, collection = args.collection, name = args.name, description = args.description, tags = args.tags, - token = args.token + token = args.token, + workspace=args.workspace, ) except Exception as e: @@ -107,4 +106,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py index 7976adbc..65c640c6 100644 --- a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py @@ -21,6 +21,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def set_mcp_tool( url : str, @@ -29,9 +30,10 @@ def set_mcp_tool( tool_url : str, auth_token : str = None, token : str = None, + workspace : str = "default", ): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() # Build the MCP tool configuration config = { @@ -80,6 +82,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--id', required=True, @@ -126,6 +134,8 @@ def main(): tool_url=args.tool_url, auth_token=args.auth_token, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_prompt.py b/trustgraph-cli/trustgraph/cli/set_prompt.py index bffc2cf2..dbf9c326 100644 --- a/trustgraph-cli/trustgraph/cli/set_prompt.py +++ b/trustgraph-cli/trustgraph/cli/set_prompt.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_system(url, system, token=None): +def set_system(url, system, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() api.put([ ConfigValue(type="prompt", key="system", value=json.dumps(system)) @@ -22,9 +23,9 @@ def set_system(url, system, token=None): print("System prompt set.") -def set_prompt(url, id, prompt, response, schema, token=None): +def set_prompt(url, id, prompt, response, schema, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="prompt", key="template-index") @@ -78,6 +79,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '--id', help=f'Prompt ID', diff --git a/trustgraph-cli/trustgraph/cli/set_token_costs.py b/trustgraph-cli/trustgraph/cli/set_token_costs.py index 19b8c703..9b046a7d 100644 --- a/trustgraph-cli/trustgraph/cli/set_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/set_token_costs.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_costs(api_url, model, input_costs, output_costs, token=None): +def set_costs(api_url, model, input_costs, output_costs, token=None, workspace="default"): - api = Api(api_url, token=token).config() + api = Api(api_url, token=token, workspace=workspace).config() api.put([ ConfigValue( @@ -26,9 +27,9 @@ def set_costs(api_url, model, input_costs, output_costs, token=None): ), ]) -def set_prompt(url, id, prompt, response, schema): +def set_prompt(url, id, prompt, response, schema, workspace="default"): - api = Api(url) + api = Api(url, workspace=workspace) values = api.config_get([ ConfigKey(type="prompt", key="template-index") @@ -102,6 +103,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: diff --git a/trustgraph-cli/trustgraph/cli/set_tool.py b/trustgraph-cli/trustgraph/cli/set_tool.py index c6412e48..45295089 100644 --- a/trustgraph-cli/trustgraph/cli/set_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_tool.py @@ -28,6 +28,7 @@ import dataclasses default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @dataclasses.dataclass class Argument: @@ -73,9 +74,10 @@ def set_tool( state : str, applicable_states : List[str], token : str = None, + workspace : str = "default", ): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="agent", key="tool-index") @@ -181,6 +183,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '--id', help=f'Unique tool identifier', @@ -303,6 +311,8 @@ def main(): state=args.state, applicable_states=args.applicable_states, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_config.py b/trustgraph-cli/trustgraph/cli/show_config.py index 6f426533..130c59b7 100644 --- a/trustgraph-cli/trustgraph/cli/show_config.py +++ b/trustgraph-cli/trustgraph/cli/show_config.py @@ -9,10 +9,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config, version = api.all() @@ -38,6 +39,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -45,6 +52,7 @@ def main(): show_config( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index 90c0e452..17aaca1a 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -36,7 +36,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Graphs @@ -50,13 +50,12 @@ PROV = "http://www.w3.org/ns/prov#" PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client): +def trace_edge_provenance(flow, collection, edge, label_cache, explain_client): """ Trace an edge back to its source document via reification. Args: flow: SocketFlowInstance - user: User identifier collection: Collection identifier edge: Dict with s, p, o keys label_cache: Dict for caching labels @@ -90,7 +89,6 @@ def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_cli p=TG_CONTAINS, o=quoted_triple, g=SOURCE_GRAPH, - user=user, collection=collection, limit=10 ) @@ -108,14 +106,14 @@ def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_cli # For each statement, trace wasDerivedFrom chain provenance_chains = [] for stmt_uri in stmt_uris: - chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client) + chain = trace_provenance_chain(flow, collection, stmt_uri, label_cache, explain_client) if chain: provenance_chains.append(chain) return provenance_chains -def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10): +def trace_provenance_chain(flow, collection, start_uri, label_cache, explain_client, max_depth=10): """Trace prov:wasDerivedFrom chain from start_uri to root.""" chain = [] current = start_uri @@ -128,7 +126,7 @@ def trace_provenance_chain(flow, user, collection, start_uri, label_cache, expla if current in label_cache: label = label_cache[current] else: - label = explain_client.resolve_label(current, user, collection) + label = explain_client.resolve_label(current, collection) label_cache[current] = label chain.append({"uri": current, "label": label}) @@ -139,7 +137,6 @@ def trace_provenance_chain(flow, user, collection, start_uri, label_cache, expla s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH, - user=user, collection=collection, limit=1 ) @@ -167,7 +164,7 @@ def format_provenance_chain(chain): return " -> ".join(labels) -def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, show_provenance=False): +def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_provenance=False): """Print GraphRAG trace in text format.""" question = trace.get("question") @@ -202,7 +199,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, for i, edge_sel in enumerate(edges, 1): if edge_sel.edge: s_label, p_label, o_label = explain_client.resolve_edge_labels( - edge_sel.edge, user, collection + edge_sel.edge, collection ) print(f" {i}. ({s_label}, {p_label}, {o_label})") @@ -212,7 +209,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, if show_provenance and edge_sel.edge: provenance = trace_edge_provenance( - flow, user, collection, edge_sel.edge, + flow, collection, edge_sel.edge, label_cache, explain_client ) for chain in provenance: @@ -238,7 +235,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, content = "" if synthesis.document and api: content = explain_client.fetch_document_content( - synthesis.document, api, user + synthesis.document, api ) if content: print("Answer:") @@ -252,7 +249,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, print("No synthesis data found") -def print_docrag_text(trace, explain_client, api, user): +def print_docrag_text(trace, explain_client, api): """Print DocRAG trace in text format.""" question = trace.get("question") @@ -288,7 +285,7 @@ def print_docrag_text(trace, explain_client, api, user): content = "" if synthesis.document and api: content = explain_client.fetch_document_content( - synthesis.document, api, user + synthesis.document, api ) if content: print("Answer:") @@ -302,14 +299,14 @@ def print_docrag_text(trace, explain_client, api, user): print("No synthesis data found") -def _print_document_content(explain_client, api, user, document_uri, label="Answer"): +def _print_document_content(explain_client, api, document_uri, label="Answer"): """Fetch and print document content, or fall back to URI.""" if not document_uri: return content = "" if api: content = explain_client.fetch_document_content( - document_uri, api, user + document_uri, api ) if content: print(f"{label}:") @@ -319,7 +316,7 @@ def _print_document_content(explain_client, api, user, document_uri, label="Answ print(f"Document: {document_uri}") -def print_agent_text(trace, explain_client, api, user): +def print_agent_text(trace, explain_client, api): """Print Agent trace in text format.""" question = trace.get("question") @@ -348,7 +345,7 @@ def print_agent_text(trace, explain_client, api, user): print("--- Finding ---") print(f"Goal: {step.goal}") _print_document_content( - explain_client, api, user, step.document, "Result", + explain_client, api, step.document, "Result", ) print() @@ -363,7 +360,7 @@ def print_agent_text(trace, explain_client, api, user): print("--- Step Result ---") print(f"Step: {step.step}") _print_document_content( - explain_client, api, user, step.document, "Result", + explain_client, api, step.document, "Result", ) print() @@ -385,21 +382,21 @@ def print_agent_text(trace, explain_client, api, user): elif isinstance(step, Observation): print("--- Observation ---") _print_document_content( - explain_client, api, user, step.document, "Content", + explain_client, api, step.document, "Content", ) print() elif isinstance(step, Synthesis): print("--- Synthesis ---") _print_document_content( - explain_client, api, user, step.document, "Answer", + explain_client, api, step.document, "Answer", ) print() elif isinstance(step, Conclusion): print("--- Conclusion ---") _print_document_content( - explain_client, api, user, step.document, "Answer", + explain_client, api, step.document, "Answer", ) print() @@ -559,9 +556,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -599,7 +596,7 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() flow = socket.flow(args.flow_id) explain_client = ExplainabilityClient(flow) @@ -609,7 +606,6 @@ def main(): trace_type = explain_client.detect_session_type( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, ) @@ -618,7 +614,6 @@ def main(): trace = explain_client.fetch_agent_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, api=api, max_content=args.max_answer, @@ -627,14 +622,13 @@ def main(): if args.format == 'json': print(json.dumps(trace_to_dict(trace, "agent"), indent=2)) else: - print_agent_text(trace, explain_client, api, args.user) + print_agent_text(trace, explain_client, api) elif trace_type == "docrag": # Fetch and display DocRAG trace trace = explain_client.fetch_docrag_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, api=api, max_content=args.max_answer, @@ -643,14 +637,13 @@ def main(): if args.format == 'json': print(json.dumps(trace_to_dict(trace, "docrag"), indent=2)) else: - print_docrag_text(trace, explain_client, api, args.user) + print_docrag_text(trace, explain_client, api) else: # Fetch and display GraphRAG trace trace = explain_client.fetch_graphrag_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, api=api, max_content=args.max_answer, @@ -661,7 +654,7 @@ def main(): else: print_graphrag_text( trace, explain_client, flow, - args.user, args.collection, + args.collection, api=api, show_provenance=args.show_provenance ) diff --git a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py index 4f87712c..49bf78ee 100644 --- a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py +++ b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py @@ -17,7 +17,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Predicates @@ -45,10 +45,9 @@ TYPE_MAP = { SOURCE_GRAPH = "urn:graph:source" -def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000): +def query_triples(socket, flow_id, collection, s=None, p=None, o=None, g=None, limit=1000): """Query triples using the socket API.""" request = { - "user": user, "collection": collection, "limit": limit, "streaming": False, @@ -120,9 +119,9 @@ def extract_value(term): return str(term) -def get_node_metadata(socket, flow_id, user, collection, node_uri): +def get_node_metadata(socket, flow_id, collection, node_uri): """Get metadata for a node (label, types, title, format).""" - triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=SOURCE_GRAPH) + triples = query_triples(socket, flow_id, collection, s=node_uri, g=SOURCE_GRAPH) metadata = {"uri": node_uri, "types": []} for s, p, o in triples: @@ -146,20 +145,20 @@ def classify_node(metadata): return "unknown" -def get_children(socket, flow_id, user, collection, parent_uri): +def get_children(socket, flow_id, collection, parent_uri): """Get children of a node via prov:wasDerivedFrom.""" triples = query_triples( - socket, flow_id, user, collection, + socket, flow_id, collection, p=PROV_WAS_DERIVED_FROM, o=parent_uri, g=SOURCE_GRAPH ) return [s for s, p, o in triples] -def get_document_content(api, user, doc_id, max_content): +def get_document_content(api, doc_id, max_content): """Fetch document content from librarian API.""" try: library = api.library() - content = library.get_document_content(user=user, id=doc_id) + content = library.get_document_content(id=doc_id) # Try to decode as text try: @@ -173,7 +172,7 @@ def get_document_content(api, user, doc_id, max_content): return f"[Error fetching content: {e}]" -def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_content=False, max_content=200, visited=None): +def build_hierarchy(socket, flow_id, collection, root_uri, api=None, show_content=False, max_content=200, visited=None): """Build document hierarchy tree recursively.""" if visited is None: visited = set() @@ -182,7 +181,7 @@ def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_ return None visited.add(root_uri) - metadata = get_node_metadata(socket, flow_id, user, collection, root_uri) + metadata = get_node_metadata(socket, flow_id, collection, root_uri) node_type = classify_node(metadata) node = { @@ -195,21 +194,21 @@ def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_ # Fetch content if requested if show_content and api: - content = get_document_content(api, user, root_uri, max_content) + content = get_document_content(api, root_uri, max_content) if content: node["content"] = content # Get children - children_uris = get_children(socket, flow_id, user, collection, root_uri) + children_uris = get_children(socket, flow_id, collection, root_uri) for child_uri in children_uris: - child_metadata = get_node_metadata(socket, flow_id, user, collection, child_uri) + child_metadata = get_node_metadata(socket, flow_id, collection, child_uri) child_type = classify_node(child_metadata) if child_type == "subgraph": # Subgraphs contain extracted edges — inline them contains_triples = query_triples( - socket, flow_id, user, collection, + socket, flow_id, collection, s=child_uri, p=TG_CONTAINS, g=SOURCE_GRAPH ) for _, _, edge in contains_triples: @@ -218,7 +217,7 @@ def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_ else: # Recurse into pages, chunks, etc. child_node = build_hierarchy( - socket, flow_id, user, collection, child_uri, + socket, flow_id, collection, child_uri, api=api, show_content=show_content, max_content=max_content, visited=visited ) @@ -331,9 +330,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -371,14 +370,13 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() try: hierarchy = build_hierarchy( socket=socket, flow_id=args.flow_id, - user=args.user, collection=args.collection, root_uri=args.document_id, api=api if args.show_content else None, diff --git a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py index 8d16d098..4924c925 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py @@ -11,6 +11,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def format_parameters(params_metadata, param_type_defs): """ @@ -44,12 +45,13 @@ def format_parameters(params_metadata, param_type_defs): return "\n".join(param_list) -async def fetch_data(client): +async def fetch_data(client, workspace): """Fetch all data needed for show_flow_blueprints concurrently.""" # Round 1: list blueprints resp = await client._send_request("flow", None, { "operation": "list-blueprints", + "workspace": workspace, }) blueprint_names = resp.get("blueprint-names", []) @@ -60,6 +62,7 @@ async def fetch_data(client): blueprint_tasks = [ client._send_request("flow", None, { "operation": "get-blueprint", + "workspace": workspace, "blueprint-name": name, }) for name in blueprint_names @@ -84,6 +87,7 @@ async def fetch_data(client): param_type_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "parameter-type", "key": pt}], }) for pt in param_types_needed @@ -100,14 +104,16 @@ async def fetch_data(client): return blueprint_names, blueprints, param_type_defs -async def _show_flow_blueprints_async(url, token=None): +async def _show_flow_blueprints_async(url, token=None, workspace="default"): async with AsyncSocketClient(url, timeout=60, token=token) as client: - return await fetch_data(client) + return await fetch_data(client, workspace) -def show_flow_blueprints(url, token=None): +def show_flow_blueprints(url, token=None, workspace="default"): blueprint_names, blueprints, param_type_defs = asyncio.run( - _show_flow_blueprints_async(url, token=token) + _show_flow_blueprints_async( + url, token=token, workspace=workspace, + ) ) if not blueprint_names: @@ -156,6 +162,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -163,6 +175,7 @@ def main(): show_flow_blueprints( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flow_state.py b/trustgraph-cli/trustgraph/cli/show_flow_state.py index d5d87f2c..8fec04ec 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_state.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_state.py @@ -10,10 +10,12 @@ import os default_metrics_url = "http://localhost:8088/api/metrics" default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def dump_status(metrics_url, api_url, flow_id, token=None): +def dump_status(metrics_url, api_url, flow_id, token=None, + workspace="default"): - api = Api(api_url, token=token).flow() + api = Api(api_url, token=token, workspace=workspace).flow() flow = api.get(flow_id) blueprint_name = flow["blueprint-name"] @@ -84,11 +86,20 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: - dump_status(args.metrics_url, args.api_url, args.flow_id, token=args.token) + dump_status( + args.metrics_url, args.api_url, args.flow_id, + token=args.token, workspace=args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py index f7a14469..6e9479f9 100644 --- a/trustgraph-cli/trustgraph/cli/show_flows.py +++ b/trustgraph-cli/trustgraph/cli/show_flows.py @@ -11,6 +11,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def describe_interfaces(intdefs, flow): @@ -97,17 +98,19 @@ def format_parameters(flow_params, blueprint_params_metadata, param_type_defs): return "\n".join(param_list) if param_list else "None" -async def fetch_show_flows(client): +async def fetch_show_flows(client, workspace): """Fetch all data needed for show_flows concurrently.""" # Round 1: list interfaces and list flows in parallel interface_names_resp, flow_ids_resp = await asyncio.gather( client._send_request("config", None, { "operation": "list", + "workspace": workspace, "type": "interface-description", }), client._send_request("flow", None, { "operation": "list-flows", + "workspace": workspace, }), ) @@ -115,12 +118,13 @@ async def fetch_show_flows(client): flow_ids = flow_ids_resp.get("flow-ids", []) if not flow_ids: - return {}, [], {}, {} + return {}, [], {}, {}, {} # Round 2: get all interfaces + all flows in parallel interface_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "interface-description", "key": name}], }) for name in interface_names @@ -129,6 +133,7 @@ async def fetch_show_flows(client): flow_tasks = [ client._send_request("flow", None, { "operation": "get-flow", + "workspace": workspace, "flow-id": fid, }) for fid in flow_ids @@ -163,6 +168,7 @@ async def fetch_show_flows(client): blueprint_tasks = [ client._send_request("flow", None, { "operation": "get-blueprint", + "workspace": workspace, "blueprint-name": bp_name, }) for bp_name in blueprint_names @@ -186,6 +192,7 @@ async def fetch_show_flows(client): param_type_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "parameter-type", "key": pt}], }) for pt in param_types_needed @@ -204,14 +211,16 @@ async def fetch_show_flows(client): return interface_defs, flow_ids, flows, blueprints, param_type_defs -async def _show_flows_async(url, token=None): +async def _show_flows_async(url, token=None, workspace="default"): async with AsyncSocketClient(url, timeout=60, token=token) as client: - return await fetch_show_flows(client) + return await fetch_show_flows(client, workspace) -def show_flows(url, token=None): +def show_flows(url, token=None, workspace="default"): - result = asyncio.run(_show_flows_async(url, token=token)) + result = asyncio.run(_show_flows_async( + url, token=token, workspace=workspace, + )) interface_defs, flow_ids, flows, blueprints, param_type_defs = result @@ -269,6 +278,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -276,6 +291,7 @@ def main(): show_flows( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_graph.py b/trustgraph-cli/trustgraph/cli/show_graph.py index 8db4edf4..6063b05a 100644 --- a/trustgraph-cli/trustgraph/cli/show_graph.py +++ b/trustgraph-cli/trustgraph/cli/show_graph.py @@ -13,9 +13,9 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") # Named graph constants for convenience GRAPH_DEFAULT = "" @@ -23,14 +23,13 @@ GRAPH_SOURCE = "urn:graph:source" GRAPH_RETRIEVAL = "urn:graph:retrieval" -def show_graph(url, flow_id, user, collection, limit, batch_size, graph=None, show_graph_column=False, token=None): +def show_graph(url, flow_id, collection, limit, batch_size, graph=None, show_graph_column=False, token=None, workspace="default"): - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) try: for batch in flow.triples_query_stream( - user=user, collection=collection, s=None, p=None, o=None, g=graph, # Filter by named graph (None = all graphs) @@ -73,12 +72,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -91,6 +84,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-l', '--limit', type=int, @@ -129,13 +128,13 @@ def main(): show_graph( url = args.api_url, flow_id = args.flow_id, - user = args.user, collection = args.collection, limit = args.limit, batch_size = args.batch_size, graph = graph, show_graph_column = args.show_graph, token = args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_kg_cores.py b/trustgraph-cli/trustgraph/cli/show_kg_cores.py index ea295543..c9d47889 100644 --- a/trustgraph-cli/trustgraph/cli/show_kg_cores.py +++ b/trustgraph-cli/trustgraph/cli/show_kg_cores.py @@ -4,16 +4,15 @@ Shows knowledge cores import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_cores(url, user, token=None): +def show_cores(url, token=None, workspace="default"): - api = Api(url, token=token).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() ids = api.list_kg_cores() @@ -26,7 +25,7 @@ def show_cores(url, user, token=None): def main(): parser = argparse.ArgumentParser( - prog='tg-show-flows', + prog='tg-show-kg-cores', description=__doc__, ) @@ -43,9 +42,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -54,8 +53,8 @@ def main(): show_cores( url=args.api_url, - user=args.user, token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -63,4 +62,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_library_documents.py b/trustgraph-cli/trustgraph/cli/show_library_documents.py index 6eeceb70..12a89f1a 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_documents.py +++ b/trustgraph-cli/trustgraph/cli/show_library_documents.py @@ -10,13 +10,13 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_docs(url, user, token=None): +def show_docs(url, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - docs = api.get_documents(user=user) + docs = api.get_documents() if len(docs) == 0: print("No documents.") @@ -60,9 +60,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -71,8 +71,8 @@ def main(): show_docs( url = args.api_url, - user = args.user, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_library_processing.py b/trustgraph-cli/trustgraph/cli/show_library_processing.py index 9ab69355..700a0f83 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/show_library_processing.py @@ -4,18 +4,17 @@ import argparse import os import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_procs(url, user, token=None): +def show_procs(url, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - procs = api.get_processings(user = user) + procs = api.get_processings() if len(procs) == 0: print("No processing objects.") @@ -52,24 +51,26 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--token', default=default_token, help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: show_procs( - url = args.api_url, user = args.user, token = args.token + url=args.api_url, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -77,4 +78,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py index 24cbfcfe..d5f7a1c1 100644 --- a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get_values(type="mcp") @@ -64,6 +65,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -71,6 +78,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_prompts.py b/trustgraph-cli/trustgraph/cli/show_prompts.py index 0e1cb2ae..cad6f317 100644 --- a/trustgraph-cli/trustgraph/cli/show_prompts.py +++ b/trustgraph-cli/trustgraph/cli/show_prompts.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="prompt", key="system"), @@ -85,6 +86,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -92,6 +99,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_token_costs.py b/trustgraph-cli/trustgraph/cli/show_token_costs.py index adc13ad7..c7a7bff2 100644 --- a/trustgraph-cli/trustgraph/cli/show_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/show_token_costs.py @@ -13,10 +13,11 @@ tabulate.PRESERVE_WHITESPACE = True default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() models = api.list("token-cost") @@ -68,6 +69,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -75,6 +82,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_tools.py b/trustgraph-cli/trustgraph/cli/show_tools.py index d77f1fae..51aeacbf 100644 --- a/trustgraph-cli/trustgraph/cli/show_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_tools.py @@ -19,10 +19,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get_values(type="tool") @@ -116,6 +117,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -123,6 +130,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_flow.py b/trustgraph-cli/trustgraph/cli/start_flow.py index e04e241d..f65ffc49 100644 --- a/trustgraph-cli/trustgraph/cli/start_flow.py +++ b/trustgraph-cli/trustgraph/cli/start_flow.py @@ -18,10 +18,12 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def start_flow(url, blueprint_name, flow_id, description, parameters=None, token=None): +def start_flow(url, blueprint_name, flow_id, description, parameters=None, + token=None, workspace="default"): - api = Api(url, token=token).flow() + api = Api(url, token=token, workspace=workspace).flow() api.start( blueprint_name = blueprint_name, @@ -49,6 +51,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', required=True, @@ -120,6 +128,7 @@ def main(): description = args.description, parameters = parameters, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_library_processing.py b/trustgraph-cli/trustgraph/cli/start_library_processing.py index ff87ea9f..27b5f33d 100644 --- a/trustgraph-cli/trustgraph/cli/start_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/start_library_processing.py @@ -4,19 +4,18 @@ Submits a library document for processing import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def start_processing( - url, user, document_id, id, flow, collection, tags, token=None + url, document_id, id, flow, collection, tags, + token=None, workspace="default", ): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() if tags: tags = tags.split(",") @@ -27,9 +26,8 @@ def start_processing( id = id, document_id = document_id, flow = flow, - user = user, collection = collection, - tags = tags + tags = tags, ) def main(): @@ -52,9 +50,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -91,14 +89,14 @@ def main(): try: start_processing( - url = args.api_url, - user = args.user, - document_id = args.document_id, - id = args.id, - flow = args.flow_id, - collection = args.collection, - tags = args.tags, - token = args.token, + url=args.api_url, + document_id=args.document_id, + id=args.id, + flow=args.flow_id, + collection=args.collection, + tags=args.tags, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -106,4 +104,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/stop_flow.py b/trustgraph-cli/trustgraph/cli/stop_flow.py index ae3a0415..7e2d0798 100644 --- a/trustgraph-cli/trustgraph/cli/stop_flow.py +++ b/trustgraph-cli/trustgraph/cli/stop_flow.py @@ -10,10 +10,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def stop_flow(url, flow_id, token=None): +def stop_flow(url, flow_id, token=None, workspace="default"): - api = Api(url, token=token).flow() + api = Api(url, token=token, workspace=workspace).flow() api.stop(id = flow_id) @@ -36,6 +37,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--flow-id', required=True, @@ -50,6 +57,7 @@ def main(): url=args.api_url, flow_id=args.flow_id, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/stop_library_processing.py b/trustgraph-cli/trustgraph/cli/stop_library_processing.py index 3d8a2c56..72a8dbb8 100644 --- a/trustgraph-cli/trustgraph/cli/stop_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/stop_library_processing.py @@ -5,21 +5,17 @@ procesing, it doesn't stop in-flight processing at the moment. import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def stop_processing( - url, user, id, token=None -): +def stop_processing(url, id, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - api.stop_processing(user = user, id = id) + api.stop_processing(id=id) def main(): @@ -41,9 +37,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -57,10 +53,10 @@ def main(): try: stop_processing( - url = args.api_url, - user = args.user, - id = args.id, - token = args.token, + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -68,4 +64,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/unload_kg_core.py b/trustgraph-cli/trustgraph/cli/unload_kg_core.py index 47f811f3..45c56067 100644 --- a/trustgraph-cli/trustgraph/cli/unload_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/unload_kg_core.py @@ -1,25 +1,21 @@ """ -Starts a load operation on a knowledge core which is already stored by -the knowledge manager. You could load a core with tg-put-kg-core and then -run this utility. +Unloads a knowledge core from a flow. """ import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_flow = "default" -default_collection = "default" -def unload_kg_core(url, user, id, flow, token=None): +def unload_kg_core(url, id, flow, token=None, workspace="default"): - api = Api(url, token=token).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.unload_kg_core(user = user, id = id, flow=flow) + api.unload_kg_core(id=id, flow=flow) def main(): @@ -41,9 +37,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -55,7 +51,7 @@ def main(): parser.add_argument( '-f', '--flow-id', default=default_flow, - help=f'Flow ID (default: {default_flow}', + help=f'Flow ID (default: {default_flow})', ) args = parser.parse_args() @@ -64,10 +60,10 @@ def main(): unload_kg_core( url=args.api_url, - user=args.user, id=args.id, flow=args.flow_id, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index 9491deaa..4ec055b7 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -25,6 +25,7 @@ default_pulsar_url = "http://localhost:8080" default_api_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") default_ui_url = "http://localhost:8888" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") class HealthChecker: @@ -210,10 +211,10 @@ def check_processors(url: str, min_processors: int, timeout: int, tr, token: Opt return False, tr.t("cli.verify_system_status.processors.error", error=str(e)) -def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if flow blueprints are loaded.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) flow_api = api.flow() blueprints = flow_api.list_blueprints() @@ -227,10 +228,10 @@ def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = Non return False, tr.t("cli.verify_system_status.flow_blueprints.error", error=str(e)) -def check_flows(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_flows(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if flow manager is responding.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) flow_api = api.flow() flows = flow_api.list() @@ -242,10 +243,10 @@ def check_flows(url: str, timeout: int, tr, token: Optional[str] = None) -> Tupl return False, tr.t("cli.verify_system_status.flows.error", error=str(e)) -def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if prompts are loaded.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) config = api.config() # Import ConfigKey here to avoid top-level import issues @@ -268,14 +269,14 @@ def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None) -> Tu return False, tr.t("cli.verify_system_status.prompts.error", error=str(e)) -def check_library(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_library(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if library service is responding.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) library_api = api.library() - # Try to get documents (with default user) - docs = library_api.get_documents(user="trustgraph") + # Try to get documents + docs = library_api.get_documents() # Success if we get a valid response (even if empty) return True, tr.t("cli.verify_system_status.library.responding", count=len(docs)) @@ -376,6 +377,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)' ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-v', '--verbose', action='store_true', @@ -438,6 +445,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) checker.run_check( @@ -447,6 +455,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) checker.run_check( @@ -456,6 +465,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) print() @@ -471,6 +481,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) print() diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index c793f9ca..8ea72260 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -26,42 +26,50 @@ class Service(ToolService): self.register_config_handler(self.on_mcp_config, types=["mcp"]) + # Per-workspace MCP service registries self.mcp_services = {} - async def on_mcp_config(self, config, version): + async def on_mcp_config(self, workspace, config, version): - logger.info(f"Got config version {version}") + logger.info( + f"Got config version {version} for workspace {workspace}" + ) if "mcp" not in config: - self.mcp_services = {} + self.mcp_services[workspace] = {} return - self.mcp_services = { + self.mcp_services[workspace] = { k: json.loads(v) for k, v in config["mcp"].items() } - async def invoke_tool(self, name, parameters): + async def invoke_tool(self, workspace, name, parameters): try: - if name not in self.mcp_services: - raise RuntimeError(f"MCP service {name} not known") + ws_services = self.mcp_services.get(workspace, {}) - if "url" not in self.mcp_services[name]: + if name not in ws_services: + raise RuntimeError( + f"MCP service {name} not known in workspace " + f"{workspace}" + ) + + if "url" not in ws_services[name]: raise RuntimeError(f"MCP service {name} URL not defined") - url = self.mcp_services[name]["url"] + url = ws_services[name]["url"] - if "remote-name" in self.mcp_services[name]: - remote_name = self.mcp_services[name]["remote-name"] + if "remote-name" in ws_services[name]: + remote_name = ws_services[name]["remote-name"] else: remote_name = name # Build headers with optional bearer token headers = {} - if "auth-token" in self.mcp_services[name]: - token = self.mcp_services[name]["auth-token"] + if "auth-token" in ws_services[name]: + token = ws_services[name]["auth-token"] headers["Authorization"] = f"Bearer {token}" logger.info(f"Invoking {remote_name} at {url}") diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py index cc5eb85c..c06b8c54 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py @@ -108,7 +108,7 @@ class Aggregator: ) def build_synthesis_request(self, correlation_id, original_question, - user, collection): + collection): """ Build the AgentRequest that triggers the synthesis phase. """ @@ -139,7 +139,6 @@ class Aggregator: state="", group=template.group if template else [], history=history, - user=user, collection=collection, streaming=template.streaming if template else False, session_id=parent_session_id, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index 6daba1a1..01abedf3 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -46,25 +46,20 @@ from ..tool_filter import filter_tools_by_group_and_state, get_next_state logger = logging.getLogger(__name__) -class UserAwareContext: - """Wraps flow interface to inject user context for tools that need it.""" +class FlowContext: + """Wraps flow interface with orchestrator-only scratch state + (explain URIs, response handle, streaming flag). Workspace isolation + is enforced by the flow layer (flow.workspace), not by this class.""" - def __init__(self, flow, user, respond=None, streaming=False): + def __init__(self, flow, respond=None, streaming=False): self._flow = flow - self._user = user self.respond = respond self.streaming = streaming self.current_explain_uri = None self.last_sub_explain_uri = None def __call__(self, service_name): - client = self._flow(service_name) - if service_name in ( - "structured-query-request", - "row-embeddings-query-request", - ): - client._current_user = self._user - return client + return self._flow(service_name) class UsageTracker: @@ -131,7 +126,6 @@ class PatternBase: state="", group=getattr(request, 'group', []), history=[completion_step], - user=request.user, collection=getattr(request, 'collection', 'default'), streaming=False, session_id=getattr(request, 'session_id', ''), @@ -158,9 +152,9 @@ class PatternBase: current_state=getattr(request, 'state', None), ) - def make_context(self, flow, user, respond=None, streaming=False): - """Create a user-aware context wrapper.""" - return UserAwareContext(flow, user, respond=respond, streaming=streaming) + def make_context(self, flow, respond=None, streaming=False): + """Create a flow context wrapper.""" + return FlowContext(flow, respond=respond, streaming=streaming) def build_history(self, request): """Convert AgentStep history into Action objects.""" @@ -249,7 +243,7 @@ class PatternBase: # ---- Provenance emission ------------------------------------------------ - async def emit_session_triples(self, flow, session_uri, question, user, + async def emit_session_triples(self, flow, session_uri, question, collection, respond, streaming, parent_uri=None): """Emit provenance triples for a new session.""" @@ -264,7 +258,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=session_uri, - user=user, collection=collection, ), triples=triples, @@ -281,7 +274,7 @@ class PatternBase: async def emit_pattern_decision_triples( self, flow, session_id, session_uri, pattern, task_type, - user, collection, respond, + collection, respond, ): """Emit provenance triples for a meta-router pattern decision.""" uri = agent_pattern_decision_uri(session_id) @@ -292,7 +285,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -329,7 +322,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=thought_doc_id, - user=request.user, + workspace=flow.workspace, content=act.thought, title=f"Agent Thought: {act.name}", ) @@ -360,7 +353,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=iteration_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=iter_triples, @@ -399,7 +391,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=observation_doc_id, - user=request.user, + workspace=flow.workspace, content=observation_text, title=f"Agent Observation", ) @@ -420,7 +412,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=observation_entity_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=obs_triples, @@ -456,7 +447,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=answer_doc_id, - user=request.user, + workspace=flow.workspace, content=answer_text, title=f"Agent Answer: {request.question[:50]}...", ) @@ -478,7 +469,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=final_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=final_triples, @@ -496,7 +486,7 @@ class PatternBase: # ---- Orchestrator provenance helpers ------------------------------------ async def emit_decomposition_triples( - self, flow, session_id, session_uri, goals, user, collection, + self, flow, session_id, session_uri, goals, collection, respond, streaming, ): """Emit provenance for a supervisor decomposition step.""" @@ -506,7 +496,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -516,7 +506,7 @@ class PatternBase: )) async def emit_finding_triples( - self, flow, session_id, index, goal, answer_text, user, collection, + self, flow, session_id, index, goal, answer_text, collection, respond, streaming, subagent_session_id="", ): """Emit provenance for a subagent finding.""" @@ -532,7 +522,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=flow.workspace, content=answer_text, title=f"Finding: {goal[:60]}", ) @@ -545,7 +535,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -555,7 +545,7 @@ class PatternBase: )) async def emit_plan_triples( - self, flow, session_id, session_uri, steps, user, collection, + self, flow, session_id, session_uri, steps, collection, respond, streaming, ): """Emit provenance for a plan creation.""" @@ -565,7 +555,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -575,7 +565,7 @@ class PatternBase: )) async def emit_step_result_triples( - self, flow, session_id, index, goal, answer_text, user, collection, + self, flow, session_id, index, goal, answer_text, collection, respond, streaming, ): """Emit provenance for a plan step result.""" @@ -585,7 +575,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=flow.workspace, content=answer_text, title=f"Step result: {goal[:60]}", ) @@ -598,7 +588,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -608,7 +598,7 @@ class PatternBase: )) async def emit_synthesis_triples( - self, flow, session_id, previous_uris, answer_text, user, collection, + self, flow, session_id, previous_uris, answer_text, collection, respond, streaming, termination_reason=None, ): """Emit provenance for a synthesis answer.""" @@ -617,7 +607,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=flow.workspace, content=answer_text, title="Synthesis", ) @@ -633,7 +623,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -751,7 +741,6 @@ class PatternBase: ) for h in history ], - user=request.user, collection=collection, streaming=streaming, session_id=session_id, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 1de31a92..0cc9013f 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -53,7 +53,7 @@ class PlanThenExecutePattern(PatternBase): if iteration_num == 1: await self.emit_session_triples( flow, session_uri, request.question, - request.user, collection, respond, streaming, + collection, respond, streaming, ) logger.info( @@ -109,11 +109,17 @@ class PlanThenExecutePattern(PatternBase): think = self.make_think_callback(respond, streaming) - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) framing = getattr(request, 'framing', '') context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -147,7 +153,7 @@ class PlanThenExecutePattern(PatternBase): step_goals = [ps.get("goal", "") for ps in plan_steps] await self.emit_plan_triples( flow, session_id, session_uri, step_goals, - request.user, collection, respond, streaming, + collection, respond, streaming, ) # Build PlanStep objects @@ -179,7 +185,6 @@ class PlanThenExecutePattern(PatternBase): state=request.state, group=getattr(request, 'group', []), history=new_history, - user=request.user, collection=collection, streaming=streaming, session_id=session_id, @@ -237,9 +242,15 @@ class PlanThenExecutePattern(PatternBase): "result": dep_result, }) - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) @@ -307,7 +318,7 @@ class PlanThenExecutePattern(PatternBase): # Emit step result provenance await self.emit_step_result_triples( flow, session_id, pending_idx, goal, step_result, - request.user, collection, respond, streaming, + collection, respond, streaming, ) # Build execution step for history @@ -327,7 +338,6 @@ class PlanThenExecutePattern(PatternBase): state=request.state, group=getattr(request, 'group', []), history=new_history, - user=request.user, collection=collection, streaming=streaming, session_id=session_id, @@ -352,7 +362,7 @@ class PlanThenExecutePattern(PatternBase): framing = getattr(request, 'framing', '') context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -387,7 +397,7 @@ class PlanThenExecutePattern(PatternBase): last_step_uri = make_step_result_uri(session_id, len(plan) - 1) await self.emit_synthesis_triples( flow, session_id, last_step_uri, - response_text, request.user, collection, respond, streaming, + response_text, collection, respond, streaming, termination_reason="plan-complete", ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 25264c26..4920ebf1 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -61,7 +61,7 @@ class ReactPattern(PatternBase): ) await self.emit_session_triples( flow, session_uri, request.question, - request.user, collection, respond, streaming, + collection, respond, streaming, parent_uri=parent_uri, ) @@ -80,13 +80,20 @@ class ReactPattern(PatternBase): observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id) answer_cb = self.make_answer_callback(respond, streaming, message_id=answer_msg_id) + # Look up the per-workspace agent + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + # Filter tools filtered_tools = self.filter_tools( - self.processor.agent.tools, request, + agent.tools, request, ) # Create temporary agent with filtered tools and optional framing - additional_context = self.processor.agent.additional_context + additional_context = agent.additional_context framing = getattr(request, 'framing', '') if framing: if additional_context: @@ -100,7 +107,7 @@ class ReactPattern(PatternBase): ) context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 3d08154d..b57ca79d 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -42,7 +42,7 @@ from ..tool_filter import validate_tool_config from ..react.types import Final, Action, Tool, Argument from . meta_router import MetaRouter -from . pattern_base import PatternBase, UserAwareContext +from . pattern_base import PatternBase, FlowContext from . react_pattern import ReactPattern from . plan_pattern import PlanThenExecutePattern from . supervisor_pattern import SupervisorPattern @@ -76,10 +76,9 @@ class Processor(AgentService): } ) - self.agent = AgentManager( - tools={}, - additional_context="", - ) + # Per-workspace agent managers and meta-routers + self.agents = {} + self.meta_routers = {} self.tool_service_clients = {} @@ -91,9 +90,6 @@ class Processor(AgentService): # Aggregator for supervisor fan-in self.aggregator = Aggregator() - # Meta-router (initialised on first config load) - self.meta_router = None - self.register_config_handler( self.on_tools_config, types=["tool", "tool-service"] ) @@ -204,13 +200,13 @@ class Processor(AgentService): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): request_id = str(uuid.uuid4()) doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "Agent Answer", document_type="answer", @@ -221,7 +217,7 @@ class Processor(AgentService): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) future = asyncio.get_event_loop().create_future() @@ -247,9 +243,12 @@ class Processor(AgentService): def provenance_session_uri(self, session_id): return agent_session_uri(session_id) - async def on_tools_config(self, config, version): + async def on_tools_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) try: tools = {} @@ -316,7 +315,6 @@ class Processor(AgentService): impl = functools.partial( StructuredQueryImpl, collection=data.get("collection"), - user=None, ) arguments = StructuredQueryImpl.get_arguments() elif impl_id == "row-embeddings-query": @@ -324,7 +322,6 @@ class Processor(AgentService): RowEmbeddingsQueryImpl, schema_name=data.get("schema-name"), collection=data.get("collection"), - user=None, index_name=data.get("index-name"), limit=int(data.get("limit", 10)), ) @@ -408,15 +405,17 @@ class Processor(AgentService): agent_config = config[self.config_key] additional = agent_config.get("additional-context", None) - self.agent = AgentManager( + self.agents[workspace] = AgentManager( tools=tools, additional_context=additional, ) - # Re-initialise meta-router with config - self.meta_router = MetaRouter(config=config) + # Re-initialise meta-router with config for this workspace + self.meta_routers[workspace] = MetaRouter(config=config) - logger.info(f"Loaded {len(tools)} tools") + logger.info( + f"Loaded {len(tools)} tools for workspace {workspace}" + ) except Exception as e: logger.error( @@ -466,7 +465,7 @@ class Processor(AgentService): await self.supervisor_pattern.emit_finding_triples( flow, parent_session_id, finding_index, subagent_goal, answer_text, - template.user, collection, + collection, respond, template.streaming, subagent_session_id=subagent_session_id, ) @@ -486,7 +485,6 @@ class Processor(AgentService): synthesis_request = self.aggregator.build_synthesis_request( correlation_id, original_question=template.question, - user=template.user, collection=getattr(template, 'collection', 'default'), ) @@ -515,10 +513,11 @@ class Processor(AgentService): # If no pattern set and this is the first iteration, route if not pattern and not request.history: - context = UserAwareContext(flow, request.user) + context = FlowContext(flow) - if self.meta_router: - pattern, task_type, framing = await self.meta_router.route( + meta_router = self.meta_routers.get(flow.workspace) + if meta_router: + pattern, task_type, framing = await meta_router.route( request.question, context, usage=usage, ) else: @@ -553,7 +552,6 @@ class Processor(AgentService): await selected.emit_pattern_decision_triples( flow, session_id, session_uri, pattern, getattr(request, 'task_type', ''), - request.user, getattr(request, 'collection', 'default'), respond, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 973a9966..f9a1751d 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -54,7 +54,7 @@ class SupervisorPattern(PatternBase): if iteration_num == 1: await self.emit_session_triples( flow, session_uri, request.question, - request.user, collection, respond, streaming, + collection, respond, streaming, ) logger.info( @@ -99,10 +99,16 @@ class SupervisorPattern(PatternBase): ) framing = getattr(request, 'framing', '') - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -144,7 +150,7 @@ class SupervisorPattern(PatternBase): # Emit decomposition provenance await self.emit_decomposition_triples( flow, session_id, session_uri, goals, - request.user, collection, respond, streaming, + collection, respond, streaming, ) # Fan out: emit a subagent request for each goal @@ -155,7 +161,6 @@ class SupervisorPattern(PatternBase): state="", group=getattr(request, 'group', []), history=[], - user=request.user, collection=collection, streaming=False, # Subagents don't stream session_id=subagent_session, @@ -207,7 +212,7 @@ class SupervisorPattern(PatternBase): subagent_results = {"(no results)": "No subagent results available"} context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -237,7 +242,7 @@ class SupervisorPattern(PatternBase): ] await self.emit_synthesis_triples( flow, session_id, finding_uris, - response_text, request.user, collection, respond, streaming, + response_text, collection, respond, streaming, termination_reason="subagents-complete", ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 1512fa83..7140284f 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -10,6 +10,7 @@ import sys import functools import logging import uuid +from typing import Dict from datetime import datetime, timezone # Module logger @@ -73,10 +74,8 @@ class Processor(AgentService): } ) - self.agent = AgentManager( - tools={}, - additional_context="", - ) + # Per-workspace agent managers + self.agents: Dict[str, AgentManager] = {} # Track active tool service clients for cleanup self.tool_service_clients = {} @@ -193,13 +192,13 @@ class Processor(AgentService): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """ Save answer content to the librarian. Args: doc_id: ID for the answer document - user: User ID + workspace: Workspace for isolation content: Answer text content title: Optional title timeout: Request timeout in seconds @@ -211,7 +210,7 @@ class Processor(AgentService): doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "Agent Answer", document_type="answer", @@ -222,7 +221,7 @@ class Processor(AgentService): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) # Create future for response @@ -249,9 +248,12 @@ class Processor(AgentService): self.pending_librarian_requests.pop(request_id, None) raise RuntimeError(f"Timeout saving answer document {doc_id}") - async def on_tools_config(self, config, version): + async def on_tools_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) try: @@ -321,7 +323,6 @@ class Processor(AgentService): impl = functools.partial( StructuredQueryImpl, collection=data.get("collection"), - user=None # User will be provided dynamically via context ) arguments = StructuredQueryImpl.get_arguments() elif impl_id == "row-embeddings-query": @@ -329,7 +330,6 @@ class Processor(AgentService): RowEmbeddingsQueryImpl, schema_name=data.get("schema-name"), collection=data.get("collection"), - user=None, # User will be provided dynamically via context index_name=data.get("index-name"), # Optional filter limit=int(data.get("limit", 10)) # Max results ) @@ -409,13 +409,17 @@ class Processor(AgentService): agent_config = config[self.config_key] additional = agent_config.get("additional-context", None) - self.agent = AgentManager( + self.agents[workspace] = AgentManager( tools=tools, additional_context=additional ) - logger.info(f"Loaded {len(tools)} tools") - logger.info("Tool configuration reloaded.") + logger.info( + f"Loaded {len(tools)} tools for workspace {workspace}" + ) + logger.info( + f"Tool configuration reloaded for workspace {workspace}." + ) except Exception as e: @@ -460,7 +464,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=session_uri, - user=request.user, collection=collection, ), triples=triples, @@ -557,35 +560,41 @@ class Processor(AgentService): await respond(r) + # Look up the agent for this workspace + workspace = flow.workspace + agent = self.agents.get(workspace) + if agent is None: + logger.error( + f"No agent configuration loaded for workspace " + f"{workspace}" + ) + raise RuntimeError( + f"No agent configuration for workspace {workspace}" + ) + # Apply tool filtering based on request groups and state filtered_tools = filter_tools_by_group_and_state( - tools=self.agent.tools, + tools=agent.tools, requested_groups=getattr(request, 'group', None), current_state=getattr(request, 'state', None) ) - + # Create temporary agent with filtered tools temp_agent = AgentManager( tools=filtered_tools, - additional_context=self.agent.additional_context + additional_context=agent.additional_context ) logger.debug("Call React") - # Create user-aware context wrapper that preserves the flow interface - # but adds user information for tools that need it - class UserAwareContext: - def __init__(self, flow, user): + # Thin wrapper around flow — carries only explain URI state. + class _Context: + def __init__(self, flow): self._flow = flow - self._user = user self.last_sub_explain_uri = None def __call__(self, service_name): - client = self._flow(service_name) - # For query clients that need user context, store it - if service_name in ("structured-query-request", "row-embeddings-query-request"): - client._current_user = self._user - return client + return self._flow(service_name) # Callback: emit Analysis+ToolUse triples before tool executes async def on_action(act_decision): @@ -604,7 +613,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=t_doc_id, - user=request.user, + workspace=flow.workspace, content=act_decision.thought, title=f"Agent Thought: {act_decision.name}", ) @@ -629,7 +638,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=iter_uri, - user=request.user, collection=collection, ), triples=iter_triples, @@ -644,7 +652,7 @@ class Processor(AgentService): explain_triples=iter_triples, )) - user_context = UserAwareContext(flow, request.user) + user_context = _Context(flow) act = await temp_agent.react( question = request.question, @@ -685,7 +693,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=answer_doc_id, - user=request.user, + workspace=flow.workspace, content=f, title=f"Agent Answer: {request.question[:50]}...", ) @@ -706,7 +714,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=final_uri, - user=request.user, collection=collection, ), triples=final_triples, @@ -763,7 +770,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=observation_doc_id, - user=request.user, + workspace=flow.workspace, content=act.observation, title=f"Agent Observation", ) @@ -783,7 +790,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=observation_entity_uri, - user=request.user, collection=collection, ), triples=obs_triples, @@ -820,7 +826,6 @@ class Processor(AgentService): ) for h in history ], - user=request.user, collection=collection, streaming=streaming, session_id=session_id, # Pass session_id for provenance continuity diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 6674c999..ae9507ab 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -116,31 +116,26 @@ class McpToolImpl: # This tool implementation knows how to query structured data using natural language class StructuredQueryImpl: - def __init__(self, context, collection=None, user=None): + def __init__(self, context, collection=None): self.context = context - self.collection = collection # For multi-tenant scenarios - self.user = user # User context for multi-tenancy - + self.collection = collection + @staticmethod def get_arguments(): return [ Argument( name="question", - type="string", + type="string", description="Natural language question about structured data (tables, databases, etc.)" ) ] - + async def invoke(self, **arguments): client = self.context("structured-query-request") logger.debug("Structured query question...") - - # Get user from client context if available, otherwise use instance user or default - user = getattr(client, '_current_user', self.user or "trustgraph") - + result = await client.structured_query( question=arguments.get("question"), - user=user, collection=self.collection or "default" ) @@ -159,11 +154,10 @@ class StructuredQueryImpl: # This tool implementation knows how to query row embeddings for semantic search class RowEmbeddingsQueryImpl: - def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10): + def __init__(self, context, schema_name, collection=None, index_name=None, limit=10): self.context = context self.schema_name = schema_name self.collection = collection - self.user = user self.index_name = index_name # Optional: filter to specific index self.limit = limit # Max results to return @@ -190,13 +184,9 @@ class RowEmbeddingsQueryImpl: client = self.context("row-embeddings-query-request") logger.debug("Row embeddings query...") - # Get user from client context if available - user = getattr(client, '_current_user', self.user or "trustgraph") - matches = await client.row_embeddings_query( vector=vector, schema_name=self.schema_name, - user=user, collection=self.collection or "default", index_name=self.index_name, limit=self.limit @@ -250,7 +240,7 @@ class ToolServiceImpl: Initialize a tool service implementation. Args: - context: The context function (provides user info) + context: Flow context (callable resolving service names to clients) request_queue: Full Pulsar topic for requests response_queue: Full Pulsar topic for responses config_values: Dict of config values (e.g., {"collection": "customers"}) @@ -325,17 +315,10 @@ class ToolServiceImpl: logger.debug(f"Config: {self.config_values}") logger.debug(f"Arguments: {arguments}") - # Get user from context if available - user = "trustgraph" - if hasattr(self.context, '_user'): - user = self.context._user - # Get or create the client client = await self._get_or_create_client() - # Call the tool service response = await client.call( - user=user, config=self.config_values, arguments=arguments, ) diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index dc7b357c..a0052c79 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -95,7 +95,7 @@ class Processor(ChunkingService): logger.info(f"Chunking document {v.metadata.id}...") # Get text content (fetches from librarian if needed) - text = await self.get_document_text(v) + text = await self.get_document_text(v, flow.workspace) # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( @@ -144,7 +144,7 @@ class Processor(ChunkingService): await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=chunk_content, document_type="chunk", title=f"Chunk {chunk_index}", @@ -168,7 +168,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -179,7 +178,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), chunk=chunk_content, diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 3f31beb9..c3935e4b 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -92,7 +92,7 @@ class Processor(ChunkingService): logger.info(f"Chunking document {v.metadata.id}...") # Get text content (fetches from librarian if needed) - text = await self.get_document_text(v) + text = await self.get_document_text(v, flow.workspace) # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( @@ -140,7 +140,7 @@ class Processor(ChunkingService): await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=chunk_content, document_type="chunk", title=f"Chunk {chunk_index}", @@ -164,7 +164,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -175,7 +174,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), chunk=chunk_content, diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py index 6c897f6b..36af6026 100644 --- a/trustgraph-flow/trustgraph/config/service/config.py +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -9,42 +9,8 @@ from ... tables.config import ConfigTableStore # Module logger logger = logging.getLogger(__name__) -class ConfigurationClass: - - async def keys(self): - return await self.table_store.get_keys(self.type) - - async def values(self): - vals = await self.table_store.get_values(self.type) - return { - v[0]: v[1] - for v in vals - } - - async def get(self, key): - return await self.table_store.get_value(self.type, key) - - async def put(self, key, value): - return await self.table_store.put_config(self.type, key, value) - - async def delete(self, key): - return await self.table_store.delete_key(self.type, key) - - async def has(self, key): - val = await self.table_store.get_value(self.type, key) - return val is not None - class Configuration: - # FIXME: The state is held internally. This only works if there's - # one config service. Should be more than one, and use a - # back-end state store. - - # FIXME: This has state now, but does it address all of the above? - # REVIEW: Above - - # FIXME: Some version vs config race conditions - def __init__(self, push, host, username, password, keyspace): # External function to respond to update @@ -60,34 +26,17 @@ class Configuration: async def get_version(self): return await self.table_store.get_version() - def get(self, type): - - c = ConfigurationClass() - c.table_store = self.table_store - c.type = type - - return c - async def handle_get(self, v): - # for k in v.keys: - # if k.type not in self or k.key not in self[k.type]: - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) + workspace = v.workspace values = [ ConfigValue( type = k.type, key = k.key, - value = await self.table_store.get_value(k.type, k.key) + value = await self.table_store.get_value( + workspace, k.type, k.key + ) ) for k in v.keys ] @@ -96,43 +45,19 @@ class Configuration: version = await self.get_version(), values = values, ) - + async def handle_list(self, v): - # if v.type not in self: - - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = "No such type", - # ), - # ) - return ConfigResponse( version = await self.get_version(), - directory = await self.table_store.get_keys(v.type), + directory = await self.table_store.get_keys( + v.workspace, v.type + ), ) async def handle_getvalues(self, v): - # if v.type not in self: - - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) - - vals = await self.table_store.get_values(v.type) + vals = await self.table_store.get_values(v.workspace, v.type) values = map( lambda x: ConfigValue( @@ -146,39 +71,63 @@ class Configuration: values = list(values), ) + async def handle_getvalues_all_ws(self, v): + """Fetch all values of a given type across all workspaces. + Used by shared processors to load type-scoped config at + startup without enumerating workspaces separately.""" + + vals = await self.table_store.get_values_all_ws(v.type) + + values = [ + ConfigValue( + workspace = row[0], + type = v.type, + key = row[1], + value = row[2], + ) + for row in vals + ] + + return ConfigResponse( + version = await self.get_version(), + values = values, + ) + async def handle_delete(self, v): + workspace = v.workspace types = list(set(k.type for k in v.keys)) for k in v.keys: - - await self.table_store.delete_key(k.type, k.key) + await self.table_store.delete_key(workspace, k.type, k.key) await self.inc_version() - await self.push(types=types) + await self.push(changes={t: [workspace] for t in types}) return ConfigResponse( ) async def handle_put(self, v): + workspace = v.workspace types = list(set(k.type for k in v.values)) for k in v.values: - - await self.table_store.put_config(k.type, k.key, k.value) + await self.table_store.put_config( + workspace, k.type, k.key, k.value + ) await self.inc_version() - await self.push(types=types) + await self.push(changes={t: [workspace] for t in types}) return ConfigResponse( ) - async def get_config(self): + async def get_config(self, workspace): - table = await self.table_store.get_all() + table = await self.table_store.get_all_for_workspace(workspace) config = {} @@ -191,7 +140,7 @@ class Configuration: async def handle_config(self, v): - config = await self.get_config() + config = await self.get_config(v.workspace) return ConfigResponse( version = await self.get_version(), @@ -200,7 +149,20 @@ class Configuration: async def handle(self, msg): - logger.debug(f"Handling config message: {msg.operation}") + logger.debug( + f"Handling config message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + # getvalues-all-ws spans all workspaces, so no workspace + # required; everything else is workspace-scoped. + if msg.operation != "getvalues-all-ws" and not msg.workspace: + return ConfigResponse( + error=Error( + type = "bad-request", + message = "Workspace is required" + ) + ) if msg.operation == "get": @@ -214,6 +176,10 @@ class Configuration: resp = await self.handle_getvalues(msg) + elif msg.operation == "getvalues-all-ws": + + resp = await self.handle_getvalues_all_ws(msg) + elif msg.operation == "delete": resp = await self.handle_delete(msg) diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index fe44b852..56a54ee0 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -128,18 +128,21 @@ class Processor(AsyncProcessor): await self.push() # Startup poke: empty types = everything await self.config_request_consumer.start() - async def push(self, types=None): + async def push(self, changes=None): version = await self.config.get_version() resp = ConfigPush( version = version, - types = types or [], + changes = changes or {}, ) await self.config_push_producer.send(resp) - logger.info(f"Pushed config poke version {version}, types={resp.types}") + logger.info( + f"Pushed config poke version {version}, " + f"changes={resp.changes}" + ) async def on_config_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index d03d4ed6..ab5f78f0 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -33,7 +33,7 @@ class KnowledgeManager: logger.info("Deleting knowledge core...") await self.table_store.delete_kg_core( - request.user, request.id + request.workspace, request.id ) await respond( @@ -63,7 +63,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_triples( - request.user, + request.workspace, request.id, publish_triples, ) @@ -81,7 +81,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_graph_embeddings( - request.user, + request.workspace, request.id, publish_ge, ) @@ -100,7 +100,7 @@ class KnowledgeManager: async def list_kg_cores(self, request, respond): - ids = await self.table_store.list_kg_cores(request.user) + ids = await self.table_store.list_kg_cores(request.workspace) await respond( KnowledgeResponse( @@ -114,12 +114,14 @@ class KnowledgeManager: async def put_kg_core(self, request, respond): + workspace = request.workspace + if request.triples: - await self.table_store.add_triples(request.triples) + await self.table_store.add_triples(workspace, request.triples) if request.graph_embeddings: await self.table_store.add_graph_embeddings( - request.graph_embeddings + workspace, request.graph_embeddings ) await respond( @@ -178,10 +180,15 @@ class KnowledgeManager: if request.flow is None: raise RuntimeError("Flow ID must be specified") - if request.flow not in self.flow_config.flows: - raise RuntimeError("Invalid flow") + workspace = request.workspace + ws_flows = self.flow_config.flows.get(workspace, {}) + if request.flow not in ws_flows: + raise RuntimeError( + f"Invalid flow {request.flow} for workspace " + f"{workspace}" + ) - flow = self.flow_config.flows[request.flow] + flow = ws_flows[request.flow] if "interfaces" not in flow: raise RuntimeError("No defined interfaces") @@ -257,7 +264,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_triples( - request.user, + request.workspace, request.id, publish_triples, ) @@ -272,7 +279,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_graph_embeddings( - request.user, + request.workspace, request.id, publish_ge, ) diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 93017c30..15e8feb6 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -124,19 +124,21 @@ class Processor(AsyncProcessor): await self.knowledge_request_consumer.start() await self.knowledge_response_producer.start() - async def on_knowledge_config(self, config, version): + async def on_knowledge_config(self, workspace, config, version): - logger.info(f"Configuration version: {version}") + logger.info( + f"Configuration version: {version} workspace: {workspace}" + ) if "flow" in config: - self.flows = { + self.flows[workspace] = { k: json.loads(v) for k, v in config["flow"].items() } else: - self.flows = {} + self.flows[workspace] = {} - logger.debug(f"Flows: {self.flows}") + logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") async def process_request(self, v, id): diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 40b8c566..3436ca51 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -200,7 +200,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -215,7 +215,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): content = content.encode('utf-8') @@ -243,7 +243,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -265,7 +265,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -277,7 +276,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 7f9ca71d..f3eb3881 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -93,7 +93,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -114,7 +114,7 @@ class Processor(FlowProcessor): content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) # Content is base64 encoded @@ -157,7 +157,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -179,7 +179,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -191,7 +190,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 66bfe31f..4564fe1f 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -6,9 +6,9 @@ import re logger = logging.getLogger(__name__) -def make_safe_collection_name(user, collection, prefix): +def make_safe_collection_name(workspace, collection, prefix): """ - Create a safe Milvus collection name from user/collection parameters. + Create a safe Milvus collection name from workspace/collection parameters. Milvus only allows letters, numbers, and underscores. """ def sanitize(s): @@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix): safe = 'default' return safe - safe_user = sanitize(user) + safe_workspace = sanitize(workspace) safe_collection = sanitize(collection) - return f"{prefix}_{safe_user}_{safe_collection}" + return f"{prefix}_{safe_workspace}_{safe_collection}" class DocVectors: @@ -49,26 +49,26 @@ class DocVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def collection_exists(self, user, collection): + def collection_exists(self, workspace, collection): """ - Check if any collection exists for this user/collection combination. + Check if any collection exists for this workspace/collection combination. Since collections are dimension-specific, this checks if ANY dimension variant exists. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" all_collections = self.client.list_collections() return any(coll.startswith(prefix) for coll in all_collections) - def create_collection(self, user, collection, dimension=384): + def create_collection(self, workspace, collection, dimension=384): """ No-op for explicit collection creation. Collections are created lazily on first insert with actual dimension. """ - logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") + logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert") - def init_collection(self, dimension, user, collection): + def init_collection(self, dimension, workspace, collection): - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( @@ -116,15 +116,15 @@ class DocVectors: index_params=index_params ) - self.collections[(dimension, user, collection)] = collection_name + self.collections[(dimension, workspace, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, chunk_id, user, collection): + def insert(self, embeds, chunk_id, workspace, collection): dim = len(embeds) - if (dim, user, collection) not in self.collections: - self.init_collection(dim, user, collection) + if (dim, workspace, collection) not in self.collections: + self.init_collection(dim, workspace, collection) data = [ { @@ -134,25 +134,25 @@ class DocVectors: ] self.client.insert( - collection_name=self.collections[(dim, user, collection)], + collection_name=self.collections[(dim, workspace, collection)], data=data ) - def search(self, embeds, user, collection, fields=["chunk_id"], limit=10): + def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10): dim = len(embeds) # Check if collection exists - return empty if not - if (dim, user, collection) not in self.collections: - base_name = make_safe_collection_name(user, collection, self.prefix) + if (dim, workspace, collection) not in self.collections: + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dim}" if not self.client.has_collection(collection_name): logger.info(f"Collection {collection_name} does not exist, returning empty results") return [] # Collection exists but not in cache, add it - self.collections[(dim, user, collection)] = collection_name + self.collections[(dim, workspace, collection)] = collection_name - coll = self.collections[(dim, user, collection)] + coll = self.collections[(dim, workspace, collection)] logger.debug("Loading...") self.client.load_collection( @@ -181,12 +181,12 @@ class DocVectors: return res - def delete_collection(self, user, collection): + def delete_collection(self, workspace, collection): """ - Delete all dimension variants of the collection for the given user/collection. + Delete all dimension variants of the collection for the given workspace/collection. Since collections are created with dimension suffixes, we need to find and delete all. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" # Get all collections and filter for matches @@ -199,10 +199,10 @@ class DocVectors: for collection_name in matching_collections: self.client.drop_collection(collection_name) logger.info(f"Deleted Milvus collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") # Remove from our local cache - keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection] for key in keys_to_remove: del self.collections[key] diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index dcbf6734..7d5a640b 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -6,9 +6,9 @@ import re logger = logging.getLogger(__name__) -def make_safe_collection_name(user, collection, prefix): +def make_safe_collection_name(workspace, collection, prefix): """ - Create a safe Milvus collection name from user/collection parameters. + Create a safe Milvus collection name from workspace/collection parameters. Milvus only allows letters, numbers, and underscores. """ def sanitize(s): @@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix): safe = 'default' return safe - safe_user = sanitize(user) + safe_workspace = sanitize(workspace) safe_collection = sanitize(collection) - return f"{prefix}_{safe_user}_{safe_collection}" + return f"{prefix}_{safe_workspace}_{safe_collection}" class EntityVectors: @@ -49,26 +49,26 @@ class EntityVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def collection_exists(self, user, collection): + def collection_exists(self, workspace, collection): """ - Check if any collection exists for this user/collection combination. + Check if any collection exists for this workspace/collection combination. Since collections are dimension-specific, this checks if ANY dimension variant exists. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" all_collections = self.client.list_collections() return any(coll.startswith(prefix) for coll in all_collections) - def create_collection(self, user, collection, dimension=384): + def create_collection(self, workspace, collection, dimension=384): """ No-op for explicit collection creation. Collections are created lazily on first insert with actual dimension. """ - logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") + logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert") - def init_collection(self, dimension, user, collection): + def init_collection(self, dimension, workspace, collection): - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( @@ -122,15 +122,15 @@ class EntityVectors: index_params=index_params ) - self.collections[(dimension, user, collection)] = collection_name + self.collections[(dimension, workspace, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, entity, user, collection, chunk_id=""): + def insert(self, embeds, entity, workspace, collection, chunk_id=""): dim = len(embeds) - if (dim, user, collection) not in self.collections: - self.init_collection(dim, user, collection) + if (dim, workspace, collection) not in self.collections: + self.init_collection(dim, workspace, collection) data = [ { @@ -141,25 +141,25 @@ class EntityVectors: ] self.client.insert( - collection_name=self.collections[(dim, user, collection)], + collection_name=self.collections[(dim, workspace, collection)], data=data ) - def search(self, embeds, user, collection, fields=["entity"], limit=10): + def search(self, embeds, workspace, collection, fields=["entity"], limit=10): dim = len(embeds) # Check if collection exists - return empty if not - if (dim, user, collection) not in self.collections: - base_name = make_safe_collection_name(user, collection, self.prefix) + if (dim, workspace, collection) not in self.collections: + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dim}" if not self.client.has_collection(collection_name): logger.info(f"Collection {collection_name} does not exist, returning empty results") return [] # Collection exists but not in cache, add it - self.collections[(dim, user, collection)] = collection_name + self.collections[(dim, workspace, collection)] = collection_name - coll = self.collections[(dim, user, collection)] + coll = self.collections[(dim, workspace, collection)] logger.debug("Loading...") self.client.load_collection( @@ -188,12 +188,12 @@ class EntityVectors: return res - def delete_collection(self, user, collection): + def delete_collection(self, workspace, collection): """ - Delete all dimension variants of the collection for the given user/collection. + Delete all dimension variants of the collection for the given workspace/collection. Since collections are created with dimension suffixes, we need to find and delete all. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" # Get all collections and filter for matches @@ -206,10 +206,10 @@ class EntityVectors: for collection_name in matching_collections: self.client.drop_collection(collection_name) logger.info(f"Deleted Milvus collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") # Remove from our local cache - keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection] for key in keys_to_remove: del self.collections[key] diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index 362bdec9..12f4cdc6 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -69,19 +69,26 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.register_config_handler(self.on_schema_config, types=["schema"]) self.register_config_handler(self.on_collection_config, types=["collection"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -115,13 +122,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) def get_index_names(self, schema: RowSchema) -> List[str]: """Get all index names for a schema.""" @@ -149,23 +162,29 @@ class Processor(CollectionConfigHandler, FlowProcessor): """Process incoming ExtractedObject and compute embeddings""" obj = msg.value() + workspace = flow.workspace logger.info( f"Computing embeddings for {len(obj.values)} rows, " - f"schema {obj.schema_name}, doc {obj.metadata.id}" + f"schema {obj.schema_name}, doc {obj.metadata.id}, " + f"workspace {workspace}" ) # Validate collection exists before processing - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + if not self.collection_exists(workspace, obj.metadata.collection): logger.warning( - f"Collection {obj.metadata.collection} for user {obj.metadata.user} " + f"Collection {obj.metadata.collection} for workspace {workspace} " f"does not exist in config. Dropping message." ) return - # Get schema definition - schema = self.schemas.get(obj.schema_name) + # Get schema definition for this workspace + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(obj.schema_name) if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") + logger.warning( + f"No schema found for {obj.schema_name} in " + f"workspace {workspace} - skipping" + ) return # Get all index names for this schema @@ -239,13 +258,13 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error("Exception during embedding computation", exc_info=True) raise e - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Collection creation notification - no action needed for embedding stage""" - logger.debug(f"Row embeddings collection notification for {user}/{collection}") + logger.debug(f"Row embeddings collection notification for {workspace}/{collection}") - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Collection deletion notification - no action needed for embedding stage""" - logger.debug(f"Row embeddings collection delete notification for {user}/{collection}") + logger.debug(f"Row embeddings collection delete notification for {workspace}/{collection}") @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index ce8d6aae..285b956c 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -75,24 +75,36 @@ class Processor(FlowProcessor): ) ) - # Null configuration, should reload quickly - self.manager = PromptManager() + # Per-workspace prompt managers + self.managers = {} - async def on_prompt_config(self, config, version): + async def on_prompt_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) return - config = config[self.config_key] + prompt_config = config[self.config_key] try: - self.manager.load_config(config) + manager = self.managers.get(workspace) + if manager is None: + manager = PromptManager() + self.managers[workspace] = manager - logger.info("Prompt configuration reloaded") + manager.load_config(prompt_config) + + logger.info( + f"Prompt configuration reloaded for {workspace}" + ) except Exception as e: @@ -107,7 +119,6 @@ class Processor(FlowProcessor): metadata = Metadata( id = metadata.id, root = metadata.root, - user = metadata.user, collection = metadata.collection, ), triples = triples, @@ -120,7 +131,6 @@ class Processor(FlowProcessor): metadata = Metadata( id = metadata.id, root = metadata.root, - user = metadata.user, collection = metadata.collection, ), entities = entity_contexts, @@ -170,13 +180,24 @@ class Processor(FlowProcessor): try: v = msg.value() + workspace = flow.workspace # Extract chunk text chunk_text = v.chunk.decode('utf-8') - logger.debug("Processing chunk for agent extraction") + logger.debug( + f"Processing chunk for agent extraction, " + f"workspace {workspace}" + ) - prompt = self.manager.render( + manager = self.managers.get(workspace) + if manager is None: + logger.error( + f"No prompt configuration for workspace {workspace}" + ) + return + + prompt = manager.render( self.template_id, { "text": chunk_text diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 9b5bbb79..31f45ae9 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -213,7 +213,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch @@ -227,7 +226,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index e024ad40..a05f4dfe 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -109,20 +109,22 @@ class Processor(FlowProcessor): # Register config handler for ontology updates self.register_config_handler(self.on_ontology_config, types=["ontology"]) - # Shared components (not flow-specific) - self.ontology_loader = OntologyLoader() + # Per-workspace ontology loaders + self.ontology_loaders = {} # workspace -> OntologyLoader self.text_processor = TextProcessor() - # Per-flow components (each flow gets its own embedder/vector store/selector) - self.flow_components = {} # flow_id -> {embedder, vector_store, selector} + # Per-flow components (each flow gets its own embedder/vector + # store/selector). Keyed by id(flow) — Flow objects are unique + # per (workspace, flow), so this is implicitly workspace-scoped. + self.flow_components = {} # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.3) - # Track loaded ontology version - self.current_ontology_version = None - self.loaded_ontology_ids = set() + # Per-workspace ontology version tracking + self.current_ontology_versions = {} # workspace -> version + self.loaded_ontology_ids = {} # workspace -> set of ids async def initialize_flow_components(self, flow): """Initialize per-flow OntoRAG components. @@ -167,17 +169,23 @@ class Processor(FlowProcessor): vector_store=vector_store ) - # Embed all loaded ontologies for this flow - if self.ontology_loader.get_all_ontologies(): - logger.info(f"Embedding ontologies for flow {flow_id}") - for ont_id, ontology in self.ontology_loader.get_all_ontologies().items(): + workspace = flow.workspace + + # Embed all loaded ontologies for this workspace + loader = self.ontology_loaders.get(workspace) + if loader is not None and loader.get_all_ontologies(): + logger.info( + f"Embedding ontologies for flow {flow_id} " + f"(workspace {workspace})" + ) + for ont_id, ontology in loader.get_all_ontologies().items(): await ontology_embedder.embed_ontology(ontology) logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}") # Initialize ontology selector ontology_selector = OntologySelector( ontology_embedder=ontology_embedder, - ontology_loader=self.ontology_loader, + ontology_loader=loader, top_k=self.top_k, similarity_threshold=self.similarity_threshold ) @@ -187,7 +195,8 @@ class Processor(FlowProcessor): 'embedder': ontology_embedder, 'vector_store': vector_store, 'selector': ontology_selector, - 'dimension': dimension + 'dimension': dimension, + 'workspace': workspace, } logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})") @@ -197,31 +206,27 @@ class Processor(FlowProcessor): logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True) raise - async def on_ontology_config(self, config, version): - """ - Handle ontology configuration updates from ConfigPush queue. - - Parses and stores ontologies. Embedding happens per-flow on first message. - - Called automatically when: - - Processor starts (gets full config history via start_of_messages=True) - - Config service pushes updates (immediate event-driven notification) - - Args: - config: Full configuration map - config[type][key] = value - version: Config version number (monotonically increasing) - """ + async def on_ontology_config(self, workspace, config, version): + """Handle ontology configuration updates for a workspace.""" try: - logger.info(f"Received ontology config update, version={version}") + logger.info( + f"Received ontology config update, " + f"version={version} workspace={workspace}" + ) - # Skip if we've already processed this version - if version == self.current_ontology_version: - logger.debug(f"Already at version {version}, skipping") + # Skip if we've already processed this version for this workspace + if version == self.current_ontology_versions.get(workspace): + logger.debug( + f"Already at version {version} for {workspace}, " + f"skipping" + ) return # Extract ontology configurations if "ontology" not in config: - logger.warning("No 'ontology' section in config") + logger.warning( + f"No 'ontology' section in config for {workspace}" + ) return ontology_configs = config["ontology"] @@ -235,38 +240,65 @@ class Processor(FlowProcessor): logger.error(f"Failed to parse ontology '{ont_id}': {e}") continue - logger.info(f"Loaded {len(ontologies)} ontology definitions") + logger.info( + f"Loaded {len(ontologies)} ontology definitions " + f"for {workspace}" + ) - # Determine what changed (for incremental updates) + # Determine what changed for this workspace + ws_loaded_ids = self.loaded_ontology_ids.get(workspace, set()) new_ids = set(ontologies.keys()) - added_ids = new_ids - self.loaded_ontology_ids - removed_ids = self.loaded_ontology_ids - new_ids - updated_ids = new_ids & self.loaded_ontology_ids # May have changed content + added_ids = new_ids - ws_loaded_ids + removed_ids = ws_loaded_ids - new_ids + updated_ids = new_ids & ws_loaded_ids # May have changed content if added_ids: - logger.info(f"New ontologies: {added_ids}") + logger.info(f"New ontologies in {workspace}: {added_ids}") if removed_ids: - logger.info(f"Removed ontologies: {removed_ids}") + logger.info(f"Removed ontologies in {workspace}: {removed_ids}") if updated_ids: - logger.info(f"Updated ontologies: {updated_ids}") + logger.info(f"Updated ontologies in {workspace}: {updated_ids}") - # Update ontology loader's internal state - self.ontology_loader.update_ontologies(ontologies) + # Get or create per-workspace loader + loader = self.ontology_loaders.get(workspace) + if loader is None: + loader = OntologyLoader() + self.ontology_loaders[workspace] = loader + loader.update_ontologies(ontologies) - # Clear all flow components to force re-embedding with new ontologies + # Clear flow components for this workspace to force + # re-embedding with new ontologies. if added_ids or removed_ids or updated_ids: - logger.info("Clearing flow components to trigger re-embedding") - self.flow_components.clear() + self._clear_workspace_flow_components(workspace) # Update tracking - self.current_ontology_version = version - self.loaded_ontology_ids = new_ids + self.current_ontology_versions[workspace] = version + self.loaded_ontology_ids[workspace] = new_ids - logger.info(f"Ontology config update complete, version={version}") + logger.info( + f"Ontology config update complete for {workspace}, " + f"version={version}" + ) except Exception as e: logger.error(f"Failed to process ontology config: {e}", exc_info=True) + def _clear_workspace_flow_components(self, workspace): + """Drop cached flow components belonging to the given workspace + so they're re-initialised on next message with fresh ontology + embeddings.""" + to_remove = [ + fid for fid, comp in self.flow_components.items() + if comp.get("workspace") == workspace + ] + if to_remove: + logger.info( + f"Clearing {len(to_remove)} flow components for " + f"workspace {workspace}" + ) + for fid in to_remove: + del self.flow_components[fid] + async def on_message(self, msg, consumer, flow): """Process incoming chunk message.""" v = msg.value() @@ -624,7 +656,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=metadata.id, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=triples, @@ -637,7 +668,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=metadata.id, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), entities=entities, diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 8068a23d..ee3e2ed2 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -207,7 +207,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 973bb3d7..f1dd4fe0 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -84,32 +84,39 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type schemas_config = config[self.config_key] - + # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): - + try: # Parse the JSON schema definition schema_def = json.loads(schema_json) - + # Create Field objects fields = [] for field_def in schema_def.get("fields", []): @@ -124,21 +131,27 @@ class Processor(FlowProcessor): indexed=field_def.get("indexed", False) ) fields.append(field) - + # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), fields=fields ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - + + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) + except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def extract_objects_for_schema(self, text: str, schema_name: str, schema: RowSchema, flow) -> List[Dict[str, Any]]: """Extract objects from text for a specific schema""" @@ -234,18 +247,26 @@ class Processor(FlowProcessor): """Process incoming chunk and extract objects""" v = msg.value() - logger.info(f"Extracting objects from chunk {v.metadata.id}...") + workspace = flow.workspace + logger.info( + f"Extracting objects from chunk {v.metadata.id} " + f"(workspace {workspace})..." + ) chunk_text = v.chunk.decode("utf-8") - # If no schemas configured, log warning and return - if not self.schemas: - logger.warning("No schemas configured - skipping extraction") + # If no schemas configured for this workspace, log and return + ws_schemas = self.schemas.get(workspace, {}) + if not ws_schemas: + logger.warning( + f"No schemas configured for workspace {workspace} " + f"- skipping extraction" + ) return try: # Extract objects for each configured schema - for schema_name, schema in self.schemas.items(): + for schema_name, schema in ws_schemas.items(): logger.debug(f"Extracting {schema_name} objects from chunk") @@ -274,7 +295,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=f"{v.metadata.id}:{schema_name}", root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), schema_name=schema_name, diff --git a/trustgraph-flow/trustgraph/flow/service/flow.py b/trustgraph-flow/trustgraph/flow/service/flow.py index a5e4a7e1..ed0158f6 100644 --- a/trustgraph-flow/trustgraph/flow/service/flow.py +++ b/trustgraph-flow/trustgraph/flow/service/flow.py @@ -17,14 +17,18 @@ class FlowConfig: self.config = config self.pubsub = pubsub - # Cache for parameter type definitions to avoid repeated lookups + # Per-workspace cache for parameter type definitions + # Keyed by (workspace, type-name) self.param_type_cache = {} - async def resolve_parameters(self, flow_blueprint, user_params): + async def resolve_parameters( + self, workspace, flow_blueprint, user_params + ): """ Resolve parameters by merging user-provided values with defaults. Args: + workspace: Workspace containing the parameter-type definitions flow_blueprint: The flow blueprint definition dict user_params: User-provided parameters dict (may be None or empty) @@ -55,24 +59,25 @@ class FlowConfig: # Look up the parameter type definition param_type = param_meta.get("type") if param_type: + cache_key = (workspace, param_type) # Check cache first - if param_type not in self.param_type_cache: + if cache_key not in self.param_type_cache: try: # Fetch parameter type definition from config store type_def = await self.config.get( - "parameter-type", param_type + workspace, "parameter-type", param_type ) if type_def: - self.param_type_cache[param_type] = json.loads(type_def) + self.param_type_cache[cache_key] = json.loads(type_def) else: logger.warning(f"Parameter type '{param_type}' not found in config") - self.param_type_cache[param_type] = {} + self.param_type_cache[cache_key] = {} except Exception as e: logger.error(f"Error fetching parameter type '{param_type}': {e}") - self.param_type_cache[param_type] = {} + self.param_type_cache[cache_key] = {} # Apply default from type definition (as string) - type_def = self.param_type_cache[param_type] + type_def = self.param_type_cache[cache_key] if "default" in type_def: default_value = type_def["default"] # Convert to string based on type @@ -94,8 +99,9 @@ class FlowConfig: else: # Controller has no value, try to get default from type definition param_type = param_meta.get("type") - if param_type and param_type in self.param_type_cache: - type_def = self.param_type_cache[param_type] + cache_key = (workspace, param_type) if param_type else None + if cache_key and cache_key in self.param_type_cache: + type_def = self.param_type_cache[cache_key] if "default" in type_def: default_value = type_def["default"] # Convert to string based on type @@ -114,7 +120,9 @@ class FlowConfig: async def handle_list_blueprints(self, msg): - names = list(await self.config.keys("flow-blueprint")) + names = list(await self.config.keys( + msg.workspace, "flow-blueprint" + )) return FlowResponse( error = None, @@ -126,14 +134,14 @@ class FlowConfig: return FlowResponse( error = None, blueprint_definition = await self.config.get( - "flow-blueprint", msg.blueprint_name + msg.workspace, "flow-blueprint", msg.blueprint_name ), ) async def handle_put_blueprint(self, msg): await self.config.put( - "flow-blueprint", + msg.workspace, "flow-blueprint", msg.blueprint_name, msg.blueprint_definition ) @@ -145,7 +153,9 @@ class FlowConfig: logger.debug(f"Flow config message: {msg}") - await self.config.delete("flow-blueprint", msg.blueprint_name) + await self.config.delete( + msg.workspace, "flow-blueprint", msg.blueprint_name + ) return FlowResponse( error = None, @@ -153,7 +163,7 @@ class FlowConfig: async def handle_list_flows(self, msg): - names = list(await self.config.keys("flow")) + names = list(await self.config.keys(msg.workspace, "flow")) return FlowResponse( error = None, @@ -162,7 +172,9 @@ class FlowConfig: async def handle_get_flow(self, msg): - flow_data = await self.config.get("flow", msg.flow_id) + flow_data = await self.config.get( + msg.workspace, "flow", msg.flow_id + ) flow = json.loads(flow_data) return FlowResponse( @@ -174,37 +186,49 @@ class FlowConfig: async def handle_start_flow(self, msg): + workspace = msg.workspace + if msg.blueprint_name is None: raise RuntimeError("No blueprint name") if msg.flow_id is None: raise RuntimeError("No flow ID") - if msg.flow_id in await self.config.keys("flow"): + if msg.flow_id in await self.config.keys(workspace, "flow"): raise RuntimeError("Flow already exists") if msg.description is None: raise RuntimeError("No description") - if msg.blueprint_name not in await self.config.keys("flow-blueprint"): + if msg.blueprint_name not in await self.config.keys( + workspace, "flow-blueprint" + ): raise RuntimeError("Blueprint does not exist") cls = json.loads( - await self.config.get("flow-blueprint", msg.blueprint_name) + await self.config.get( + workspace, "flow-blueprint", msg.blueprint_name + ) ) # Resolve parameters by merging user-provided values with defaults user_params = msg.parameters if msg.parameters else {} - parameters = await self.resolve_parameters(cls, user_params) + parameters = await self.resolve_parameters( + workspace, cls, user_params + ) # Log the resolved parameters for debugging logger.debug(f"User provided parameters: {user_params}") logger.debug(f"Resolved parameters (with defaults): {parameters}") - # Apply parameter substitution to template replacement function + # Apply parameter substitution to template replacement function. + # {workspace} is substituted from msg.workspace to isolate + # queue names across workspaces. def repl_template_with_params(tmp): result = tmp.replace( + "{workspace}", workspace + ).replace( "{blueprint}", msg.blueprint_name ).replace( "{id}", msg.flow_id @@ -253,7 +277,7 @@ class FlowConfig: json.dumps(entry), )) - await self.config.put_many(updates) + await self.config.put_many(workspace, updates) def repl_interface(i): return { @@ -270,7 +294,7 @@ class FlowConfig: interfaces = {} await self.config.put( - "flow", msg.flow_id, + workspace, "flow", msg.flow_id, json.dumps({ "description": msg.description, "blueprint-name": msg.blueprint_name, @@ -283,68 +307,77 @@ class FlowConfig: error = None, ) - async def ensure_existing_flow_topics(self): - """Ensure topics exist for all already-running flows. + async def ensure_existing_flow_topics(self, workspaces): + """Ensure topics exist for all already-running flows across + the given workspaces. Called on startup to handle flows that were started before this version of the flow service was deployed, or before a restart. """ - flow_ids = await self.config.keys("flow") + for workspace in workspaces: + flow_ids = await self.config.keys(workspace, "flow") - for flow_id in flow_ids: - try: - flow_data = await self.config.get("flow", flow_id) - if flow_data is None: - continue - - flow = json.loads(flow_data) - - blueprint_name = flow.get("blueprint-name") - if blueprint_name is None: - continue - - # Skip flows that are mid-shutdown - if flow.get("status") == "stopping": - continue - - parameters = flow.get("parameters", {}) - - blueprint_data = await self.config.get( - "flow-blueprint", blueprint_name - ) - if blueprint_data is None: - logger.warning( - f"Blueprint '{blueprint_name}' not found for " - f"flow '{flow_id}', skipping topic creation" + for flow_id in flow_ids: + try: + flow_data = await self.config.get( + workspace, "flow", flow_id ) - continue + if flow_data is None: + continue - cls = json.loads(blueprint_data) + flow = json.loads(flow_data) - def repl_template(tmp): - result = tmp.replace( - "{blueprint}", blueprint_name - ).replace( - "{id}", flow_id + blueprint_name = flow.get("blueprint-name") + if blueprint_name is None: + continue + + # Skip flows that are mid-shutdown + if flow.get("status") == "stopping": + continue + + parameters = flow.get("parameters", {}) + + blueprint_data = await self.config.get( + workspace, "flow-blueprint", blueprint_name ) - for param_name, param_value in parameters.items(): - result = result.replace( - f"{{{param_name}}}", str(param_value) + if blueprint_data is None: + logger.warning( + f"Blueprint '{blueprint_name}' not found " + f"for flow '{workspace}/{flow_id}', skipping " + f"topic creation" ) - return result + continue - topics = self._collect_flow_topics(cls, repl_template) - for topic in topics: - await self.pubsub.ensure_topic(topic) + cls = json.loads(blueprint_data) - logger.info( - f"Ensured topics for existing flow '{flow_id}'" - ) + def repl_template(tmp): + result = tmp.replace( + "{workspace}", workspace + ).replace( + "{blueprint}", blueprint_name + ).replace( + "{id}", flow_id + ) + for param_name, param_value in parameters.items(): + result = result.replace( + f"{{{param_name}}}", str(param_value) + ) + return result - except Exception as e: - logger.error( - f"Failed to ensure topics for flow '{flow_id}': {e}" - ) + topics = self._collect_flow_topics(cls, repl_template) + for topic in topics: + await self.pubsub.ensure_topic(topic) + + logger.info( + f"Ensured topics for existing flow " + f"'{workspace}/{flow_id}'" + ) + + except Exception as e: + logger.error( + f"Failed to ensure topics for flow " + f"'{workspace}/{flow_id}': {e}" + ) def _collect_flow_topics(self, cls, repl_template): """Collect unique topic identifiers from the blueprint. @@ -393,79 +426,95 @@ class FlowConfig: return topics - async def _live_owned_topic_closure(self, exclude_flow_id=None): - """Union of flow-owned topics referenced by all live flows. + async def _live_owned_topic_closure( + self, exclude_workspace=None, exclude_flow_id=None, + ): + """Union of flow-owned topics referenced by all live flows, + across every workspace. Walks every flow record currently registered in the config - service (except ``exclude_flow_id``, typically the flow being - torn down), resolves its blueprint + parameter templates, and - collects the set of flow-owned topics those templates produce. + service (except the single ``(exclude_workspace, exclude_flow_id)`` + pair — typically the flow being torn down), resolves its + blueprint + parameter templates, and collects the set of + flow-owned topics those templates produce. Used to drive closure-based topic cleanup on flow stop: a - topic may only be deleted if no remaining live flow would - still template to it. This handles all three scoping cases - transparently — ``{id}`` topics have no other references once - their flow is excluded; ``{blueprint}`` topics stay alive - while another flow of the same blueprint exists; ``{workspace}`` - (when introduced) stays alive while any flow in the workspace - exists. + topic may only be deleted if no remaining live flow (in any + workspace) would still template to it. This handles all + scoping cases transparently — ``{id}`` topics have no other + references once their flow is excluded; ``{blueprint}`` topics + stay alive while another flow of the same blueprint exists; + ``{workspace}`` topics stay alive while any flow in the same + workspace remains. """ live = set() - flow_ids = await self.config.keys("flow") + workspaces = await self.config.workspaces_for_type("flow") - for fid in flow_ids: + for ws in workspaces: - if fid == exclude_flow_id: - continue + flow_ids = await self.config.keys(ws, "flow") - try: - frec_raw = await self.config.get("flow", fid) - if frec_raw is None: + for fid in flow_ids: + + if ws == exclude_workspace and fid == exclude_flow_id: continue - frec = json.loads(frec_raw) - except Exception as e: - logger.warning( - f"Closure sweep: skipping flow {fid}: {e}" - ) - continue - # Flows mid-shutdown don't keep their topics alive. - if frec.get("status") == "stopping": - continue - - bp_name = frec.get("blueprint-name") - if bp_name is None: - continue - - try: - bp_raw = await self.config.get("flow-blueprint", bp_name) - if bp_raw is None: - continue - bp = json.loads(bp_raw) - except Exception as e: - logger.warning( - f"Closure sweep: skipping flow {fid} " - f"(blueprint {bp_name}): {e}" - ) - continue - - parameters = frec.get("parameters", {}) - - def repl(tmp, bp_name=bp_name, fid=fid, parameters=parameters): - result = tmp.replace( - "{blueprint}", bp_name - ).replace( - "{id}", fid - ) - for pname, pvalue in parameters.items(): - result = result.replace( - f"{{{pname}}}", str(pvalue) + try: + frec_raw = await self.config.get(ws, "flow", fid) + if frec_raw is None: + continue + frec = json.loads(frec_raw) + except Exception as e: + logger.warning( + f"Closure sweep: skipping flow {ws}/{fid}: {e}" ) - return result + continue - live.update(self._collect_owned_topics(bp, repl)) + # Flows mid-shutdown don't keep their topics alive. + if frec.get("status") == "stopping": + continue + + bp_name = frec.get("blueprint-name") + if bp_name is None: + continue + + try: + bp_raw = await self.config.get( + ws, "flow-blueprint", bp_name + ) + if bp_raw is None: + continue + bp = json.loads(bp_raw) + except Exception as e: + logger.warning( + f"Closure sweep: skipping flow {ws}/{fid} " + f"(blueprint {bp_name}): {e}" + ) + continue + + parameters = frec.get("parameters", {}) + + def repl( + tmp, + ws=ws, bp_name=bp_name, fid=fid, + parameters=parameters, + ): + result = tmp.replace( + "{workspace}", ws + ).replace( + "{blueprint}", bp_name + ).replace( + "{id}", fid + ) + for pname, pvalue in parameters.items(): + result = result.replace( + f"{{{pname}}}", str(pvalue) + ) + return result + + live.update(self._collect_owned_topics(bp, repl)) return live @@ -501,13 +550,17 @@ class FlowConfig: async def handle_stop_flow(self, msg): + workspace = msg.workspace + if msg.flow_id is None: raise RuntimeError("No flow ID") - if msg.flow_id not in await self.config.keys("flow"): + if msg.flow_id not in await self.config.keys(workspace, "flow"): raise RuntimeError("Flow ID invalid") - flow = json.loads(await self.config.get("flow", msg.flow_id)) + flow = json.loads( + await self.config.get(workspace, "flow", msg.flow_id) + ) if "blueprint-name" not in flow: raise RuntimeError("Internal error: flow has no flow blueprint") @@ -516,11 +569,15 @@ class FlowConfig: parameters = flow.get("parameters", {}) cls = json.loads( - await self.config.get("flow-blueprint", blueprint_name) + await self.config.get( + workspace, "flow-blueprint", blueprint_name + ) ) def repl_template(tmp): result = tmp.replace( + "{workspace}", workspace + ).replace( "{blueprint}", blueprint_name ).replace( "{id}", msg.flow_id @@ -539,7 +596,7 @@ class FlowConfig: # The config push tells processors to shut down their consumers. flow["status"] = "stopping" await self.config.put( - "flow", msg.flow_id, json.dumps(flow) + workspace, "flow", msg.flow_id, json.dumps(flow) ) # Delete all processor config entries for this flow. @@ -552,7 +609,7 @@ class FlowConfig: deletes.append((f"processor:{processor}", variant)) - await self.config.delete_many(deletes) + await self.config.delete_many(workspace, deletes) # Phase 2: Closure-based sweep. Only delete topics that no # other live flow still references via its blueprint templates. @@ -560,6 +617,7 @@ class FlowConfig: # of the same blueprint is still running, and {workspace}-scoped # topics while any flow in that workspace remains. live_owned = await self._live_owned_topic_closure( + exclude_workspace=workspace, exclude_flow_id=msg.flow_id, ) @@ -571,13 +629,13 @@ class FlowConfig: kept = this_flow_owned - to_delete if kept: logger.info( - f"Flow {msg.flow_id}: keeping {len(kept)} topics " - f"still referenced by other live flows" + f"Flow {workspace}/{msg.flow_id}: keeping {len(kept)} " + f"topics still referenced by other live flows" ) # Phase 3: Remove the flow record. - if msg.flow_id in await self.config.keys("flow"): - await self.config.delete("flow", msg.flow_id) + if msg.flow_id in await self.config.keys(workspace, "flow"): + await self.config.delete(workspace, "flow", msg.flow_id) return FlowResponse( error = None, @@ -585,7 +643,18 @@ class FlowConfig: async def handle(self, msg): - logger.debug(f"Handling flow message: {msg.operation}") + logger.debug( + f"Handling flow message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + if not msg.workspace: + return FlowResponse( + error=Error( + type="bad-request", + message="Workspace is required", + ), + ) if msg.operation == "list-blueprints": resp = await self.handle_list_blueprints(msg) diff --git a/trustgraph-flow/trustgraph/flow/service/service.py b/trustgraph-flow/trustgraph/flow/service/service.py index e1997452..74077ccb 100644 --- a/trustgraph-flow/trustgraph/flow/service/service.py +++ b/trustgraph-flow/trustgraph/flow/service/service.py @@ -103,7 +103,12 @@ class Processor(AsyncProcessor): await self.pubsub.ensure_topic(self.flow_request_topic) await self.config_client.start() - await self.flow.ensure_existing_flow_topics() + + # Discover workspaces with existing flow config and ensure + # their topics exist before we start accepting requests. + workspaces = await self.config_client.workspaces_for_type("flow") + await self.flow.ensure_existing_flow_topics(workspaces) + await self.flow_request_consumer.start() async def on_flow_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index c721a46a..5bc781a9 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -30,6 +30,7 @@ class ConfigReceiver: self.flow_handlers = [] + # Per-workspace flow tracking: {workspace: {flow_id: flow_def}} self.flows = {} self.config_version = 0 @@ -43,7 +44,7 @@ class ConfigReceiver: v = msg.value() notify_version = v.version - notify_types = set(v.types) + changes = v.changes # Skip if we already have this version or newer if notify_version <= self.config_version: @@ -53,20 +54,27 @@ class ConfigReceiver: ) return - # Gateway cares about flow config - if notify_types and "flow" not in notify_types: + # Gateway cares about flow config — check if any flow + # types changed in any workspace + flow_workspaces = changes.get("flow", []) + if changes and not flow_workspaces: logger.debug( f"Ignoring config notify v{notify_version}, " - f"no flow types in {notify_types}" + f"no flow changes" ) self.config_version = notify_version return logger.info( - f"Config notify v{notify_version}, fetching config..." + f"Config notify v{notify_version} " + f"types={list(changes.keys())}, fetching config..." ) - await self.fetch_and_apply() + # Refresh config for each affected workspace + for workspace in flow_workspaces: + await self.fetch_and_apply_workspace(workspace) + + self.config_version = notify_version except Exception as e: logger.error( @@ -98,20 +106,25 @@ class ConfigReceiver: response_metrics=config_resp_metrics, ) - async def fetch_and_apply(self, retry=False): - """Fetch full config and apply flow changes. + async def fetch_and_apply_workspace(self, workspace, retry=False): + """Fetch config for a single workspace and apply flow changes. If retry=True, keeps retrying until successful.""" while True: try: - logger.info("Fetching config from config service...") + logger.info( + f"Fetching config for workspace {workspace}..." + ) client = self._create_config_client() try: await client.start() resp = await client.request( - ConfigRequest(operation="config"), + ConfigRequest( + operation="config", + workspace=workspace, + ), timeout=10, ) finally: @@ -137,18 +150,22 @@ class ConfigReceiver: flows = config.get("flow", {}) + ws_flows = self.flows.get(workspace, {}) + wanted = list(flows.keys()) - current = list(self.flows.keys()) + current = list(ws_flows.keys()) for k in wanted: if k not in current: - self.flows[k] = json.loads(flows[k]) - await self.start_flow(k, self.flows[k]) + ws_flows[k] = json.loads(flows[k]) + await self.start_flow(workspace, k, ws_flows[k]) for k in current: if k not in wanted: - await self.stop_flow(k, self.flows[k]) - del self.flows[k] + await self.stop_flow(workspace, k, ws_flows[k]) + del ws_flows[k] + + self.flows[workspace] = ws_flows return @@ -164,27 +181,91 @@ class ConfigReceiver: ) return - async def start_flow(self, id, flow): + async def fetch_all_workspaces(self, retry=False): + """Fetch config for all workspaces at startup. + Discovers workspaces via the config service getvalues-all-ws + operation on the flow type.""" - logger.info(f"Starting flow: {id}") + while True: + + try: + logger.info("Discovering workspaces with flows...") + + client = self._create_config_client() + try: + await client.start() + + # Discover workspaces that have any flow config + resp = await client.request( + ConfigRequest( + operation="getvalues-all-ws", + type="flow", + ), + timeout=10, + ) + + if resp.error: + raise RuntimeError( + f"Config error: {resp.error.message}" + ) + + workspaces = { + v.workspace for v in resp.values if v.workspace + } + + # Always include the default workspace, even if + # empty, so that newly-created flows in it can be + # picked up by subsequent notifications. + workspaces.add("default") + + logger.info( + f"Found workspaces with flows: {workspaces}" + ) + + finally: + await client.stop() + + # Fetch and apply config for each workspace + for workspace in workspaces: + await self.fetch_and_apply_workspace( + workspace, retry=retry + ) + + return + + except Exception as e: + if retry: + logger.warning( + f"Workspace fetch failed: {e}, retrying in 2s..." + ) + await asyncio.sleep(2) + continue + logger.error( + f"Workspace fetch exception: {e}", exc_info=True + ) + return + + async def start_flow(self, workspace, id, flow): + + logger.info(f"Starting flow: {workspace}/{id}") for handler in self.flow_handlers: try: - await handler.start_flow(id, flow) + await handler.start_flow(workspace, id, flow) except Exception as e: logger.error( f"Config processing exception: {e}", exc_info=True ) - async def stop_flow(self, id, flow): + async def stop_flow(self, workspace, id, flow): - logger.info(f"Stopping flow: {id}") + logger.info(f"Stopping flow: {workspace}/{id}") for handler in self.flow_handlers: try: - await handler.stop_flow(id, flow) + await handler.stop_flow(workspace, id, flow) except Exception as e: logger.error( f"Config processing exception: {e}", exc_info=True @@ -218,7 +299,7 @@ class ConfigReceiver: # Fetch current config (subscribe-then-fetch pattern) # Retry until config service is available - await self.fetch_and_apply(retry=True) + await self.fetch_all_workspaces(retry=True) logger.info( "Config loader initialised, waiting for notifys..." diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 3a37c4e3..6696afbe 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -16,7 +16,7 @@ class CoreExport: async def process(self, data, error, ok, request): id = request.query["id"] - user = request.query["user"] + workspace = request.query.get("workspace", "default") response = await ok() @@ -41,7 +41,6 @@ class CoreExport: { "m": { "i": data["metadata"]["id"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ @@ -65,7 +64,6 @@ class CoreExport: { "m": { "i": data["metadata"]["id"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -78,7 +76,7 @@ class CoreExport: await kr.process( { "operation": "get-kg-core", - "user": user, + "workspace": workspace, "id": id, }, responder diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index 0ca07319..d03d4efd 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -17,7 +17,7 @@ class CoreImport: async def process(self, data, error, ok, request): id = request.query["id"] - user = request.query["user"] + workspace = request.query.get("workspace", "default") kr = KnowledgeRequestor( backend = self.backend, @@ -43,12 +43,11 @@ class CoreImport: msg = unpacked[1] msg = { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "triples": { "metadata": { "id": id, - "user": user, "collection": "default", # Not used? }, "triples": msg["t"], @@ -61,12 +60,11 @@ class CoreImport: msg = unpacked[1] msg = { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "graph-embeddings": { "metadata": { "id": id, - "user": user, "collection": "default", # Not used? }, "entities": [ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py index e70bf6de..2992d99f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py @@ -14,12 +14,12 @@ class DocumentStreamExport: async def process(self, data, error, ok, request): - user = request.query.get("user") + workspace = request.query.get("workspace", "default") document_id = request.query.get("document-id") chunk_size = int(request.query.get("chunk-size", 1024 * 1024)) - if not user or not document_id: - return await error("Missing required parameters: user, document-id") + if not document_id: + return await error("Missing required parameter: document-id") response = await ok() @@ -45,7 +45,7 @@ class DocumentStreamExport: await lr.process( { "operation": "stream-document", - "user": user, + "workspace": workspace, "document-id": document_id, "chunk-size": chunk_size, }, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index de0fe52d..91e47aaf 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -48,7 +48,6 @@ class EntityContextsImport: elt = EntityContexts( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 7c7dc915..3e246335 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -48,7 +48,6 @@ class GraphEmbeddingsImport: elt = GraphEmbeddings( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 592120b1..f3db3290 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -116,18 +116,20 @@ class DispatcherManager: # Format: {"config": {"request": "...", "response": "..."}, ...} self.queue_overrides = queue_overrides or {} + # Flows keyed by (workspace, flow_id) self.flows = {} + # Dispatchers keyed by (workspace, flow_id, kind) self.dispatchers = {} self.dispatcher_lock = asyncio.Lock() - async def start_flow(self, id, flow): - logger.info(f"Starting flow {id}") - self.flows[id] = flow + async def start_flow(self, workspace, id, flow): + logger.info(f"Starting flow {workspace}/{id}") + self.flows[(workspace, id)] = flow return - async def stop_flow(self, id, flow): - logger.info(f"Stopping flow {id}") - del self.flows[id] + async def stop_flow(self, workspace, id, flow): + logger.info(f"Stopping flow {workspace}/{id}") + del self.flows[(workspace, id)] return def dispatch_global_service(self): @@ -203,18 +205,20 @@ class DispatcherManager: async def process_flow_import(self, ws, running, params): + workspace = params.get("workspace", "default") flow = params.get("flow") kind = params.get("kind") - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") if kind not in import_dispatchers: raise RuntimeError("Invalid kind") - key = (flow, kind) + key = (workspace, flow, kind) - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] # FIXME: The -store bit, does it make sense? if kind == "entity-contexts": @@ -242,18 +246,20 @@ class DispatcherManager: async def process_flow_export(self, ws, running, params): + workspace = params.get("workspace", "default") flow = params.get("flow") kind = params.get("kind") - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") if kind not in export_dispatchers: raise RuntimeError("Invalid kind") - key = (flow, kind) + key = (workspace, flow, kind) - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] # FIXME: The -store bit, does it make sense? if kind == "entity-contexts": @@ -286,22 +292,36 @@ class DispatcherManager: async def process_flow_service(self, data, responder, params): + # Workspace can come from URL or from request body, defaulting + # to "default". Having it in the URL allows gateway routing to + # be workspace-aware without touching the body. + workspace = params.get("workspace") + if not workspace and isinstance(data, dict): + workspace = data.get("workspace") + if not workspace: + workspace = "default" + flow = params.get("flow") kind = params.get("kind") - return await self.invoke_flow_service(data, responder, flow, kind) + return await self.invoke_flow_service( + data, responder, workspace, flow, kind, + ) - async def invoke_flow_service(self, data, responder, flow, kind): + async def invoke_flow_service( + self, data, responder, workspace, flow, kind, + ): - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") - key = (flow, kind) + key = (workspace, flow, kind) if key not in self.dispatchers: async with self.dispatcher_lock: if key not in self.dispatchers: - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] if kind not in intf_defs: raise RuntimeError("This kind not supported by flow") @@ -314,8 +334,8 @@ class DispatcherManager: request_queue = qconfig["request"], response_queue = qconfig["response"], timeout = 120, - consumer = f"{self.prefix}-{flow}-{kind}-request", - subscriber = f"{self.prefix}-{flow}-{kind}-request", + consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request", ) elif kind in sender_dispatchers: dispatcher = sender_dispatchers[kind]( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index fabd5c44..3d610dca 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -47,7 +47,9 @@ class Mux: raise RuntimeError("Bad message") await self.q.put(( - data["id"], data.get("flow"), + data["id"], + data.get("workspace", "default"), + data.get("flow"), data["service"], data["request"] )) @@ -87,8 +89,10 @@ class Mux: # worker[0] still running, move on break - async def start_request_task(self, ws, id, flow, svc, request, workers): - + async def start_request_task( + self, ws, id, workspace, flow, svc, request, workers, + ): + # Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS while len(workers) > MAX_OUTSTANDING_REQUESTS: @@ -106,19 +110,23 @@ class Mux: }) worker = asyncio.create_task( - self.request_task(id, request, responder, flow, svc) + self.request_task( + id, request, responder, workspace, flow, svc, + ) ) workers.append(worker) - async def request_task(self, id, request, responder, flow, svc): + async def request_task( + self, id, request, responder, workspace, flow, svc, + ): try: if flow: await self.dispatcher_manager.invoke_flow_service( - request, responder, flow, svc + request, responder, workspace, flow, svc, ) else: @@ -148,7 +156,7 @@ class Mux: # Get next request on queue item = await asyncio.wait_for(self.q.get(), 1) - id, flow, svc, request = item + id, workspace, flow, svc, request = item except TimeoutError: continue @@ -172,7 +180,7 @@ class Mux: try: await self.start_request_task( - self.ws, id, flow, svc, request, workers + self.ws, id, workspace, flow, svc, request, workers ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py index ad634cab..8f92fa59 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py @@ -53,7 +53,6 @@ class RowsImport: elt = ExtractedObject( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), schema_name=data["schema_name"], diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 7267e320..28b0ded5 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -38,7 +38,6 @@ def serialize_triples(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "triples": serialize_subgraph(message.triples), @@ -50,7 +49,6 @@ def serialize_graph_embeddings(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ @@ -68,7 +66,6 @@ def serialize_entity_contexts(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ @@ -86,7 +83,6 @@ def serialize_document_embeddings(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "chunks": [ @@ -120,8 +116,8 @@ def serialize_document_metadata(message): if message.metadata: ret["metadata"] = serialize_subgraph(message.metadata) - if message.user: - ret["user"] = message.user + if message.workspace: + ret["workspace"] = message.workspace if message.tags is not None: ret["tags"] = message.tags @@ -144,8 +140,8 @@ def serialize_processing_metadata(message): if message.flow: ret["flow"] = message.flow - if message.user: - ret["user"] = message.user + if message.workspace: + ret["workspace"] = message.workspace if message.collection: ret["collection"] = message.collection @@ -164,7 +160,7 @@ def to_document_metadata(x): title = x.get("title", None), comments = x.get("comments", None), metadata = to_subgraph(x["metadata"]), - user = x.get("user", None), + workspace = x.get("workspace", None), tags = x.get("tags", None), ) @@ -175,7 +171,7 @@ def to_processing_metadata(x): document_id = x.get("document-id", None), time = x.get("time", None), flow = x.get("flow", None), - user = x.get("user", None), + workspace = x.get("workspace", None), collection = x.get("collection", None), tags = x.get("tags", None), ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 37f123fa..358faa8d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -49,7 +49,6 @@ class TriplesImport: metadata=Metadata( id=data["metadata"]["id"], root=data["metadata"].get("root", ""), - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), triples=to_subgraph(data["triples"]), diff --git a/trustgraph-flow/trustgraph/librarian/collection_manager.py b/trustgraph-flow/trustgraph/librarian/collection_manager.py index 34ce1de8..09932adf 100644 --- a/trustgraph-flow/trustgraph/librarian/collection_manager.py +++ b/trustgraph-flow/trustgraph/librarian/collection_manager.py @@ -3,6 +3,7 @@ Collection management for the librarian - uses config service for storage """ import asyncio +import dataclasses import logging import json import uuid @@ -20,7 +21,6 @@ logger = logging.getLogger(__name__) def metadata_to_dict(metadata: CollectionMetadata) -> dict: """Convert CollectionMetadata to dictionary for JSON serialization""" return { - 'user': metadata.user, 'collection': metadata.collection, 'name': metadata.name, 'description': metadata.description, @@ -92,38 +92,38 @@ class CollectionManager: self.pending_config_requests[response_id + "_response"] = response self.pending_config_requests[response_id].set() - async def ensure_collection_exists(self, user: str, collection: str): + async def ensure_collection_exists(self, workspace: str, collection: str): """ Ensure a collection exists, creating it if necessary Args: - user: User ID + workspace: Workspace ID collection: Collection ID """ try: # Check if collection exists via config service request = ConfigRequest( operation='get', - keys=[ConfigKey(type='collection', key=f'{user}:{collection}')] + workspace=workspace, + keys=[ConfigKey(type='collection', key=collection)] ) response = await self.send_config_request(request) # Validate response if not response.values or len(response.values) == 0: - raise Exception(f"Invalid response from config service when checking collection {user}/{collection}") + raise Exception(f"Invalid response from config service when checking collection {workspace}/{collection}") # Check if collection exists (value not None means it exists) if response.values[0].value is not None: - logger.debug(f"Collection {user}/{collection} already exists") + logger.debug(f"Collection {workspace}/{collection} already exists") return # Collection doesn't exist (value is None), proceed to create # Create new collection with default metadata - logger.info(f"Auto-creating collection {user}/{collection}") + logger.info(f"Auto-creating collection {workspace}/{collection}") metadata = CollectionMetadata( - user=user, collection=collection, name=collection, # Default name to collection ID description="", @@ -132,9 +132,10 @@ class CollectionManager: request = ConfigRequest( operation='put', + workspace=workspace, values=[ConfigValue( type='collection', - key=f'{user}:{collection}', + key=collection, value=json.dumps(metadata_to_dict(metadata)) )] ) @@ -144,7 +145,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {user}/{collection} auto-created in config service") + logger.info(f"Collection {workspace}/{collection} auto-created in config service") except Exception as e: logger.error(f"Error ensuring collection exists: {e}") @@ -161,9 +162,10 @@ class CollectionManager: CollectionManagementResponse with list of collections """ try: - # Get all collections from config service + # Get all collections in this workspace from config service config_request = ConfigRequest( operation='getvalues', + workspace=request.workspace, type='collection' ) @@ -172,15 +174,19 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config query failed: {response.error.message}") - # Parse collections and filter by user + # Every value in this workspace is a collection. + # Filter to fields the current schema knows about — older + # persisted values may carry fields that have since been + # dropped (e.g. `user` from the pre-workspace-refactor era). + known_fields = {f.name for f in dataclasses.fields(CollectionMetadata)} collections = [] for config_value in response.values: - if ":" in config_value.key: - coll_user, coll_name = config_value.key.split(":", 1) - if coll_user == request.user: - metadata_dict = json.loads(config_value.value) - metadata = CollectionMetadata(**metadata_dict) - collections.append(metadata) + metadata_dict = json.loads(config_value.value) + metadata_dict = { + k: v for k, v in metadata_dict.items() if k in known_fields + } + metadata = CollectionMetadata(**metadata_dict) + collections.append(metadata) # Apply tag filtering if specified if request.tag_filter: @@ -221,7 +227,6 @@ class CollectionManager: tags = list(request.tags) if request.tags else [] metadata = CollectionMetadata( - user=request.user, collection=request.collection, name=name, description=description, @@ -231,9 +236,10 @@ class CollectionManager: # Send put request to config service config_request = ConfigRequest( operation='put', + workspace=request.workspace, values=[ConfigValue( type='collection', - key=f'{request.user}:{request.collection}', + key=request.collection, value=json.dumps(metadata_to_dict(metadata)) )] ) @@ -243,7 +249,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {request.user}/{request.collection} updated in config service") + logger.info(f"Collection {request.workspace}/{request.collection} updated in config service") # Config service will trigger config push automatically # Storage services will receive update and create/update collections @@ -269,12 +275,13 @@ class CollectionManager: CollectionManagementResponse indicating success or failure """ try: - logger.info(f"Deleting collection {request.user}/{request.collection}") + logger.info(f"Deleting collection {request.workspace}/{request.collection}") # Send delete request to config service config_request = ConfigRequest( operation='delete', - keys=[ConfigKey(type='collection', key=f'{request.user}:{request.collection}')] + workspace=request.workspace, + keys=[ConfigKey(type='collection', key=request.collection)] ) response = await self.send_config_request(config_request) @@ -282,7 +289,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config delete failed: {response.error.message}") - logger.info(f"Collection {request.user}/{request.collection} deleted from config service") + logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service") # Config service will trigger config push automatically # Storage services will receive update and delete collections diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 77232650..af1d69b1 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -48,7 +48,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RuntimeError("Document already exists") @@ -78,7 +78,7 @@ class Librarian: logger.debug("Removing document...") if not await self.table_store.document_exists( - request.user, + request.workspace, request.document_id, ): raise RuntimeError("Document does not exist") @@ -89,17 +89,17 @@ class Librarian: logger.debug(f"Cascade deleting child document {child.id}") try: child_object_id = await self.table_store.get_document_object_id( - child.user, + child.workspace, child.id ) await self.blob_store.remove(child_object_id) - await self.table_store.remove_document(child.user, child.id) + await self.table_store.remove_document(child.workspace, child.id) except Exception as e: logger.warning(f"Failed to delete child document {child.id}: {e}") # Now remove the parent document object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) @@ -108,7 +108,7 @@ class Librarian: # Remove doc table row await self.table_store.remove_document( - request.user, + request.workspace, request.document_id ) @@ -120,10 +120,10 @@ class Librarian: logger.debug("Updating document...") - # You can't update the document ID, user or kind. + # You can't update the document ID, workspace or kind. if not await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RuntimeError("Document does not exist") @@ -139,7 +139,7 @@ class Librarian: logger.debug("Getting document metadata...") doc = await self.table_store.get_document( - request.user, + request.workspace, request.document_id ) @@ -156,7 +156,7 @@ class Librarian: logger.debug("Getting document content...") object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) @@ -180,18 +180,18 @@ class Librarian: raise RuntimeError("Collection parameter is required") if await self.table_store.processing_exists( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.id ): raise RuntimeError("Processing already exists") doc = await self.table_store.get_document( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.document_id ) object_id = await self.table_store.get_document_object_id( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.document_id ) @@ -222,14 +222,14 @@ class Librarian: logger.debug("Removing processing metadata...") if not await self.table_store.processing_exists( - request.user, + request.workspace, request.processing_id, ): raise RuntimeError("Processing object does not exist") # Remove doc table row await self.table_store.remove_processing( - request.user, + request.workspace, request.processing_id ) @@ -239,7 +239,7 @@ class Librarian: async def list_documents(self, request): - docs = await self.table_store.list_documents(request.user) + docs = await self.table_store.list_documents(request.workspace) # Filter out child documents and answer documents by default include_children = getattr(request, 'include_children', False) @@ -256,7 +256,7 @@ class Librarian: async def list_processing(self, request): - procs = await self.table_store.list_processing(request.user) + procs = await self.table_store.list_processing(request.workspace) return LibrarianResponse( processing_metadatas = procs, @@ -276,7 +276,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -312,14 +312,14 @@ class Librarian: "kind": request.document_metadata.kind, "title": request.document_metadata.title, "comments": request.document_metadata.comments, - "user": request.document_metadata.user, + "workspace": request.document_metadata.workspace, "tags": request.document_metadata.tags, }) # Store session in Cassandra await self.table_store.create_upload_session( upload_id=upload_id, - user=request.document_metadata.user, + workspace=request.document_metadata.workspace, document_id=request.document_metadata.id, document_metadata=doc_meta_json, s3_upload_id=s3_upload_id, @@ -352,7 +352,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to upload to this session") # Validate chunk index @@ -419,7 +419,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to complete this upload") # Verify all chunks received @@ -457,7 +457,7 @@ class Librarian: kind=doc_meta_dict["kind"], title=doc_meta_dict.get("title", ""), comments=doc_meta_dict.get("comments", ""), - user=doc_meta_dict["user"], + workspace=doc_meta_dict["workspace"], tags=doc_meta_dict.get("tags", []), metadata=[], # Triples not supported in chunked upload yet ) @@ -488,7 +488,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to abort this upload") # Abort S3 multipart upload @@ -520,7 +520,7 @@ class Librarian: ) # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to view this upload") chunks_received = session["chunks_received"] @@ -548,11 +548,11 @@ class Librarian: async def list_uploads(self, request): """ - List all in-progress uploads for a user. + List all in-progress uploads for a workspace. """ - logger.debug(f"Listing uploads for user {request.user}") + logger.debug(f"Listing uploads for workspace {request.workspace}") - sessions = await self.table_store.list_upload_sessions(request.user) + sessions = await self.table_store.list_upload_sessions(request.workspace) upload_sessions = [ UploadSession( @@ -591,7 +591,7 @@ class Librarian: # Verify parent exists if not await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.parent_id ): raise RequestError( @@ -599,7 +599,7 @@ class Librarian: ) if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -665,7 +665,7 @@ class Librarian: ) object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index ed005298..c24a5fe8 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -277,18 +277,22 @@ class Processor(AsyncProcessor): """Forward config responses to collection manager""" await self.collection_manager.on_config_response(message, consumer, flow) - async def on_librarian_config(self, config, version): + async def on_librarian_config(self, workspace, config, version): - logger.info(f"Configuration version: {version}") + logger.info( + f"Configuration version: {version} workspace: {workspace}" + ) if "flow" in config: - self.flows = { + self.flows[workspace] = { k: json.loads(v) for k, v in config["flow"].items() } + else: + self.flows[workspace] = {} - logger.debug(f"Flows: {self.flows}") + logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") def __del__(self): @@ -345,7 +349,6 @@ class Processor(AsyncProcessor): metadata=Metadata( id=doc_uri, root=document.id, - user=processing.user, collection=processing.collection, ), triples=all_triples, @@ -363,10 +366,15 @@ class Processor(AsyncProcessor): logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}") - if processing.flow not in self.flows: - raise RuntimeError("Invalid flow ID") + workspace = processing.workspace + ws_flows = self.flows.get(workspace, {}) + if processing.flow not in ws_flows: + raise RuntimeError( + f"Invalid flow ID {processing.flow} for workspace " + f"{workspace}" + ) - flow = self.flows[processing.flow] + flow = ws_flows[processing.flow] if document.kind == "text/plain": kind = "text-load" @@ -386,7 +394,6 @@ class Processor(AsyncProcessor): metadata = Metadata( id = document.id, root = document.id, - user = processing.user, collection = processing.collection ), document_id = document.id, @@ -398,7 +405,6 @@ class Processor(AsyncProcessor): metadata = Metadata( id = document.id, root = document.id, - user = processing.user, collection = processing.collection ), document_id = document.id, @@ -429,9 +435,9 @@ class Processor(AsyncProcessor): """ # Ensure collection exists when processing is added if hasattr(request, 'processing_metadata') and request.processing_metadata: - user = request.processing_metadata.user + workspace = request.processing_metadata.workspace collection = request.processing_metadata.collection - await self.collection_manager.ensure_collection_exists(user, collection) + await self.collection_manager.ensure_collection_exists(workspace, collection) # Call the original add_processing method return await self.librarian.add_processing(request) diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 46460b1f..a63b60ae 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -50,30 +50,37 @@ class Processor(FlowProcessor): ) ) + # Per-workspace price tables self.prices = {} self.config_key = "token-cost" - # Load token costs from the config service - async def on_cost_config(self, config, version): + async def on_cost_config(self, workspace, config, version): - logger.info(f"Loading metering configuration version {version}") + logger.info( + f"Loading metering configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) + self.prices[workspace] = {} return - config = config[self.config_key] + prices = config[self.config_key] - self.prices = { + self.prices[workspace] = { k: json.loads(v) - for k, v in config.items() + for k, v in prices.items() } - def get_prices(self, modelname): + def get_prices(self, workspace, modelname): - if modelname in self.prices: - model = self.prices[modelname] + ws_prices = self.prices.get(workspace, {}) + if modelname in ws_prices: + model = ws_prices[modelname] return model["input_price"], model["output_price"] return None, None # Return None if model is not found @@ -81,6 +88,8 @@ class Processor(FlowProcessor): v = msg.value() + workspace = flow.workspace + modelname = v.model or "unknown" num_in = v.in_token or 0 num_out = v.out_token or 0 @@ -89,7 +98,9 @@ class Processor(FlowProcessor): __class__.token_metric.labels(model=modelname, direction="input").inc(num_in) __class__.token_metric.labels(model=modelname, direction="output").inc(num_out) - model_input_price, model_output_price = self.get_prices(modelname) + model_input_price, model_output_price = self.get_prices( + workspace, modelname + ) if model_input_price == None: cost_per_call = f"Model Not Found in Price list" diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index c599ce77..5da329d3 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -66,24 +66,37 @@ class Processor(FlowProcessor): self.register_config_handler(self.on_prompt_config, types=["prompt"]) - # Null configuration, should reload quickly - self.manager = PromptManager() + # Per-workspace prompt managers. Populated lazily as config + # arrives for each workspace. + self.managers = {} - async def on_prompt_config(self, config, version): + async def on_prompt_config(self, workspace, config, version): - logger.info(f"Loading prompt configuration version {version}") + logger.info( + f"Loading prompt configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) return - config = config[self.config_key] + prompt_config = config[self.config_key] try: - self.manager.load_config(config) + manager = self.managers.get(workspace) + if manager is None: + manager = PromptManager() + self.managers[workspace] = manager - logger.info("Prompt configuration reloaded") + manager.load_config(prompt_config) + + logger.info( + f"Prompt configuration reloaded for {workspace}" + ) except Exception as e: @@ -103,6 +116,29 @@ class Processor(FlowProcessor): # Check if streaming is requested streaming = getattr(v, 'streaming', False) + # Look up the prompt manager for this workspace. If none is + # loaded yet, the request can't be handled. + workspace = flow.workspace + manager = self.managers.get(workspace) + if manager is None: + logger.error( + f"No prompt configuration loaded for workspace {workspace}" + ) + r = PromptResponse( + error=Error( + type="no-configuration", + message=( + f"No prompt configuration for workspace " + f"{workspace}" + ), + ), + text=None, + object=None, + end_of_stream=True, + ) + await flow("response").send(r, properties={"id": id}) + return + try: logger.debug(f"Prompt terms: {v.terms}") @@ -149,7 +185,7 @@ class Processor(FlowProcessor): return "" try: - await self.manager.invoke(kind, input, llm_streaming) + await manager.invoke(kind, input, llm_streaming) except Exception as e: logger.error(f"Prompt streaming exception: {e}", exc_info=True) raise e @@ -177,7 +213,7 @@ class Processor(FlowProcessor): return None try: - resp = await self.manager.invoke(kind, input, llm) + resp = await manager.invoke(kind, input, llm) except Exception as e: logger.error(f"Prompt invocation exception: {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 98350961..0a1d8e0f 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -31,7 +31,7 @@ class Processor(DocumentEmbeddingsQueryService): self.vecstore = DocVectors(store_uri) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -45,7 +45,7 @@ class Processor(DocumentEmbeddingsQueryService): resp = self.vecstore.search( vec, - msg.user, + workspace, msg.collection, limit=msg.limit ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 406f979c..e1bc39fc 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -48,7 +48,7 @@ class Processor(DocumentEmbeddingsQueryService): } ) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -63,7 +63,7 @@ class Processor(DocumentEmbeddingsQueryService): dim = len(vec) # Use dimension suffix in index name - index_name = f"d-{msg.user}-{msg.collection}-{dim}" + index_name = f"d-{workspace}-{msg.collection}-{dim}" # Check if index exists - return empty if not if not self.pinecone.has_index(index_name): diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index f056b1c1..1d59c835 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -65,7 +65,7 @@ class Processor(DocumentEmbeddingsQueryService): """Check if collection exists (no implicit creation)""" return self.qdrant.collection_exists(collection) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -75,7 +75,7 @@ class Processor(DocumentEmbeddingsQueryService): # Use dimension suffix in collection name dim = len(vec) - collection = f"d_{msg.user}_{msg.collection}_{dim}" + collection = f"d_{workspace}_{msg.collection}_{dim}" # Check if collection exists - return empty if not if not self.collection_exists(collection): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index 94eee387..1c5e8160 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -37,7 +37,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -51,7 +51,7 @@ class Processor(GraphEmbeddingsQueryService): resp = self.vecstore.search( vec, - msg.user, + workspace, msg.collection, limit=msg.limit * 2 ) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index ca443a6f..f612e3e8 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -55,7 +55,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -70,7 +70,7 @@ class Processor(GraphEmbeddingsQueryService): dim = len(vec) # Use dimension suffix in index name - index_name = f"t-{msg.user}-{msg.collection}-{dim}" + index_name = f"t-{workspace}-{msg.collection}-{dim}" # Check if index exists - return empty if not if not self.pinecone.has_index(index_name): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index df93ad8b..b8fb1361 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -71,7 +71,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -81,7 +81,7 @@ class Processor(GraphEmbeddingsQueryService): # Use dimension suffix in collection name dim = len(vec) - collection = f"t_{msg.user}_{msg.collection}_{dim}" + collection = f"t_{workspace}_{msg.collection}_{dim}" # Check if collection exists - return empty if not if not self.collection_exists(collection): diff --git a/trustgraph-flow/trustgraph/query/graphql/schema.py b/trustgraph-flow/trustgraph/query/graphql/schema.py index 0c97b1d9..af136cf7 100644 --- a/trustgraph-flow/trustgraph/query/graphql/schema.py +++ b/trustgraph-flow/trustgraph/query/graphql/schema.py @@ -70,7 +70,7 @@ class GraphQLSchemaBuilder: Build the GraphQL schema with the provided query callback. The query callback will be invoked when resolving queries, with: - - user: str + - workspace: str - collection: str - schema_name: str - row_schema: RowSchema @@ -228,7 +228,7 @@ class GraphQLSchemaBuilder: limit: Optional[int] = 100 ) -> List[graphql_type]: # Get context values - user = info.context["user"] + workspace = info.context["workspace"] collection = info.context["collection"] # Parse the where clause @@ -236,7 +236,7 @@ class GraphQLSchemaBuilder: # Call the query backend results = await query_callback( - user, collection, schema_name, row_schema, + workspace, collection, schema_name, row_schema, filters, limit, order_by, direction ) diff --git a/trustgraph-flow/trustgraph/query/ontology/query_explanation.py b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py index bd72aedc..6cced915 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_explanation.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py @@ -167,7 +167,7 @@ class QueryExplainer: question_components, query_results, processing_metadata ) - # Generate user-friendly explanation + # Generate workspace-friendly explanation user_friendly_explanation = self._generate_user_friendly_explanation( question, question_components, ontology_subsets, final_answer ) @@ -503,7 +503,7 @@ class QueryExplainer: question_components: QuestionComponents, ontology_subsets: List[QueryOntologySubset], final_answer: str) -> str: - """Generate user-friendly explanation of the process.""" + """Generate workspace-friendly explanation of the process.""" explanation_parts = [] # Introduction diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py index ec7884ed..c6057cc1 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_service.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) @dataclass class QueryRequest: - """Query request from user.""" + """Query request from workspace.""" question: str context: Optional[str] = None ontology_hint: Optional[str] = None diff --git a/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py index 3e48ac78..de39a89c 100644 --- a/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py +++ b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py @@ -1,6 +1,6 @@ """ Question analyzer for ontology-sensitive query system. -Decomposes user questions into semantic components. +Decomposes workspace questions into semantic components. """ import logging diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 7fc20303..dd89a8d8 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -1,7 +1,7 @@ """ Row embeddings query service for Qdrant. -Input is query vectors plus user/collection/schema context. +Input is query vectors plus workspace/collection/schema context. Output is matching row index information (index_name, index_value) for use in subsequent Cassandra lookups. """ @@ -70,10 +70,10 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]: - """Find the Qdrant collection for a given user/collection/schema""" + def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: + """Find the Qdrant collection for a given workspace/collection/schema""" prefix = ( - f"rows_{self.sanitize_name(user)}_" + f"rows_{self.sanitize_name(workspace)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) @@ -93,22 +93,22 @@ class Processor(FlowProcessor): return None - async def query_row_embeddings(self, request: RowEmbeddingsRequest): + async def query_row_embeddings(self, workspace, request: RowEmbeddingsRequest): """Execute row embeddings query""" vec = request.vector if not vec: return [] - # Find the collection for this user/collection/schema + # Find the collection for this workspace/collection/schema qdrant_collection = self.find_collection( - request.user, request.collection, request.schema_name + workspace, request.collection, request.schema_name ) if not qdrant_collection: logger.info( f"No Qdrant collection found for " - f"{request.user}/{request.collection}/{request.schema_name}" + f"{workspace}/{request.collection}/{request.schema_name}" ) return [] @@ -163,11 +163,11 @@ class Processor(FlowProcessor): logger.debug( f"Handling row embeddings query for " - f"{request.user}/{request.collection}/{request.schema_name}..." + f"{flow.workspace}/{request.collection}/{request.schema_name}..." ) # Execute query - matches = await self.query_row_embeddings(request) + matches = await self.query_row_embeddings(flow.workspace, request) response = RowEmbeddingsResponse( error=None, diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 019d5610..cabdf617 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -87,12 +87,12 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - # GraphQL schema builder and generated schema - self.schema_builder = GraphQLSchemaBuilder() - self.graphql_schema = None + # Per-workspace GraphQL schema builders and compiled schemas + self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {} + self.graphql_schemas: Dict[str, Any] = {} # Cassandra session self.cluster = None @@ -133,17 +133,27 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} - self.schema_builder.clear() + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas + + builder = GraphQLSchemaBuilder() + self.schema_builders[workspace] = builder # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) + self.graphql_schemas[workspace] = None return # Get the schemas dictionary for our type @@ -177,17 +187,23 @@ class Processor(FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - self.schema_builder.add_schema(schema_name, row_schema) - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + builder.add_schema(schema_name, row_schema) + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) - # Regenerate GraphQL schema - self.graphql_schema = self.schema_builder.build(self.query_cassandra) + # Regenerate GraphQL schema for this workspace + self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: """Get all index names for a schema.""" @@ -222,7 +238,7 @@ class Processor(FlowProcessor): async def query_cassandra( self, - user: str, + workspace: str, collection: str, schema_name: str, row_schema: RowSchema, @@ -240,7 +256,7 @@ class Processor(FlowProcessor): # Connect if needed self.connect_cassandra() - safe_keyspace = self.sanitize_name(user) + safe_keyspace = self.sanitize_name(workspace) # Try to find an index that matches the filters index_match = self.find_matching_index(row_schema, filters) @@ -389,26 +405,30 @@ class Processor(FlowProcessor): async def execute_graphql_query( self, + workspace: str, query: str, variables: Dict[str, Any], operation_name: Optional[str], - user: str, collection: str ) -> Dict[str, Any]: - """Execute a GraphQL query""" + """Execute a GraphQL query against the workspace's schema""" - if not self.graphql_schema: - raise RuntimeError("No GraphQL schema available - no schemas loaded") + graphql_schema = self.graphql_schemas.get(workspace) + if not graphql_schema: + raise RuntimeError( + f"No GraphQL schema available for workspace {workspace} " + f"- no schemas loaded" + ) # Create context for the query context = { "processor": self, - "user": user, + "workspace": workspace, "collection": collection } # Execute the query - result = await self.graphql_schema.execute( + result = await graphql_schema.execute( query, variable_values=variables, operation_name=operation_name, @@ -454,10 +474,10 @@ class Processor(FlowProcessor): # Execute GraphQL query result = await self.execute_graphql_query( + workspace=flow.workspace, query=request.query, variables=dict(request.variables) if request.variables else {}, operation_name=request.operation_name, - user=request.user, collection=request.collection ) diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index eda83efb..bff9a336 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -30,14 +30,14 @@ class EvaluationError(Exception): pass -async def evaluate(node, triples_client, user, collection, limit=10000): +async def evaluate(node, triples_client, workspace, collection, limit=10000): """ Evaluate a SPARQL algebra node. Args: node: rdflib CompValue algebra node triples_client: TriplesClient instance for triple pattern queries - user: user/keyspace identifier + workspace: workspace/keyspace identifier collection: collection identifier limit: safety limit on results @@ -55,24 +55,24 @@ async def evaluate(node, triples_client, user, collection, limit=10000): logger.warning(f"Unsupported algebra node: {name}") return [{}] - return await handler(node, triples_client, user, collection, limit) + return await handler(node, triples_client, workspace, collection, limit) # --- Node handlers --- -async def _eval_select_query(node, tc, user, collection, limit): +async def _eval_select_query(node, tc, workspace, collection, limit): """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) -async def _eval_project(node, tc, user, collection, limit): +async def _eval_project(node, tc, workspace, collection, limit): """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) variables = [str(v) for v in node.PV] return project(solutions, variables) -async def _eval_bgp(node, tc, user, collection, limit): +async def _eval_bgp(node, tc, workspace, collection, limit): """ Evaluate a Basic Graph Pattern. @@ -107,7 +107,7 @@ async def _eval_bgp(node, tc, user, collection, limit): # Query the triples store results = await _query_pattern( - tc, s_val, p_val, o_val, user, collection, limit + tc, s_val, p_val, o_val, workspace, collection, limit ) # Map results back to variable bindings, @@ -130,17 +130,17 @@ async def _eval_bgp(node, tc, user, collection, limit): return solutions[:limit] -async def _eval_join(node, tc, user, collection, limit): +async def _eval_join(node, tc, workspace, collection, limit): """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, user, collection, limit) - right = await evaluate(node.p2, tc, user, collection, limit) + left = await evaluate(node.p1, tc, workspace, collection, limit) + right = await evaluate(node.p2, tc, workspace, collection, limit) return hash_join(left, right)[:limit] -async def _eval_left_join(node, tc, user, collection, limit): +async def _eval_left_join(node, tc, workspace, collection, limit): """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, user, collection, limit) - right_sols = await evaluate(node.p2, tc, user, collection, limit) + left_sols = await evaluate(node.p1, tc, workspace, collection, limit) + right_sols = await evaluate(node.p2, tc, workspace, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -153,16 +153,16 @@ async def _eval_left_join(node, tc, user, collection, limit): return left_join(left_sols, right_sols, filter_fn)[:limit] -async def _eval_union(node, tc, user, collection, limit): +async def _eval_union(node, tc, workspace, collection, limit): """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, user, collection, limit) - right = await evaluate(node.p2, tc, user, collection, limit) + left = await evaluate(node.p1, tc, workspace, collection, limit) + right = await evaluate(node.p2, tc, workspace, collection, limit) return union(left, right)[:limit] -async def _eval_filter(node, tc, user, collection, limit): +async def _eval_filter(node, tc, workspace, collection, limit): """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) expr = node.expr return [ sol for sol in solutions @@ -170,22 +170,22 @@ async def _eval_filter(node, tc, user, collection, limit): ] -async def _eval_distinct(node, tc, user, collection, limit): +async def _eval_distinct(node, tc, workspace, collection, limit): """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) return distinct(solutions) -async def _eval_reduced(node, tc, user, collection, limit): +async def _eval_reduced(node, tc, workspace, collection, limit): """Evaluate a Reduced node (like Distinct but implementation-defined).""" # Treat same as Distinct - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) return distinct(solutions) -async def _eval_order_by(node, tc, user, collection, limit): +async def _eval_order_by(node, tc, workspace, collection, limit): """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) key_fns = [] for cond in node.expr: @@ -206,7 +206,7 @@ async def _eval_order_by(node, tc, user, collection, limit): return order_by(solutions, key_fns) -async def _eval_slice(node, tc, user, collection, limit): +async def _eval_slice(node, tc, workspace, collection, limit): """Evaluate a Slice node (LIMIT/OFFSET).""" # Pass tighter limit downstream if possible inner_limit = limit @@ -214,13 +214,13 @@ async def _eval_slice(node, tc, user, collection, limit): offset = node.start or 0 inner_limit = min(limit, offset + node.length) - solutions = await evaluate(node.p, tc, user, collection, inner_limit) + solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) return slice_solutions(solutions, node.start or 0, node.length) -async def _eval_extend(node, tc, user, collection, limit): +async def _eval_extend(node, tc, workspace, collection, limit): """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) var_name = str(node.var) expr = node.expr @@ -246,9 +246,9 @@ async def _eval_extend(node, tc, user, collection, limit): return result -async def _eval_group(node, tc, user, collection, limit): +async def _eval_group(node, tc, workspace, collection, limit): """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) # Extract grouping expressions group_exprs = [] @@ -289,9 +289,9 @@ async def _eval_group(node, tc, user, collection, limit): return result -async def _eval_aggregate_join(node, tc, user, collection, limit): +async def _eval_aggregate_join(node, tc, workspace, collection, limit): """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) result = [] for sol in solutions: @@ -310,7 +310,7 @@ async def _eval_aggregate_join(node, tc, user, collection, limit): return result -async def _eval_graph(node, tc, user, collection, limit): +async def _eval_graph(node, tc, workspace, collection, limit): """Evaluate a Graph node (GRAPH clause).""" term = node.term @@ -319,16 +319,16 @@ async def _eval_graph(node, tc, user, collection, limit): # We'd need to pass graph to triples queries # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) elif isinstance(term, Variable): # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) else: - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) -async def _eval_values(node, tc, user, collection, limit): +async def _eval_values(node, tc, workspace, collection, limit): """Evaluate a VALUES clause (inline data).""" variables = [str(v) for v in node.var] solutions = [] @@ -343,9 +343,9 @@ async def _eval_values(node, tc, user, collection, limit): return solutions -async def _eval_to_multiset(node, tc, user, collection, limit): +async def _eval_to_multiset(node, tc, workspace, collection, limit): """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) # --- Aggregate computation --- @@ -487,7 +487,7 @@ def _resolve_term(tmpl, solution): return rdflib_term_to_term(tmpl) -async def _query_pattern(tc, s, p, o, user, collection, limit): +async def _query_pattern(tc, s, p, o, workspace, collection, limit): """ Issue a streaming triple pattern query via TriplesClient. @@ -496,7 +496,7 @@ async def _query_pattern(tc, s, p, o, user, collection, limit): results = await tc.query( s=s, p=p, o=o, limit=limit, - user=user, + workspace=workspace, collection=collection, ) return results diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 38488032..983cd4f6 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -141,7 +141,7 @@ class Processor(FlowProcessor): solutions = await evaluate( parsed.algebra, triples_client, - user=request.user or "trustgraph", + workspace=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, ) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 905aaaf2..efce5968 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -178,34 +178,34 @@ class Processor(TriplesQueryService): self.cassandra_password = password self.table = None - def ensure_connection(self, user): + def ensure_connection(self, workspace): """Ensure we have a connection to the correct keyspace.""" - if user != self.table: + if workspace != self.table: KGClass = EntityCentricKnowledgeGraph if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, ) - self.table = user + self.table = workspace - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: # ensure_connection may construct a fresh # EntityCentricKnowledgeGraph which does sync schema # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per user. - await asyncio.to_thread(self.ensure_connection, query.user) + # so the event loop doesn't block on first-use per workspace. + await asyncio.to_thread(self.ensure_connection, workspace) # Extract values from query s_val = get_term_value(query.s) @@ -359,13 +359,13 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e - async def query_triples_stream(self, query): + async def query_triples_stream(self, workspace, query): """ Streaming query - yields (batch, is_final) tuples. Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, query.user) + await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 @@ -395,7 +395,7 @@ class Processor(TriplesQueryService): else: # For specific patterns, fall back to non-streaming # (these typically return small result sets anyway) - async for batch, is_final in self._fallback_stream(query, batch_size): + async for batch, is_final in self._fallback_stream(workspace, query, batch_size): yield batch, is_final return @@ -452,9 +452,9 @@ class Processor(TriplesQueryService): logger.error(f"Exception in streaming query: {e}", exc_info=True) raise e - async def _fallback_stream(self, query, batch_size): + async def _fallback_stream(self, workspace, query, batch_size): """Fallback to non-streaming query with post-hoc batching.""" - triples = await self.query_triples(query) + triples = await self.query_triples(workspace, query) for i in range(0, len(triples), batch_size): batch = triples[i:i + batch_size] diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index 14b24d52..9781aaaf 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -58,7 +58,7 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index 37633f34..173f07dd 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -63,12 +63,11 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: - # Extract user and collection, use defaults if not provided - user = query.user if query.user else "default" + workspace = workspace collection = query.collection if query.collection else "default" triples = [] @@ -80,13 +79,13 @@ class Processor(TriplesQueryService): # SPO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -94,13 +93,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -112,13 +111,13 @@ class Processor(TriplesQueryService): # SP records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -127,13 +126,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -148,13 +147,13 @@ class Processor(TriplesQueryService): # SO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -163,13 +162,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -182,13 +181,13 @@ class Processor(TriplesQueryService): # S records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -197,13 +196,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -221,13 +220,13 @@ class Processor(TriplesQueryService): # PO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -236,13 +235,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), dest=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -255,13 +254,13 @@ class Processor(TriplesQueryService): # P records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -270,13 +269,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -291,13 +290,13 @@ class Processor(TriplesQueryService): # O records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -306,13 +305,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -325,12 +324,12 @@ class Processor(TriplesQueryService): # * records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -339,12 +338,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 4cb1ab21..b47d49a9 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -63,14 +63,12 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: - # Extract user and collection, use defaults if not provided - user = query.user if query.user else "default" collection = query.collection if query.collection else "default" - + triples = [] if query.s is not None: @@ -80,13 +78,13 @@ class Processor(TriplesQueryService): # SPO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -94,13 +92,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -112,13 +110,13 @@ class Processor(TriplesQueryService): # SP records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -127,13 +125,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -148,13 +146,13 @@ class Processor(TriplesQueryService): # SO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -163,13 +161,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -182,13 +180,13 @@ class Processor(TriplesQueryService): # S records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -197,13 +195,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -221,13 +219,13 @@ class Processor(TriplesQueryService): # PO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -236,13 +234,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), dest=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -255,13 +253,13 @@ class Processor(TriplesQueryService): # P records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -270,13 +268,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -291,13 +289,13 @@ class Processor(TriplesQueryService): # O records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -306,13 +304,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -325,12 +323,12 @@ class Processor(TriplesQueryService): # * records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -339,12 +337,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -367,7 +365,7 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e - + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index dfe4e051..1864e1ad 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -26,11 +26,11 @@ LABEL="http://www.w3.org/2000/01/rdf-schema#label" class Query: def __init__( - self, rag, user, collection, verbose, + self, rag, workspace, collection, verbose, doc_limit=20, track_usage=None, ): self.rag = rag - self.user = user + self.workspace = workspace self.collection = collection self.verbose = verbose self.doc_limit = doc_limit @@ -97,7 +97,7 @@ class Query: async def query_concept(vec): return await self.rag.doc_embeddings_client.query( vector=vec, limit=per_concept_limit, - user=self.user, collection=self.collection, + collection=self.collection, ) results = await asyncio.gather( @@ -122,7 +122,7 @@ class Query: for match in chunk_matches: if match.chunk_id: try: - content = await self.rag.fetch_chunk(match.chunk_id, self.user) + content = await self.rag.fetch_chunk(match.chunk_id, self.workspace) docs.append(content) chunk_ids.append(match.chunk_id) except Exception as e: @@ -154,7 +154,7 @@ class DocumentRag: logger.debug("DocumentRag initialized") async def query( - self, query, user="trustgraph", collection="default", + self, query, workspace="default", collection="default", doc_limit=20, streaming=False, chunk_callback=None, explain_callback=None, save_answer_callback=None, ): @@ -163,7 +163,7 @@ class DocumentRag: Args: query: The query string - user: User identifier + workspace: Workspace for isolation (also scopes chunk lookup) collection: Collection identifier doc_limit: Max chunks to retrieve streaming: Enable streaming LLM response @@ -210,7 +210,8 @@ class DocumentRag: await explain_callback(q_triples, q_uri) q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose, + rag=self, workspace=workspace, collection=collection, + verbose=self.verbose, doc_limit=doc_limit, track_usage=track_usage, ) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index dc7296ad..30333c0e 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -96,19 +96,19 @@ class Processor(FlowProcessor): await super(Processor, self).start() await self.librarian.start() - async def fetch_chunk_content(self, chunk_id, user, timeout=120): + async def fetch_chunk_content(self, chunk_id, workspace, timeout=120): """Fetch chunk content from librarian. Chunks are small so single request-response is fine.""" return await self.librarian.fetch_document_text( - document_id=chunk_id, user=user, timeout=timeout, + document_id=chunk_id, workspace=workspace, timeout=timeout, ) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """Save answer content to the librarian.""" doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "DocumentRAG Answer", document_type="answer", @@ -119,7 +119,7 @@ class Processor(FlowProcessor): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) await self.librarian.request(request, timeout=timeout) @@ -150,14 +150,13 @@ class Processor(FlowProcessor): doc_limit = self.doc_limit # Real-time explainability callback - emits triples and IDs as they're generated - # Triples are stored in the user's collection with a named graph (urn:graph:retrieval) + # Triples are stored in the request's collection with a named graph (urn:graph:retrieval) async def send_explainability(triples, explain_id): # Send triples to explainability queue - stores in same collection with named graph await flow("explainability").send(Triples( metadata=Metadata( id=explain_id, - user=v.user, - collection=v.collection, # Store in user's collection + collection=v.collection, ), triples=triples, )) @@ -178,7 +177,7 @@ class Processor(FlowProcessor): async def save_answer(doc_id, answer_text): await self.save_answer_content( doc_id=doc_id, - user=v.user, + workspace=flow.workspace, content=answer_text, title=f"DocumentRAG Answer: {v.query[:50]}...", ) @@ -202,7 +201,7 @@ class Processor(FlowProcessor): # All chunks (including final one with end_of_stream=True) are sent via callback response, usage = await self.rag.query( v.query, - user=v.user, + workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, streaming=True, @@ -227,7 +226,7 @@ class Processor(FlowProcessor): # Non-streaming path - single response with answer and token usage response, usage = await self.rag.query( v.query, - user=v.user, + workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, explain_callback=send_explainability, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index a4b14644..81dc8fe2 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -75,12 +75,11 @@ def edge_id(s, p, o): class LRUCacheWithTTL: - """LRU cache with TTL for label caching + """LRU cache with TTL for label caching. - CRITICAL SECURITY WARNING: - This cache is shared within a GraphRag instance but GraphRag instances - are created per-request. Cache keys MUST include user:collection prefix - to ensure data isolation between different security contexts. + GraphRag instances are created per-request, so this cache is + request-scoped. Cache keys include the collection prefix to keep + entries from different collections distinct within one request. """ def __init__(self, max_size=5000, ttl=300): @@ -119,12 +118,11 @@ class LRUCacheWithTTL: class Query: def __init__( - self, rag, user, collection, verbose, + self, rag, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, max_path_length=2, track_usage=None, ): self.rag = rag - self.user = user self.collection = collection self.verbose = verbose self.entity_limit = entity_limit @@ -194,7 +192,7 @@ class Query: entity_tasks = [ self.rag.graph_embeddings_client.query( vector=v, limit=per_concept_limit, - user=self.user, collection=self.collection, + collection=self.collection, ) for v in vectors ] @@ -222,18 +220,18 @@ class Query: async def maybe_label(self, e): - # CRITICAL SECURITY: Cache key MUST include user and collection - # to prevent data leakage between different contexts - cache_key = f"{self.user}:{self.collection}:{e}" + # The label cache lives on a per-request GraphRag instance — no + # cross-request isolation concern. The collection prefix keeps + # entries from different collections distinct within one request. + cache_key = f"{self.collection}:{e}" - # Check LRU cache first with isolated key cached_label = self.rag.label_cache.get(cache_key) if cached_label is not None: return cached_label res = await self.rag.triples_client.query( s=e, p=LABEL, o=None, limit=1, - user=self.user, collection=self.collection, + collection=self.collection, g="", ) @@ -255,19 +253,19 @@ class Query: self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection, + collection=self.collection, batch_size=20, g="", ), self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection, + collection=self.collection, batch_size=20, g="", ), self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, - user=self.user, collection=self.collection, + collection=self.collection, batch_size=20, g="", ) ]) @@ -468,7 +466,7 @@ class Query: subgraph_tasks.append( self.rag.triples_client.query( s=None, p=TG_CONTAINS, o=quoted, limit=1, - user=self.user, collection=self.collection, + collection=self.collection, g=GRAPH_SOURCE, ) ) @@ -501,7 +499,7 @@ class Query: derivation_tasks = [ self.rag.triples_client.query( s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5, - user=self.user, collection=self.collection, + collection=self.collection, g=GRAPH_SOURCE, ) for uri in current_uris @@ -535,7 +533,7 @@ class Query: metadata_tasks = [ self.rag.triples_client.query( s=uri, p=None, o=None, limit=50, - user=self.user, collection=self.collection, + collection=self.collection, ) for uri in doc_uris ] @@ -560,11 +558,9 @@ class Query: class GraphRag: """ - CRITICAL SECURITY: - This class MUST be instantiated per-request to ensure proper isolation - between users and collections. The cache within this instance will only - live for the duration of a single request, preventing cross-contamination - of data between different security contexts. + Must be instantiated per-request so the label cache lives only for + the duration of a single request. Workspace isolation is enforced + by the trusted flow layer (flow.workspace), not by this class. """ def __init__( @@ -587,7 +583,7 @@ class GraphRag: logger.debug("GraphRag initialized") async def query( - self, query, user = "trustgraph", collection = "default", + self, query, collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, max_path_length = 2, edge_score_limit = 30, edge_limit = 25, streaming = False, @@ -600,7 +596,6 @@ class GraphRag: Args: query: The query string - user: User identifier collection: Collection identifier entity_limit: Max entities to retrieve triple_limit: Max triples per entity @@ -657,7 +652,7 @@ class GraphRag: await explain_callback(q_triples, q_uri) q = Query( - rag = self, user = user, collection = collection, + rag = self, collection = collection, verbose = self.verbose, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 15c30ba1..acb111e1 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -62,9 +62,9 @@ class Processor(FlowProcessor): self.default_edge_score_limit = edge_score_limit self.default_edge_limit = edge_limit - # CRITICAL SECURITY: NEVER share data between users or collections - # Each user/collection combination MUST have isolated data access - # Caching must NEVER allow information leakage across these boundaries + # Workspace isolation is enforced by the flow layer (flow.workspace). + # Per-request caching (see GraphRag) keeps within-request state + # scoped; no cross-request sharing here. self.register_specification( ConsumerSpec( @@ -170,13 +170,13 @@ class Processor(FlowProcessor): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """ Save answer content to the librarian. Args: doc_id: ID for the answer document - user: User ID + workspace: Workspace for isolation content: Answer text content title: Optional title timeout: Request timeout in seconds @@ -188,7 +188,7 @@ class Processor(FlowProcessor): doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "GraphRAG Answer", document_type="answer", @@ -199,7 +199,7 @@ class Processor(FlowProcessor): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) # Create future for response @@ -241,14 +241,13 @@ class Processor(FlowProcessor): explainability_refs_emitted = [] # Real-time explainability callback - emits triples and IDs as they're generated - # Triples are stored in the user's collection with a named graph (urn:graph:retrieval) + # Triples are stored in the request's collection with a named graph (urn:graph:retrieval) async def send_explainability(triples, explain_id): # Send triples to explainability queue - stores in same collection with named graph await flow("explainability").send(Triples( metadata=Metadata( id=explain_id, - user=v.user, - collection=v.collection, # Store in user's collection, not separate explainability collection + collection=v.collection, ), triples=triples, )) @@ -266,9 +265,9 @@ class Processor(FlowProcessor): explainability_refs_emitted.append(explain_id) - # CRITICAL SECURITY: Create new GraphRag instance per request - # This ensures proper isolation between users and collections - # Flow clients are request-scoped and must not be shared + # Create new GraphRag instance per request — its label cache + # is request-scoped, and flow clients must not be shared + # across requests. rag = GraphRag( embeddings_client=flow("embeddings-request"), graph_embeddings_client=flow("graph-embeddings-request"), @@ -311,7 +310,7 @@ class Processor(FlowProcessor): async def save_answer(doc_id, answer_text): await self.save_answer_content( doc_id=doc_id, - user=v.user, + workspace=flow.workspace, content=answer_text, title=f"GraphRAG Answer: {v.query[:50]}...", ) @@ -333,7 +332,7 @@ class Processor(FlowProcessor): # Query with streaming and real-time explain response, usage = await rag.query( - query = v.query, user = v.user, collection = v.collection, + query = v.query, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, @@ -349,7 +348,7 @@ class Processor(FlowProcessor): else: # Non-streaming path with real-time explain response, usage = await rag.query( - query = v.query, user = v.user, collection = v.collection, + query = v.query, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, @@ -464,7 +463,7 @@ class Processor(FlowProcessor): help=f'Max edges after LLM scoring (default: 25)' ) - # Note: Explainability triples are now stored in the user's collection + # Note: Explainability triples are now stored in the request's collection # with the named graph urn:graph:retrieval (no separate collection needed) def run(): diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py index b567cc7b..091069ad 100644 --- a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py @@ -66,32 +66,39 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} + logger.info("NLP Query service initialized") - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) + + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas + # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return - + # Get the schemas dictionary for our type schemas_config = config[self.config_key] - + # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: # Parse the JSON schema definition schema_def = json.loads(schema_json) - + # Create Field objects fields = [] for field_def in schema_def.get("fields", []): @@ -106,29 +113,37 @@ class Processor(FlowProcessor): indexed=field_def.get("indexed", False) ) fields.append(field) - + # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), fields=fields ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - + + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) + except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def phase1_select_schemas(self, question: str, flow) -> List[str]: """Phase 1: Use prompt service to select relevant schemas for the question""" logger.info("Starting Phase 1: Schema selection") - + + ws_schemas = self.schemas.get(flow.workspace, {}) + # Prepare schema information for the prompt schema_info = [] - for name, schema in self.schemas.items(): + for name, schema in ws_schemas.items(): schema_desc = { "name": name, "description": schema.description, @@ -176,12 +191,14 @@ class Processor(FlowProcessor): async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]: """Phase 2: Generate GraphQL query using selected schemas""" logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}") - + + ws_schemas = self.schemas.get(flow.workspace, {}) + # Get detailed schema information for selected schemas only selected_schema_info = [] for schema_name in selected_schemas: - if schema_name in self.schemas: - schema = self.schemas[schema_name] + if schema_name in ws_schemas: + schema = ws_schemas[schema_name] schema_desc = { "name": schema_name, "description": schema.description, diff --git a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py index b878bf61..6dd79cbb 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py @@ -72,21 +72,28 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} logger.info("Structured Data Diagnosis service initialized") - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -120,13 +127,19 @@ class Processor(FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def on_message(self, msg, consumer, flow): """Handle incoming structured data diagnosis request""" @@ -216,15 +229,19 @@ class Processor(FlowProcessor): ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - # Get target schema - if request.schema_name not in self.schemas: + # Get target schema from this workspace's schemas + ws_schemas = self.schemas.get(flow.workspace, {}) + if request.schema_name not in ws_schemas: error = Error( type="SchemaNotFound", - message=f"Schema '{request.schema_name}' not found in configuration" + message=( + f"Schema '{request.schema_name}' not found " + f"in configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - target_schema = self.schemas[request.schema_name] + target_schema = ws_schemas[request.schema_name] # Generate descriptor using prompt service descriptor = await self.generate_descriptor_with_prompt( @@ -260,26 +277,33 @@ class Processor(FlowProcessor): return StructuredDataDiagnosisResponse(error=error, operation=request.operation) # Step 2: Use provided schema name or auto-select first available + ws_schemas = self.schemas.get(flow.workspace, {}) schema_name = request.schema_name - if not schema_name and self.schemas: - schema_name = list(self.schemas.keys())[0] + if not schema_name and ws_schemas: + schema_name = list(ws_schemas.keys())[0] logger.info(f"Auto-selected schema: {schema_name}") if not schema_name: error = Error( type="NoSchemaAvailable", - message="No schema specified and no schemas available in configuration" + message=( + f"No schema specified and no schemas available " + f"in configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - if schema_name not in self.schemas: + if schema_name not in ws_schemas: error = Error( type="SchemaNotFound", - message=f"Schema '{schema_name}' not found in configuration" + message=( + f"Schema '{schema_name}' not found in " + f"configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - target_schema = self.schemas[schema_name] + target_schema = ws_schemas[schema_name] # Step 3: Generate descriptor descriptor = await self.generate_descriptor_with_prompt( @@ -316,8 +340,9 @@ class Processor(FlowProcessor): logger.info("Processing schema-selection operation") # Prepare all schemas for the prompt - match the original config format + ws_schemas = self.schemas.get(flow.workspace, {}) all_schemas = [] - for schema_name, row_schema in self.schemas.items(): + for schema_name, row_schema in ws_schemas.items(): schema_info = { "name": row_schema.name, "description": row_schema.description, diff --git a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py index e39f9041..151703cb 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py @@ -111,9 +111,9 @@ class Processor(FlowProcessor): else: variables_as_strings[key] = str(value) - # Use user/collection values from request + # Use collection from request. Workspace isolation is + # enforced by flow.workspace at the rows-query service. objects_request = RowsQueryRequest( - user=request.user, collection=request.collection, query=nlp_response.graphql_query, variables=variables_as_strings, diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index f5c12441..7c8db19d 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -33,7 +33,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): for emb in message.chunks: @@ -45,7 +45,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if vec: self.vecstore.insert( vec, chunk_id, - message.metadata.user, + workspace, message.metadata.collection ) @@ -60,27 +60,27 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") - self.vecstore.create_collection(user, collection) + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") + self.vecstore.create_collection(workspace, collection) except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for document embeddings via config push""" try: - self.vecstore.delete_collection(user, collection) - logger.info(f"Successfully deleted collection {user}/{collection}") + self.vecstore.delete_collection(workspace, collection) + logger.info(f"Successfully deleted collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 31a70f23..41a1e5a5 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -88,12 +88,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -112,7 +112,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Create index name with dimension suffix for lazy creation dim = len(vec) index_name = ( - f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + f"d-{workspace}-{message.metadata.collection}-{dim}" ) # Lazily create index if it doesn't exist (but only if authorized in config) @@ -165,22 +165,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - indexes are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for document embeddings via config push""" try: - prefix = f"d-{user}-{collection}-" + prefix = f"d-{workspace}-{collection}-" # Get all indexes and filter for matches all_indexes = self.pinecone.list_indexes() @@ -195,10 +195,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for index_name in matching_indexes: self.pinecone.delete_index(index_name) logger.info(f"Deleted Pinecone index: {index_name}") - logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") + logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index e5e7e705..fb7166b5 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -39,12 +39,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -63,7 +63,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( - f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + f"d_{workspace}_{message.metadata.collection}_{dim}" ) # Lazily create collection if it doesn't exist (but only if authorized in config) @@ -107,22 +107,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): help=f'Qdrant API key (default: None)' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for document embeddings via config push""" try: - prefix = f"d_{user}_{collection}_" + prefix = f"d_{workspace}_{collection}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -137,10 +137,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for collection_name in matching_collections: self.qdrant.delete_collection(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 9346c948..2068d58c 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): for entity in message.entities: entity_value = get_term_value(entity.entity) @@ -57,7 +57,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if vec: self.vecstore.insert( vec, entity_value, - message.metadata.user, + workspace, message.metadata.collection, chunk_id=entity.chunk_id or "", ) @@ -73,27 +73,27 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") - self.vecstore.create_collection(user, collection) + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") + self.vecstore.create_collection(workspace, collection) except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for graph embeddings via config push""" try: - self.vecstore.delete_collection(user, collection) - logger.info(f"Successfully deleted collection {user}/{collection}") + self.vecstore.delete_collection(workspace, collection) + logger.info(f"Successfully deleted collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 6a95a38d..23662f7f 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -102,12 +102,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -126,7 +126,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Create index name with dimension suffix for lazy creation dim = len(vec) index_name = ( - f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + f"t-{workspace}-{message.metadata.collection}-{dim}" ) # Lazily create index if it doesn't exist (but only if authorized in config) @@ -183,22 +183,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - indexes are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for graph embeddings via config push""" try: - prefix = f"t-{user}-{collection}-" + prefix = f"t-{workspace}-{collection}-" # Get all indexes and filter for matches all_indexes = self.pinecone.list_indexes() @@ -213,10 +213,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): for index_name in matching_indexes: self.pinecone.delete_index(index_name) logger.info(f"Deleted Pinecone index: {index_name}") - logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") + logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 9a7672f8..391c2a04 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -54,12 +54,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -78,7 +78,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( - f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" + f"t_{workspace}_{message.metadata.collection}_{dim}" ) # Lazily create collection if it doesn't exist (but only if authorized in config) @@ -126,22 +126,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): help=f'Qdrant API key' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for graph embeddings via config push""" try: - prefix = f"t_{user}_{collection}_" + prefix = f"t_{workspace}_{collection}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -156,10 +156,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): for collection_name in matching_collections: self.qdrant.delete_collection(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index 475604b6..57e1fe48 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -65,13 +65,13 @@ class Processor(FlowProcessor): v = msg.value() if v.triples: - await self.table_store.add_triples(v) + await self.table_store.add_triples(flow.workspace, v) async def on_graph_embeddings(self, msg, consumer, flow): v = msg.value() if v.entities: - await self.table_store.add_graph_embeddings(v) + await self.table_store.add_graph_embeddings(flow.workspace, v) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index a6ec4ff7..32d87871 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -2,13 +2,13 @@ Row embeddings writer for Qdrant (Stage 2). Consumes RowEmbeddings messages (which already contain computed vectors) -and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair. +and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair. This follows the two-stage pattern used by graph-embeddings and document-embeddings: Stage 1 (row-embeddings): Compute embeddings Stage 2 (this processor): Store embeddings -Collection naming: rows_{user}_{collection}_{schema_name}_{dimension} +Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension} Payload structure: - index_name: The indexed field(s) this embedding represents @@ -77,10 +77,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): return safe_name.lower() def get_collection_name( - self, user: str, collection: str, schema_name: str, dimension: int + self, workspace: str, collection: str, schema_name: str, dimension: int ) -> str: """Generate Qdrant collection name""" - safe_user = self.sanitize_name(user) + safe_user = self.sanitize_name(workspace) safe_collection = self.sanitize_name(collection) safe_schema = self.sanitize_name(schema_name) return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" @@ -114,18 +114,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{embeddings.schema_name} from {embeddings.metadata.id}" ) + workspace = flow.workspace + # Validate collection exists in config before processing if not self.collection_exists( - embeddings.metadata.user, embeddings.metadata.collection + workspace, embeddings.metadata.collection ): logger.warning( - f"Collection {embeddings.metadata.collection} for user " - f"{embeddings.metadata.user} does not exist in config. " + f"Collection {embeddings.metadata.collection} for workspace " + f"{workspace} does not exist in config. " f"Dropping message." ) return - user = embeddings.metadata.user collection = embeddings.metadata.collection schema_name = embeddings.schema_name @@ -145,7 +146,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( - user, collection, schema_name, dimension + workspace, collection, schema_name, dimension ) self.ensure_collection(qdrant_collection, dimension) @@ -168,17 +169,17 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Collection creation via config push - collections created lazily on first write""" logger.info( - f"Row embeddings collection create request for {user}/{collection} - " + f"Row embeddings collection create request for {workspace}/{collection} - " f"will be created lazily on first write" ) - async def delete_collection(self, user: str, collection: str): - """Delete all Qdrant collections for a given user/collection""" + async def delete_collection(self, workspace: str, collection: str): + """Delete all Qdrant collections for a given workspace/collection""" try: - prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_" + prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -196,23 +197,23 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info( f"Deleted {len(matching_collections)} collection(s) " - f"for {user}/{collection}" + f"for {workspace}/{collection}" ) except Exception as e: logger.error( - f"Failed to delete collection {user}/{collection}: {e}", + f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True ) raise async def delete_collection_schema( - self, user: str, collection: str, schema_name: str + self, workspace: str, collection: str, schema_name: str ): - """Delete Qdrant collection for a specific user/collection/schema""" + """Delete Qdrant collection for a specific workspace/collection/schema""" try: prefix = ( - f"rows_{self.sanitize_name(user)}_" + f"rows_{self.sanitize_name(workspace)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) @@ -233,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): except Exception as e: logger.error( - f"Failed to delete collection {user}/{collection}/{schema_name}: {e}", + f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}", exc_info=True ) raise diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index d0eec2e1..acfe00d2 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -119,19 +119,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) raise - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Track which schemas changed so we can clear partition cache - old_schema_names = set(self.schemas.keys()) + # Track which schemas changed in this workspace + old_schemas = self.schemas.get(workspace, {}) + old_schema_names = set(old_schemas.keys()) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -165,24 +173,32 @@ class Processor(CollectionConfigHandler, FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) - # Clear partition cache for schemas that changed - # This ensures next write will re-register partitions - new_schema_names = set(self.schemas.keys()) + # Clear partition cache for schemas that changed in this workspace + new_schema_names = set(ws_schemas.keys()) changed_schemas = old_schema_names.symmetric_difference(new_schema_names) if changed_schemas: self.registered_partitions = { (col, sch) for col, sch in self.registered_partitions if sch not in changed_schemas } - logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}") + logger.info( + f"Cleared partition cache for changed schemas " + f"in {workspace}: {changed_schemas}" + ) def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -286,7 +302,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): return index_names - def register_partitions(self, keyspace: str, collection: str, schema_name: str): + def register_partitions( + self, keyspace: str, collection: str, schema_name: str, + workspace: str, + ): """ Register partition entries for a (collection, schema_name) pair. Called once on first row for each pair. @@ -295,9 +314,13 @@ class Processor(CollectionConfigHandler, FlowProcessor): if cache_key in self.registered_partitions: return - schema = self.schemas.get(schema_name) + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(schema_name) if not schema: - logger.warning(f"Cannot register partitions - schema {schema_name} not found") + logger.warning( + f"Cannot register partitions - schema {schema_name} " + f"not found in workspace {workspace}" + ) return safe_keyspace = self.sanitize_name(keyspace) @@ -338,13 +361,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): """Process incoming ExtractedObject and store in Cassandra""" obj = msg.value() + workspace = flow.workspace logger.info( f"Storing {len(obj.values)} rows for schema {obj.schema_name} " - f"from {obj.metadata.id}" + f"from {obj.metadata.id} (workspace {workspace})" ) # Validate collection exists before accepting writes - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + if not self.collection_exists(workspace, obj.metadata.collection): error_msg = ( f"Collection {obj.metadata.collection} does not exist. " f"Create it first via collection management API." @@ -352,13 +376,17 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(error_msg) raise ValueError(error_msg) - # Get schema definition - schema = self.schemas.get(obj.schema_name) + # Get schema definition for this workspace + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(obj.schema_name) if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") + logger.warning( + f"No schema found for {obj.schema_name} in " + f"workspace {workspace} - skipping" + ) return - keyspace = obj.metadata.user + keyspace = workspace collection = obj.metadata.collection schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' @@ -370,7 +398,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register partitions if first time seeing this (collection, schema_name) await asyncio.to_thread( - self.register_partitions, keyspace, collection, schema_name + self.register_partitions, + keyspace, collection, schema_name, workspace, ) safe_keyspace = self.sanitize_name(keyspace) @@ -430,25 +459,25 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"({len(index_names)} indexes per row)" ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" # Connect if not already connected (sync, push to thread) await asyncio.to_thread(self.connect_cassandra) # Ensure tables exist (sync DDL, push to thread) - await asyncio.to_thread(self.ensure_tables, user) + await asyncio.to_thread(self.ensure_tables, workspace) - logger.info(f"Collection {collection} ready for user {user}") + logger.info(f"Collection {collection} ready for workspace {workspace}") - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection using partition tracking""" # Connect if not already connected await asyncio.to_thread(self.connect_cassandra) - safe_keyspace = self.sanitize_name(user) + safe_keyspace = self.sanitize_name(workspace) # Check if keyspace exists - if user not in self.known_keyspaces: + if workspace not in self.known_keyspaces: check_keyspace_cql = """ SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s @@ -459,7 +488,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): if not result: logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") return - self.known_keyspaces.add(user) + self.known_keyspaces.add(workspace) # Discover all partitions for this collection select_partitions_cql = f""" @@ -522,12 +551,12 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"from keyspace {safe_keyspace}" ) - async def delete_collection_schema(self, user: str, collection: str, schema_name: str): + async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" # Connect if not already connected await asyncio.to_thread(self.connect_cassandra) - safe_keyspace = self.sanitize_name(user) + safe_keyspace = self.sanitize_name(workspace) # Discover partitions for this collection + schema select_partitions_cql = f""" diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 01d95c8b..05331d09 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -147,9 +147,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_triples(self, message): - - user = message.metadata.user + async def store_triples(self, workspace, message): # The cassandra-driver work below — connection, schema # setup, and per-triple inserts — is all synchronous. @@ -159,7 +157,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): def _do_store(): - if self.table is None or self.table != user: + if self.table is None or self.table != workspace: self.tg = None @@ -170,21 +168,21 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=message.metadata.user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=message.metadata.user, + keyspace=workspace, ) except Exception as e: logger.error(f"Exception: {e}", exc_info=True) time.sleep(1) raise e - self.table = user + self.table = workspace for t in message.triples: # Extract values from Term objects @@ -212,12 +210,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): await asyncio.to_thread(_do_store) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" def _do_create(): - # Create or reuse connection for this user's keyspace - if self.table is None or self.table != user: + # Create or reuse connection for this workspace's keyspace + if self.table is None or self.table != workspace: self.tg = None # Use factory function to select implementation @@ -227,23 +225,23 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, ) except Exception as e: - logger.error(f"Failed to connect to Cassandra for user {user}: {e}") + logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") raise - self.table = user + self.table = workspace # Create collection using the built-in method - logger.info(f"Creating collection {collection} for user {user}") + logger.info(f"Creating collection {collection} for workspace {workspace}") if self.tg.collection_exists(collection): logger.info(f"Collection {collection} already exists") @@ -254,15 +252,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): try: await asyncio.to_thread(_do_create) except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection from the unified triples table""" def _do_delete(): - # Create or reuse connection for this user's keyspace - if self.table is None or self.table != user: + # Create or reuse connection for this workspace's keyspace + if self.table is None or self.table != workspace: self.tg = None # Use factory function to select implementation @@ -272,29 +270,29 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, ) except Exception as e: - logger.error(f"Failed to connect to Cassandra for user {user}: {e}") + logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") raise - self.table = user + self.table = workspace # Delete all triples for this collection using the built-in method self.tg.delete_collection(collection) - logger.info(f"Deleted all triples for collection {collection} from keyspace {user}") + logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") try: await asyncio.to_thread(_do_delete) except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise @staticmethod diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index 86f9a6e3..77c32919 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -59,15 +59,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - def create_node(self, uri, user, collection): + def create_node(self, uri, workspace, collection): - logger.debug(f"Create node {uri} for user={user}, collection={collection}") + logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}") res = self.io.query( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", params={ "uri": uri, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -77,15 +77,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def create_literal(self, value, user, collection): + def create_literal(self, value, workspace, collection): - logger.debug(f"Create literal {value} for user={user}, collection={collection}") + logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}") res = self.io.query( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", params={ "value": value, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -95,19 +95,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def relate_node(self, src, uri, dest, user, collection): + def relate_node(self, src, uri, dest, workspace, collection): - logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") res = self.io.query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src, "dest": dest, "uri": uri, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -117,19 +117,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def relate_literal(self, src, uri, dest, user, collection): + def relate_literal(self, src, uri, dest, workspace, collection): - logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") res = self.io.query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src, "dest": dest, "uri": uri, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -139,36 +139,34 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def collection_exists(self, user, collection): + def collection_exists(self, workspace, collection): """Check if collection metadata node exists""" result = self.io.query( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "RETURN c LIMIT 1", - params={"user": user, "collection": collection} + params={"workspace": workspace, "collection": collection} ) return result.result_set is not None and len(result.result_set) > 0 - def create_collection(self, user, collection): + def create_collection(self, workspace, collection): """Create collection metadata node""" import datetime self.io.query( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", params={ - "user": user, + "workspace": workspace, "collection": collection, "created_at": datetime.datetime.now().isoformat() } ) - logger.info(f"Created collection metadata node for {user}/{collection}") + logger.info(f"Created collection metadata node for {workspace}/{collection}") - async def store_triples(self, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" + async def store_triples(self, workspace, message): collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -182,14 +180,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) @staticmethod def add_args(parser): @@ -208,58 +206,58 @@ class Processor(CollectionConfigHandler, TriplesStoreService): help=f'FalkorDB database (default: {default_database})' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create collection metadata in FalkorDB via config push""" try: # Check if collection exists result = self.io.query( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1", - params={"user": user, "collection": collection} + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) RETURN c LIMIT 1", + params={"workspace": workspace, "collection": collection} ) if result.result_set: - logger.info(f"Collection {user}/{collection} already exists") + logger.info(f"Collection {workspace}/{collection} already exists") else: # Create collection metadata node import datetime self.io.query( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", params={ - "user": user, + "workspace": workspace, "collection": collection, "created_at": datetime.datetime.now().isoformat() } ) - logger.info(f"Created collection {user}/{collection}") + logger.info(f"Created collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for FalkorDB triples via config push""" try: - # Delete all nodes and literals for this user/collection + # Delete all nodes and literals for this workspace/collection node_result = self.io.query( - "MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n", - params={"user": user, "collection": collection} + "MATCH (n:Node {workspace: $workspace, collection: $collection}) DETACH DELETE n", + params={"workspace": workspace, "collection": collection} ) literal_result = self.io.query( - "MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n", - params={"user": user, "collection": collection} + "MATCH (n:Literal {workspace: $workspace, collection: $collection}) DETACH DELETE n", + params={"workspace": workspace, "collection": collection} ) # Delete collection metadata node metadata_result = self.io.query( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c", - params={"user": user, "collection": collection} + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) DELETE c", + params={"workspace": workspace, "collection": collection} ) - logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}") + logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 16a7d3ed..3e1a8288 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -117,10 +117,10 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Maybe index already exists logger.warning("Index create failure ignored") - # New indexes for user/collection filtering + # New indexes for workspace/collection filtering try: session.run( - "CREATE INDEX ON :Node(user)" + "CREATE INDEX ON :Node(workspace)" ) except Exception as e: logger.warning(f"User index create failure: {e}") @@ -136,7 +136,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): try: session.run( - "CREATE INDEX ON :Literal(user)" + "CREATE INDEX ON :Literal(workspace)" ) except Exception as e: logger.warning(f"User index create failure: {e}") @@ -152,13 +152,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): logger.info("Index creation done") - def create_node(self, uri, user, collection): + def create_node(self, uri, workspace, collection): - logger.debug(f"Create node {uri} for user={user}, collection={collection}") + logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=uri, user=user, collection=collection, + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -167,13 +167,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def create_literal(self, value, user, collection): + def create_literal(self, value, workspace, collection): - logger.debug(f"Create literal {value} for user={user}, collection={collection}") + logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=value, user=user, collection=collection, + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value=value, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -182,15 +182,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_node(self, src, uri, dest, user, collection): + def relate_node(self, src, uri, dest, workspace, collection): - logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -199,15 +199,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_literal(self, src, uri, dest, user, collection): + def relate_literal(self, src, uri, dest, workspace, collection): - logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def create_triple(self, tx, t, user, collection): + def create_triple(self, tx, t, workspace, collection): s_val = get_term_value(t.s) p_val = get_term_value(t.p) @@ -224,48 +224,46 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Create new s node with given uri, if not exists result = tx.run( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=s_val, user=user, collection=collection + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=s_val, workspace=workspace, collection=collection ) if t.o.type == IRI: # Create new o node with given uri, if not exists result = tx.run( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=o_val, user=user, collection=collection + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=o_val, workspace=workspace, collection=collection ) result = tx.run( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection, ) else: # Create new o literal with given uri, if not exists result = tx.run( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=o_val, user=user, collection=collection + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value=o_val, workspace=workspace, collection=collection ) result = tx.run( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection, ) - async def store_triples(self, message): + async def store_triples(self, workspace, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -279,18 +277,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) # Alternative implementation using transactions # with self.io.session(database=self.db) as session: - # session.execute_write(self.create_triple, t, user, collection) + # session.execute_write(self.create_triple, t, workspace, collection) @staticmethod def add_args(parser): @@ -321,72 +319,72 @@ class Processor(CollectionConfigHandler, TriplesStoreService): help=f'Memgraph database (default: {default_database})' ) - def _collection_exists_in_db(self, user, collection): + def _collection_exists_in_db(self, workspace, collection): """Check if collection metadata node exists""" with self.io.session(database=self.db) as session: result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "RETURN c LIMIT 1", - user=user, collection=collection + workspace=workspace, collection=collection ) return bool(list(result)) - def _create_collection_in_db(self, user, collection): + def _create_collection_in_db(self, workspace, collection): """Create collection metadata node""" import datetime with self.io.session(database=self.db) as session: session.run( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", - user=user, collection=collection, + workspace=workspace, collection=collection, created_at=datetime.datetime.now().isoformat() ) - logger.info(f"Created collection metadata node for {user}/{collection}") + logger.info(f"Created collection metadata node for {workspace}/{collection}") - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create collection metadata in Memgraph via config push""" try: - if self._collection_exists_in_db(user, collection): - logger.info(f"Collection {user}/{collection} already exists") + if self._collection_exists_in_db(workspace, collection): + logger.info(f"Collection {workspace}/{collection} already exists") else: - self._create_collection_in_db(user, collection) - logger.info(f"Created collection {user}/{collection}") + self._create_collection_in_db(workspace, collection) + logger.info(f"Created collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection via config push""" try: with self.io.session(database=self.db) as session: - # Delete all nodes for this user and collection + # Delete all nodes for this workspace and collection node_result = session.run( - "MATCH (n:Node {user: $user, collection: $collection}) " + "MATCH (n:Node {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) nodes_deleted = node_result.consume().counters.nodes_deleted - # Delete all literals for this user and collection + # Delete all literals for this workspace and collection literal_result = session.run( - "MATCH (n:Literal {user: $user, collection: $collection}) " + "MATCH (n:Literal {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) literals_deleted = literal_result.consume().counters.nodes_deleted # Delete collection metadata node metadata_result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "DELETE c", - user=user, collection=collection + workspace=workspace, collection=collection ) metadata_deleted = metadata_result.consume().counters.nodes_deleted # Note: Relationships are automatically deleted with DETACH DELETE - logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") + logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index f7b2d947..22e25153 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -80,14 +80,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): logger.info("Create indexes...") - # Legacy indexes for backwards compatibility try: session.run( "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", ) except Exception as e: logger.warning(f"Index create failure: {e}") - # Maybe index already exists logger.warning("Index create failure ignored") try: @@ -96,7 +94,6 @@ class Processor(CollectionConfigHandler, TriplesStoreService): ) except Exception as e: logger.warning(f"Index create failure: {e}") - # Maybe index already exists logger.warning("Index create failure ignored") try: @@ -105,13 +102,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService): ) except Exception as e: logger.warning(f"Index create failure: {e}") - # Maybe index already exists logger.warning("Index create failure ignored") - # New compound indexes for user/collection filtering try: session.run( - "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", + "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)", ) except Exception as e: logger.warning(f"Compound index create failure: {e}") @@ -119,17 +114,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService): try: session.run( - "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", + "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)", ) except Exception as e: logger.warning(f"Compound index create failure: {e}") logger.warning("Index create failure ignored") - # Note: Neo4j doesn't support compound indexes on relationships in all versions - # Try to create individual indexes on relationship properties + # Neo4j doesn't support compound indexes on relationships in all versions try: session.run( - "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)", ) except Exception as e: logger.warning(f"Relationship index create failure: {e}") @@ -145,13 +139,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): logger.info("Index creation done") - def create_node(self, uri, user, collection): + def create_node(self, uri, workspace, collection): - logger.debug(f"Create node {uri} for user={user}, collection={collection}") + logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=uri, user=user, collection=collection, + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -160,13 +154,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def create_literal(self, value, user, collection): + def create_literal(self, value, workspace, collection): - logger.debug(f"Create literal {value} for user={user}, collection={collection}") + logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=value, user=user, collection=collection, + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value=value, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -175,15 +169,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_node(self, src, uri, dest, user, collection): + def relate_node(self, src, uri, dest, workspace, collection): - logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -192,15 +186,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_literal(self, src, uri, dest, user, collection): + def relate_literal(self, src, uri, dest, workspace, collection): - logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -209,14 +203,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - async def store_triples(self, message): + async def store_triples(self, workspace, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -230,14 +222,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) @staticmethod def add_args(parser): @@ -268,75 +260,70 @@ class Processor(CollectionConfigHandler, TriplesStoreService): help=f'Neo4j database (default: {default_database})' ) - def _collection_exists_in_db(self, user, collection): + def _collection_exists_in_db(self, workspace, collection): """Check if collection metadata node exists""" with self.io.session(database=self.db) as session: result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "RETURN c LIMIT 1", - user=user, collection=collection + workspace=workspace, collection=collection ) return bool(list(result)) - def _create_collection_in_db(self, user, collection): + def _create_collection_in_db(self, workspace, collection): """Create collection metadata node""" import datetime with self.io.session(database=self.db) as session: session.run( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", - user=user, collection=collection, + workspace=workspace, collection=collection, created_at=datetime.datetime.now().isoformat() ) - logger.info(f"Created collection metadata node for {user}/{collection}") + logger.info(f"Created collection metadata node for {workspace}/{collection}") - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create collection metadata in Neo4j via config push""" try: - if self._collection_exists_in_db(user, collection): - logger.info(f"Collection {user}/{collection} already exists") + if self._collection_exists_in_db(workspace, collection): + logger.info(f"Collection {workspace}/{collection} already exists") else: - self._create_collection_in_db(user, collection) - logger.info(f"Created collection {user}/{collection}") + self._create_collection_in_db(workspace, collection) + logger.info(f"Created collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection via config push""" try: with self.io.session(database=self.db) as session: - # Delete all nodes for this user and collection node_result = session.run( - "MATCH (n:Node {user: $user, collection: $collection}) " + "MATCH (n:Node {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) nodes_deleted = node_result.consume().counters.nodes_deleted - # Delete all literals for this user and collection literal_result = session.run( - "MATCH (n:Literal {user: $user, collection: $collection}) " + "MATCH (n:Literal {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) literals_deleted = literal_result.consume().counters.nodes_deleted - # Note: Relationships are automatically deleted with DETACH DELETE - - # Delete collection metadata node metadata_result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "DELETE c", - user=user, collection=collection + workspace=workspace, collection=collection ) metadata_deleted = metadata_result.consume().counters.nodes_deleted - logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") + logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py index d9a8711b..8fd00427 100644 --- a/trustgraph-flow/trustgraph/tables/config.py +++ b/trustgraph-flow/trustgraph/tables/config.py @@ -72,10 +72,11 @@ class ConfigTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS config ( + workspace text, class text, key text, value text, - PRIMARY KEY (class, key) + PRIMARY KEY ((workspace, class), key) ); """); @@ -124,52 +125,63 @@ class ConfigTableStore: def prepare_statements(self): self.put_config_stmt = self.cassandra.prepare(""" - INSERT INTO config ( class, key, value ) - VALUES (?, ?, ?) - """) - - self.get_classes_stmt = self.cassandra.prepare(""" - SELECT DISTINCT class FROM config; + INSERT INTO config ( workspace, class, key, value ) + VALUES (?, ?, ?, ?) """) self.get_keys_stmt = self.cassandra.prepare(""" - SELECT key FROM config WHERE class = ?; + SELECT key FROM config + WHERE workspace = ? AND class = ?; """) self.get_value_stmt = self.cassandra.prepare(""" - SELECT value FROM config WHERE class = ? AND key = ?; + SELECT value FROM config + WHERE workspace = ? AND class = ? AND key = ?; """) self.delete_key_stmt = self.cassandra.prepare(""" DELETE FROM config - WHERE class = ? AND key = ?; + WHERE workspace = ? AND class = ? AND key = ?; """) self.get_all_stmt = self.cassandra.prepare(""" - SELECT class AS cls, key, value FROM config; + SELECT workspace, class AS cls, key, value FROM config; + """) + + self.get_all_for_workspace_stmt = self.cassandra.prepare(""" + SELECT class AS cls, key, value FROM config + WHERE workspace = ? + ALLOW FILTERING; """) self.get_values_stmt = self.cassandra.prepare(""" - SELECT key, value FROM config WHERE class = ?; + SELECT key, value FROM config + WHERE workspace = ? AND class = ?; """) - async def put_config(self, cls, key, value): + self.get_values_all_ws_stmt = self.cassandra.prepare(""" + SELECT workspace, key, value FROM config + WHERE class = ? + ALLOW FILTERING; + """) + + async def put_config(self, workspace, cls, key, value): try: await async_execute( self.cassandra, self.put_config_stmt, - (cls, key, value), + (workspace, cls, key, value), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - async def get_value(self, cls, key): + async def get_value(self, workspace, cls, key): try: rows = await async_execute( self.cassandra, self.get_value_stmt, - (cls, key), + (workspace, cls, key), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -179,12 +191,12 @@ class ConfigTableStore: return row[0] return None - async def get_values(self, cls): + async def get_values(self, workspace, cls): try: rows = await async_execute( self.cassandra, self.get_values_stmt, - (cls,), + (workspace, cls), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -192,18 +204,20 @@ class ConfigTableStore: return [[row[0], row[1]] for row in rows] - async def get_classes(self): + async def get_values_all_ws(self, cls): + """Return (workspace, key, value) tuples for all workspaces + with entries of the given class.""" try: rows = await async_execute( self.cassandra, - self.get_classes_stmt, - (), + self.get_values_all_ws_stmt, + (cls,), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - return [row[0] for row in rows] + return [(row[0], row[1], row[2]) for row in rows] async def get_all(self): try: @@ -216,14 +230,27 @@ class ConfigTableStore: logger.error("Exception occurred", exc_info=True) raise + return [(row[0], row[1], row[2], row[3]) for row in rows] + + async def get_all_for_workspace(self, workspace): + try: + rows = await async_execute( + self.cassandra, + self.get_all_for_workspace_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + return [(row[0], row[1], row[2]) for row in rows] - async def get_keys(self, cls): + async def get_keys(self, workspace, cls): try: rows = await async_execute( self.cassandra, self.get_keys_stmt, - (cls,), + (workspace, cls), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -231,12 +258,12 @@ class ConfigTableStore: return [row[0] for row in rows] - async def delete_key(self, cls, key): + async def delete_key(self, workspace, cls, key): try: await async_execute( self.cassandra, self.delete_key_stmt, - (cls, key), + (workspace, cls, key), ) except Exception: logger.error("Exception occurred", exc_info=True) diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index b06f4862..4d729956 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -88,7 +88,7 @@ class KnowledgeTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS triples ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -98,7 +98,7 @@ class KnowledgeTableStore: triples list>, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); @@ -106,7 +106,7 @@ class KnowledgeTableStore: self.cassandra.execute(""" create table if not exists graph_embeddings ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -119,20 +119,20 @@ class KnowledgeTableStore: list > >, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS graph_embeddings_user ON - graph_embeddings ( user ); + CREATE INDEX IF NOT EXISTS graph_embeddings_workspace ON + graph_embeddings ( workspace ); """); logger.debug("document_embeddings table...") self.cassandra.execute(""" create table if not exists document_embeddings ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -145,13 +145,13 @@ class KnowledgeTableStore: list > >, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS document_embeddings_user ON - document_embeddings ( user ); + CREATE INDEX IF NOT EXISTS document_embeddings_workspace ON + document_embeddings ( workspace ); """); logger.info("Cassandra schema OK.") @@ -161,7 +161,7 @@ class KnowledgeTableStore: self.insert_triples_stmt = self.cassandra.prepare(""" INSERT INTO triples ( - id, user, document_id, + id, workspace, document_id, time, metadata, triples ) VALUES (?, ?, ?, ?, ?, ?) @@ -170,7 +170,7 @@ class KnowledgeTableStore: self.insert_graph_embeddings_stmt = self.cassandra.prepare(""" INSERT INTO graph_embeddings ( - id, user, document_id, time, metadata, entity_embeddings + id, workspace, document_id, time, metadata, entity_embeddings ) VALUES (?, ?, ?, ?, ?, ?) """) @@ -178,45 +178,45 @@ class KnowledgeTableStore: self.insert_document_embeddings_stmt = self.cassandra.prepare(""" INSERT INTO document_embeddings ( - id, user, document_id, time, metadata, chunks + id, workspace, document_id, time, metadata, chunks ) VALUES (?, ?, ?, ?, ?, ?) """) self.list_cores_stmt = self.cassandra.prepare(""" - SELECT DISTINCT user, document_id FROM graph_embeddings - WHERE user = ? + SELECT DISTINCT workspace, document_id FROM graph_embeddings + WHERE workspace = ? """) self.get_triples_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, triples FROM triples - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.get_graph_embeddings_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, entity_embeddings FROM graph_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.get_document_embeddings_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, chunks FROM document_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.delete_triples_stmt = self.cassandra.prepare(""" DELETE FROM triples - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.delete_graph_embeddings_stmt = self.cassandra.prepare(""" DELETE FROM graph_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) - async def add_triples(self, m): + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -232,7 +232,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_triples_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], triples, ), @@ -241,7 +241,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def add_graph_embeddings(self, m): + async def add_graph_embeddings(self, workspace, m): when = int(time.time() * 1000) @@ -258,7 +258,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_graph_embeddings_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], entities, ), @@ -267,7 +267,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def add_document_embeddings(self, m): + async def add_document_embeddings(self, workspace, m): when = int(time.time() * 1000) @@ -284,7 +284,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_document_embeddings_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], chunks, ), @@ -293,7 +293,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def list_kg_cores(self, user): + async def list_kg_cores(self, workspace): logger.debug("List kg cores...") @@ -301,7 +301,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.list_cores_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -313,7 +313,7 @@ class KnowledgeTableStore: return lst - async def delete_kg_core(self, user, document_id): + async def delete_kg_core(self, workspace, document_id): logger.debug("Delete kg cores...") @@ -321,7 +321,7 @@ class KnowledgeTableStore: await async_execute( self.cassandra, self.delete_triples_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -331,13 +331,13 @@ class KnowledgeTableStore: await async_execute( self.cassandra, self.delete_graph_embeddings_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - async def get_triples(self, user, document_id, receiver): + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -345,7 +345,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.get_triples_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -369,7 +369,6 @@ class KnowledgeTableStore: Triples( metadata = Metadata( id = document_id, - user = user, collection = "default", # FIXME: What to put here? ), triples = triples @@ -378,7 +377,7 @@ class KnowledgeTableStore: logger.debug("Done") - async def get_graph_embeddings(self, user, document_id, receiver): + async def get_graph_embeddings(self, workspace, document_id, receiver): logger.debug("Get GE...") @@ -386,7 +385,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.get_graph_embeddings_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -409,12 +408,11 @@ class KnowledgeTableStore: GraphEmbeddings( metadata = Metadata( id = document_id, - user = user, collection = "default", # FIXME: What to put here? ), entities = entities ) - ) + ) logger.debug("Done") diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index c85ae72a..86706079 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -64,7 +64,7 @@ class LibraryTableStore: self.cluster = Cluster(cassandra_host) self.cassandra = self.cluster.connect() - + logger.info("Connected.") self.ensure_cassandra_schema() @@ -76,13 +76,13 @@ class LibraryTableStore: logger.debug("Ensure Cassandra schema...") logger.debug("Keyspace...") - + # FIXME: Replication factor should be configurable self.cassandra.execute(f""" create keyspace if not exists {self.keyspace} - with replication = {{ - 'class' : 'SimpleStrategy', - 'replication_factor' : 1 + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 }}; """); @@ -93,7 +93,7 @@ class LibraryTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS document ( id text, - user text, + workspace text, time timestamp, kind text, title text, @@ -103,7 +103,9 @@ class LibraryTableStore: >>, tags list, object_id uuid, - PRIMARY KEY (user, id) + parent_id text, + document_type text, + PRIMARY KEY (workspace, id) ); """); @@ -114,27 +116,6 @@ class LibraryTableStore: ON document (object_id) """); - # Add parent_id and document_type columns for child document support - logger.debug("document table parent_id column...") - - try: - self.cassandra.execute(""" - ALTER TABLE document ADD parent_id text - """); - except Exception as e: - # Column may already exist - if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): - logger.debug(f"parent_id column may already exist: {e}") - - try: - self.cassandra.execute(""" - ALTER TABLE document ADD document_type text - """); - except Exception as e: - # Column may already exist - if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): - logger.debug(f"document_type column may already exist: {e}") - logger.debug("document parent index...") self.cassandra.execute(""" @@ -150,10 +131,10 @@ class LibraryTableStore: document_id text, time timestamp, flow text, - user text, + workspace text, collection text, tags list, - PRIMARY KEY (user, id) + PRIMARY KEY (workspace, id) ); """); @@ -162,7 +143,7 @@ class LibraryTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS upload_session ( upload_id text PRIMARY KEY, - user text, + workspace text, document_id text, document_metadata text, s3_upload_id text, @@ -176,11 +157,11 @@ class LibraryTableStore: ) WITH default_time_to_live = 86400; """); - logger.debug("upload_session user index...") + logger.debug("upload_session workspace index...") self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS upload_session_user - ON upload_session (user) + CREATE INDEX IF NOT EXISTS upload_session_workspace + ON upload_session (workspace) """); logger.info("Cassandra schema OK.") @@ -190,7 +171,7 @@ class LibraryTableStore: self.insert_document_stmt = self.cassandra.prepare(""" INSERT INTO document ( - id, user, time, + id, workspace, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type @@ -202,25 +183,25 @@ class LibraryTableStore: UPDATE document SET time = ?, title = ?, comments = ?, metadata = ?, tags = ? - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.get_document_stmt = self.cassandra.prepare(""" SELECT time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.delete_document_stmt = self.cassandra.prepare(""" DELETE FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.test_document_exists_stmt = self.cassandra.prepare(""" SELECT id FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? LIMIT 1 """) @@ -229,7 +210,7 @@ class LibraryTableStore: id, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? + WHERE workspace = ? """) self.list_document_by_tag_stmt = self.cassandra.prepare(""" @@ -237,7 +218,7 @@ class LibraryTableStore: id, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? AND tags CONTAINS ? + WHERE workspace = ? AND tags CONTAINS ? ALLOW FILTERING """) @@ -245,7 +226,7 @@ class LibraryTableStore: INSERT INTO processing ( id, document_id, time, - flow, user, collection, + flow, workspace, collection, tags ) VALUES (?, ?, ?, ?, ?, ?, ?) @@ -253,13 +234,13 @@ class LibraryTableStore: self.delete_processing_stmt = self.cassandra.prepare(""" DELETE FROM processing - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.test_processing_exists_stmt = self.cassandra.prepare(""" SELECT id FROM processing - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? LIMIT 1 """) @@ -267,14 +248,14 @@ class LibraryTableStore: SELECT id, document_id, time, flow, collection, tags FROM processing - WHERE user = ? + WHERE workspace = ? """) # Upload session prepared statements self.insert_upload_session_stmt = self.cassandra.prepare(""" INSERT INTO upload_session ( - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at ) @@ -283,7 +264,7 @@ class LibraryTableStore: self.get_upload_session_stmt = self.cassandra.prepare(""" SELECT - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at FROM upload_session @@ -308,25 +289,25 @@ class LibraryTableStore: total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at FROM upload_session - WHERE user = ? + WHERE workspace = ? """) # Child document queries self.list_children_stmt = self.cassandra.prepare(""" SELECT - id, user, time, kind, title, comments, metadata, tags, + id, workspace, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document WHERE parent_id = ? ALLOW FILTERING """) - async def document_exists(self, user, id): + async def document_exists(self, workspace, id): rows = await async_execute( self.cassandra, self.test_document_exists_stmt, - (user, id), + (workspace, id), ) return bool(rows) @@ -351,7 +332,7 @@ class LibraryTableStore: self.cassandra, self.insert_document_stmt, ( - document.id, document.user, int(document.time * 1000), + document.id, document.workspace, int(document.time * 1000), document.kind, document.title, document.comments, metadata, document.tags, object_id, parent_id, document_type @@ -381,7 +362,7 @@ class LibraryTableStore: ( int(document.time * 1000), document.title, document.comments, metadata, document.tags, - document.user, document.id + document.workspace, document.id ), ) except Exception: @@ -390,7 +371,7 @@ class LibraryTableStore: logger.debug("Update complete") - async def remove_document(self, user, document_id): + async def remove_document(self, workspace, document_id): logger.info(f"Removing document {document_id}") @@ -398,7 +379,7 @@ class LibraryTableStore: await async_execute( self.cassandra, self.delete_document_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -406,7 +387,7 @@ class LibraryTableStore: logger.debug("Delete complete") - async def list_documents(self, user): + async def list_documents(self, workspace): logger.debug("List documents...") @@ -414,7 +395,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.list_document_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -423,7 +404,7 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - user = user, + workspace = workspace, time = int(time.mktime(row[1].timetuple())), kind = row[2], title = row[3], @@ -465,7 +446,7 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - user = row[1], + workspace = row[1], time = int(time.mktime(row[2].timetuple())), kind = row[3], title = row[4], @@ -489,7 +470,7 @@ class LibraryTableStore: return lst - async def get_document(self, user, id): + async def get_document(self, workspace, id): logger.debug("Get document") @@ -497,7 +478,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.get_document_stmt, - (user, id), + (workspace, id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -506,7 +487,7 @@ class LibraryTableStore: for row in rows: doc = DocumentMetadata( id = id, - user = user, + workspace = workspace, time = int(time.mktime(row[0].timetuple())), kind = row[1], title = row[2], @@ -529,7 +510,7 @@ class LibraryTableStore: raise RuntimeError("No such document row?") - async def get_document_object_id(self, user, id): + async def get_document_object_id(self, workspace, id): logger.debug("Get document obj ID") @@ -537,7 +518,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.get_document_stmt, - (user, id), + (workspace, id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -549,12 +530,12 @@ class LibraryTableStore: raise RuntimeError("No such document row?") - async def processing_exists(self, user, id): + async def processing_exists(self, workspace, id): rows = await async_execute( self.cassandra, self.test_processing_exists_stmt, - (user, id), + (workspace, id), ) return bool(rows) @@ -570,7 +551,7 @@ class LibraryTableStore: ( processing.id, processing.document_id, int(processing.time * 1000), processing.flow, - processing.user, processing.collection, + processing.workspace, processing.collection, processing.tags ), ) @@ -580,7 +561,7 @@ class LibraryTableStore: logger.debug("Add complete") - async def remove_processing(self, user, processing_id): + async def remove_processing(self, workspace, processing_id): logger.info(f"Removing processing {processing_id}") @@ -588,7 +569,7 @@ class LibraryTableStore: await async_execute( self.cassandra, self.delete_processing_stmt, - (user, processing_id), + (workspace, processing_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -596,7 +577,7 @@ class LibraryTableStore: logger.debug("Delete complete") - async def list_processing(self, user): + async def list_processing(self, workspace): logger.debug("List processing objects") @@ -604,7 +585,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.list_processing_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -616,7 +597,7 @@ class LibraryTableStore: document_id = row[1], time = int(time.mktime(row[2].timetuple())), flow = row[3], - user = user, + workspace = workspace, collection = row[4], tags = row[5] if row[5] else [], ) @@ -632,7 +613,7 @@ class LibraryTableStore: async def create_upload_session( self, upload_id, - user, + workspace, document_id, document_metadata, s3_upload_id, @@ -652,7 +633,7 @@ class LibraryTableStore: self.cassandra, self.insert_upload_session_stmt, ( - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, {}, now, now ), @@ -681,7 +662,7 @@ class LibraryTableStore: for row in rows: session = { "upload_id": row[0], - "user": row[1], + "workspace": row[1], "document_id": row[2], "document_metadata": row[3], "s3_upload_id": row[4], @@ -738,16 +719,16 @@ class LibraryTableStore: logger.debug("Upload session deleted") - async def list_upload_sessions(self, user): - """List all upload sessions for a user.""" + async def list_upload_sessions(self, workspace): + """List all upload sessions for a workspace.""" - logger.debug(f"List upload sessions for {user}") + logger.debug(f"List upload sessions for {workspace}") try: rows = await async_execute( self.cassandra, self.list_upload_sessions_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) diff --git a/trustgraph-flow/trustgraph/tool_service/joke/service.py b/trustgraph-flow/trustgraph/tool_service/joke/service.py index d9b7cde0..171156d8 100644 --- a/trustgraph-flow/trustgraph/tool_service/joke/service.py +++ b/trustgraph-flow/trustgraph/tool_service/joke/service.py @@ -2,7 +2,6 @@ Joke Tool Service - An example dynamic tool service. This service demonstrates the tool service integration by: -- Using the 'user' field to personalize responses - Using config params (style) to customize joke style - Using arguments (topic) to generate topic-specific jokes @@ -143,17 +142,16 @@ class Processor(DynamicToolService): super(Processor, self).__init__(**params) logger.info("Joke service initialized") - async def invoke(self, user, config, arguments): + async def invoke(self, config, arguments): """ Generate a joke based on the topic and style. Args: - user: The user requesting the joke config: Config values including 'style' (pun, dad-joke, one-liner) arguments: Arguments including 'topic' (programming, animals, food) Returns: - A personalized joke string + A joke string """ # Get style from config (default: random) style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"])) @@ -183,10 +181,9 @@ class Processor(DynamicToolService): # Pick a random joke joke = random.choice(jokes) - # Personalize the response - response = f"Hey {user}! Here's a {style} for you:\n\n{joke}" + response = f"Here's a {style} for you:\n\n{joke}" - logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}") + logger.debug(f"Generated joke: style={style}, topic={topic}") return response diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index eadd841b..7378db64 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -49,26 +49,26 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8 logging.info("Shutdown complete") -async def get_socket_manager(ctx, user): +async def get_socket_manager(ctx): lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url gateway_token = lifespan_context.gateway_token - if user in sockets: + if "default" in sockets: logging.info("Return existing socket manager") - return sockets[user] + return sockets["default"] logging.info(f"Opening socket to {websocket_url}...") # Create manager with empty pending requests manager = WebSocketManager(websocket_url, token=gateway_token) - + # Start reader task with the proper manager await manager.start() - - sockets[user] = manager + + sockets["default"] = manager logging.info("Return new socket manager") return manager @@ -372,7 +372,6 @@ class McpServer: async def graph_rag( self, question: str, - user: str | None = None, collection: str | None = None, entity_limit: int | None = None, triple_limit: int | None = None, @@ -391,7 +390,6 @@ class McpServer: Args: question: The question or query to answer using the knowledge graph. The system will find relevant entities and relationships to inform the response. - user: User identifier for access control and personalization (default: "trustgraph"). collection: Knowledge collection to query (default: "default"). Different collections may contain domain-specific knowledge. entity_limit: Maximum number of entities to retrieve during graph traversal. @@ -414,7 +412,6 @@ class McpServer: - Perform research queries across connected information """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if flow_id is None: flow_id = "default" @@ -423,7 +420,7 @@ class McpServer: logging.info("GraphRAG request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -437,7 +434,6 @@ class McpServer: "query": question } - if user: request_data["user"] = user if collection: request_data["collection"] = collection if entity_limit: request_data["entity_limit"] = entity_limit if triple_limit: request_data["triple_limit"] = triple_limit @@ -466,7 +462,6 @@ class McpServer: async def agent( self, question: str, - user: str | None = None, collection: str | None = None, flow_id: str | None = None, ctx: Context = None, @@ -481,7 +476,6 @@ class McpServer: Args: question: The question or task for the agent to solve. Can be complex queries requiring multiple steps, analysis, or tool usage. - user: User identifier for personalization and access control (default: "trustgraph"). collection: Knowledge collection the agent can access (default: "default"). Determines what information and tools are available. flow_id: Agent workflow to use (default: "default"). Different flows @@ -501,7 +495,6 @@ class McpServer: through log messages, so you can follow its reasoning steps. """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if flow_id is None: flow_id = "default" @@ -510,7 +503,7 @@ class McpServer: logging.info("Agent request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -524,7 +517,6 @@ class McpServer: "question": question } - if user: request_data["user"] = user if collection: request_data["collection"] = collection gen = manager.request("agent", request_data, flow_id) @@ -1143,23 +1135,18 @@ class McpServer: async def get_knowledge_cores( self, - user: str | None = None, ctx: Context = None, ) -> KnowledgeCoresResponse: """ - List all available knowledge graph cores for a user. - + List all available knowledge graph cores in the current workspace. + Knowledge cores are packaged collections of structured knowledge that can be loaded into the system for querying and reasoning. They contain entities, relationships, and facts organized as knowledge graphs. - - Args: - user: User identifier to list cores for (default: "trustgraph"). - Different users may have access to different knowledge cores. - + Returns: KnowledgeCoresResponse containing a list of available knowledge core IDs. - + Use this for: - Discovering available knowledge collections - Understanding what knowledge domains are accessible @@ -1167,14 +1154,12 @@ class McpServer: - Managing knowledge resources """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get knowledge cores request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1185,7 +1170,6 @@ class McpServer: request_data = { "operation": "list-kg-cores", - "user": user } gen = manager.request("knowledge", request_data, None) @@ -1199,40 +1183,35 @@ class McpServer: async def delete_kg_core( self, core_id: str, - user: str | None = None, ctx: Context = None, ) -> DeleteKgCoreResponse: """ Permanently delete a knowledge graph core. - + This operation removes a knowledge core from storage. Use with caution as this action cannot be undone. - + Args: core_id: Unique identifier of the knowledge core to delete. - user: User identifier (default: "trustgraph"). Only cores owned - by this user can be deleted. - + Returns: DeleteKgCoreResponse confirming the deletion. - + Use this for: - Cleaning up obsolete knowledge cores - Removing test or experimental data - Managing storage space - Maintaining organized knowledge collections - + Warning: This permanently deletes the knowledge core and all its data. """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Delete KG core request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1244,7 +1223,6 @@ class McpServer: request_data = { "operation": "delete-kg-core", "id": core_id, - "user": user } gen = manager.request("knowledge", request_data, None) @@ -1258,27 +1236,25 @@ class McpServer: self, core_id: str, flow: str, - user: str | None = None, collection: str | None = None, ctx: Context = None, ) -> LoadKgCoreResponse: """ Load a knowledge graph core into the active system for querying. - + This operation makes a knowledge core available for GraphRAG queries, triple searches, and other knowledge-based operations. - + Args: core_id: Unique identifier of the knowledge core to load. flow: Processing flow to use for loading the core. Different flows may apply different processing, indexing, or optimization steps. - user: User identifier (default: "trustgraph"). collection: Target collection name (default: "default"). The loaded knowledge will be available under this collection name. - + Returns: LoadKgCoreResponse confirming the core has been loaded. - + Use this for: - Making knowledge cores available for queries - Switching between different knowledge domains @@ -1286,7 +1262,6 @@ class McpServer: - Preparing knowledge for GraphRAG operations """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if ctx is None: @@ -1294,7 +1269,7 @@ class McpServer: logging.info("Load KG core request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1307,7 +1282,6 @@ class McpServer: "operation": "load-kg-core", "id": core_id, "flow": flow, - "user": user, "collection": collection } @@ -1321,42 +1295,38 @@ class McpServer: async def get_kg_core( self, core_id: str, - user: str | None = None, ctx: Context = None, ) -> GetKgCoreResponse: """ Download and retrieve the complete content of a knowledge graph core. - + This tool streams the entire content of a knowledge core, returning all entities, relationships, and metadata. Due to potentially large data sizes, the content is streamed in chunks. - + Args: core_id: Unique identifier of the knowledge core to retrieve. - user: User identifier (default: "trustgraph"). - + Returns: GetKgCoreResponse containing all chunks of the knowledge core data. Each chunk contains part of the knowledge graph structure. - + Use this for: - Examining knowledge core content and structure - Debugging knowledge graph data - Exporting knowledge for backup or analysis - Understanding the scope and quality of knowledge - + Note: Large knowledge cores may take significant time to download. Progress updates are provided through log messages during streaming. """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get KG core request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1368,7 +1338,6 @@ class McpServer: request_data = { "operation": "get-kg-core", "id": core_id, - "user": user } # Collect all streaming responses @@ -1713,27 +1682,22 @@ class McpServer: async def get_documents( self, - user: str | None = None, ctx: Context = None, ) -> DocumentsResponse: """ List all documents stored in the TrustGraph document library. - + This tool returns metadata for all documents that have been uploaded to the system, including their processing status and properties. - - Args: - user: User identifier to list documents for (default: "trustgraph"). - Only documents owned by this user will be returned. - + Returns: DocumentsResponse containing metadata for each document including: - Document ID and title - - Upload timestamp and user + - Upload timestamp - MIME type and size information - Tags and custom metadata - Processing status - + Use this for: - Browsing available documents - Managing document collections @@ -1741,14 +1705,12 @@ class McpServer: - Auditing document storage """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get documents request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1759,7 +1721,6 @@ class McpServer: request_data = { "operation": "list-documents", - "user": user } gen = manager.request("librarian", request_data, None) @@ -1772,26 +1733,21 @@ class McpServer: async def get_processing( self, - user: str | None = None, ctx: Context = None, ) -> ProcessingResponse: """ List all documents currently in the processing queue. - + This tool shows documents that are being processed or waiting to be processed, along with their processing status and configuration. - - Args: - user: User identifier (default: "trustgraph"). Only processing - jobs for this user will be returned. - + Returns: ProcessingResponse containing processing metadata including: - Processing job ID and document ID - Processing flow and status - - Target collection and user + - Target collection - Timestamp and progress information - + Use this for: - Monitoring document processing progress - Debugging processing issues @@ -1799,14 +1755,12 @@ class McpServer: - Understanding system workload """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get processing request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1817,7 +1771,6 @@ class McpServer: request_data = { "operation": "list-processing", - "user": user } gen = manager.request("librarian", request_data, None) @@ -1837,16 +1790,15 @@ class McpServer: title: str = "", comments: str = "", tags: List[str] | None = None, - user: str | None = None, ctx: Context = None, ) -> LoadDocumentResponse: """ Upload a document to the TrustGraph document library. - + This tool stores documents with rich metadata for later processing, search, and knowledge extraction. Documents can be text files, PDFs, or other supported formats. - + Args: document: The document content as a string. For binary files, this should be base64-encoded content. @@ -1856,11 +1808,10 @@ class McpServer: title: Human-readable title for the document. comments: Optional description or notes about the document. tags: List of tags for categorizing and finding the document. - user: User identifier (default: "trustgraph"). - + Returns: LoadDocumentResponse confirming the document has been stored. - + Use this for: - Adding new documents to the knowledge base - Storing reference materials and data sources @@ -1868,7 +1819,6 @@ class McpServer: - Importing external content for analysis """ - if user is None: user = "trustgraph" if tags is None: tags = [] if ctx is None: @@ -1876,7 +1826,7 @@ class McpServer: logging.info("Load document request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1897,7 +1847,6 @@ class McpServer: "title": title, "comments": comments, "metadata": metadata, - "user": user, "tags": tags }, "content": document @@ -1913,40 +1862,35 @@ class McpServer: async def remove_document( self, document_id: str, - user: str | None = None, ctx: Context = None, ) -> RemoveDocumentResponse: """ Permanently remove a document from the library. - + This operation deletes a document and all its associated metadata. Use with caution as this action cannot be undone. - + Args: document_id: Unique identifier of the document to remove. - user: User identifier (default: "trustgraph"). Only documents - owned by this user can be removed. - + Returns: RemoveDocumentResponse confirming the document has been deleted. - + Use this for: - Cleaning up obsolete or incorrect documents - Managing storage space - Removing sensitive or inappropriate content - Maintaining organized document collections - + Warning: This permanently deletes the document and all its metadata. """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Remove document request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1958,7 +1902,6 @@ class McpServer: request_data = { "operation": "remove-document", "document-id": document_id, - "user": user } gen = manager.request("librarian", request_data, None) @@ -1973,42 +1916,39 @@ class McpServer: processing_id: str, document_id: str, flow: str, - user: str | None = None, collection: str | None = None, tags: List[str] | None = None, ctx: Context = None, ) -> AddProcessingResponse: """ Queue a document for processing through a specific workflow. - + This tool adds a document to the processing queue where it will be processed by the specified flow to extract knowledge, create embeddings, or perform other analysis operations. - + Args: processing_id: Unique identifier for this processing job. document_id: ID of the document to process (must exist in library). flow: Processing flow to use. Different flows perform different types of analysis (e.g., knowledge extraction, summarization). - user: User identifier (default: "trustgraph"). collection: Target collection for processed knowledge (default: "default"). Results will be stored under this collection name. tags: Optional tags for categorizing this processing job. - + Returns: AddProcessingResponse confirming the document has been queued. - + Use this for: - Processing uploaded documents into knowledge - Extracting entities and relationships from text - Creating searchable embeddings - Converting documents into structured knowledge - + Note: Processing may take time depending on document size and flow complexity. Use get_processing to monitor progress. """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if tags is None: tags = [] @@ -2017,7 +1957,7 @@ class McpServer: logging.info("Add processing request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -2036,7 +1976,6 @@ class McpServer: "document-id": document_id, "time": timestamp, "flow": flow, - "user": user, "collection": collection, "tags": tags } diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index 4844b104..9d955d17 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -91,7 +91,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -106,7 +106,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): content = content.encode('utf-8') @@ -141,7 +141,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -163,7 +163,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -175,7 +174,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py index 6b7d0246..b3723655 100644 --- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py +++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py @@ -275,7 +275,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=doc_id, parent_id=parent_doc_id, - user=metadata.user, + workspace=flow.workspace, content=page_content, document_type="page" if is_page else "section", title=label, @@ -303,7 +303,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=entity_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -314,7 +313,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=entity_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), document_id=doc_id, @@ -356,7 +354,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=img_uri, parent_id=parent_doc_id, - user=metadata.user, + workspace=flow.workspace, content=img_content, document_type="image", title=f"Image from page {page_number}" if page_number else "Image", @@ -379,7 +377,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=img_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -404,13 +401,13 @@ class Processor(FlowProcessor): doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) mime_type = doc_meta.kind if doc_meta else None content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): From 8be128aa59561f28ecfa9bd2df23e48afe261bad Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 22 Apr 2026 12:05:24 +0100 Subject: [PATCH 08/21] fix: api-gateway evicts cached dispatchers when a flow stops (#841) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DispatcherManager caches one ServiceRequestor per (flow_id, kind) in self.dispatchers, lazily created on first use. stop_flow dropped the flow from self.flows but never touched the cached dispatchers, so their publisher/subscriber connections persisted — bound to the per-flow exchanges that flow-svc tears down when the flow stops. If the same flow id was later re-created, flow-svc re-declared fresh per-flow exchanges, but the gateway's cached dispatcher still held a subscription queue bound to the now-gone old response exchange. Requests went out fine (publishers target exchanges by name and the new exchange has the right name), but responses landed on an exchange with no binding to the dispatcher's queue and were silently dropped. The calling CLI or websocket session hung waiting for a reply that would never arrive. Reproduction before fix: tg-start-flow -i test-flow-1 ... # any query on test-flow-1 works tg-stop-flow -i test-flow-1 tg-start-flow -i test-flow-1 ... tg-show-graph -f test-flow-1 -C # hangs Flows that were never stopped (e.g. "default" in a typical session) were unaffected — their cached dispatcher still pointed at live plumbing. That's why the bug appeared flow-name-specific at first glance; it's actually lifecycle-specific. Fix: in stop_flow, evict and cleanly stop() every cached dispatcher keyed on the stopped flow id. Next request after restart constructs a fresh dispatcher against the freshly-declared exchanges. Tuple shape check preserves global dispatchers, which use (None, kind) as their key and must survive flow churn. Uses pop(id, None) instead of del in case stop_flow is invoked defensively for a flow the gateway never saw. --- .../trustgraph/gateway/dispatch/manager.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index f3db3290..b238bb5b 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -129,7 +129,35 @@ class DispatcherManager: async def stop_flow(self, workspace, id, flow): logger.info(f"Stopping flow {workspace}/{id}") - del self.flows[(workspace, id)] + self.flows.pop((workspace, id), None) + + # Drop any cached dispatchers for this (workspace, flow). + # Their publishers and subscribers were wired to the flow's + # per-flow exchanges, which flow-svc tears down when the flow + # stops. Leaving the cached dispatcher in place means a + # subsequent restart of the same flow id would reuse a + # dispatcher whose subscription queue is still bound to the + # torn-down (now re-created) response exchange — requests go + # out but responses are silently dropped and the caller hangs. + # + # Per-flow dispatchers are keyed (workspace, flow_id, kind). + # Global dispatchers are keyed (None, kind) — the len==3 + # check naturally excludes them. + async with self.dispatcher_lock: + stale_keys = [ + k for k in self.dispatchers + if isinstance(k, tuple) and len(k) == 3 + and k[0] == workspace and k[1] == id + ] + for key in stale_keys: + dispatcher = self.dispatchers.pop(key) + try: + await dispatcher.stop() + except Exception as e: + logger.warning( + f"Error stopping cached dispatcher {key}: {e}" + ) + return def dispatch_global_service(self): From 6cbaf88fc6415476acc02892d3dc00643cee7901 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 22 Apr 2026 12:05:47 +0100 Subject: [PATCH 09/21] fix: ontology extractor reads .objects, not .object, from PromptResult (#842) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The extract-with-ontologies prompt is a JSONL prompt, which means the prompt service returns a PromptResult with response_type="jsonl" and the parsed items in `.objects` (plural). The ontology extractor was reading `.object` (singular) — the field used for response_type="json" — which is always None for JSONL prompts. Effect: the parser received None on every chunk, hit its "Unexpected response type: " branch, returned no ExtractionResult, and extract_with_simplified_format returned []. Every extraction silently produced zero triples. Graphs populated only with the seed ontology schema (TBox) and document/chunk provenance — no instance triples at all. The e2e test threshold of >=100 edges per collection was met by schema + provenance alone, so the failure mode was invisible until RAG queries couldn't find any content. Regression introduced in v2.3 with the token-usage work (commit 56d700f3 / 14e49d83) when PromptClient.prompt() began returning a PromptResult wrapper instead of the raw text/dict/list. All other call sites of .prompt() across retrieval/, agent/, orchestrator/ were already reading the correct field for their prompt's response_type; ontology extraction was the sole stranded caller. Also adds tests/unit/test_extract/test_ontology/test_extract_with_simplified_format.py covering: - happy path: populated .objects produces non-empty triples - production failure shape: .objects=None returns [] cleanly - empty .objects returns [] without raising - defensive: do not silently fall back to .object for a JSONL prompt --- .../test_extract_with_simplified_format.py | 200 ++++++++++++++++++ .../trustgraph/extract/kg/ontology/extract.py | 8 +- 2 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_extract/test_ontology/test_extract_with_simplified_format.py diff --git a/tests/unit/test_extract/test_ontology/test_extract_with_simplified_format.py b/tests/unit/test_extract/test_ontology/test_extract_with_simplified_format.py new file mode 100644 index 00000000..7130bd73 --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_extract_with_simplified_format.py @@ -0,0 +1,200 @@ +""" +Unit tests for extract_with_simplified_format. + +Regression guard for the bug where the extractor read +``result.object`` (singular, used for response_type="json") instead of +``result.objects`` (plural, used for response_type="jsonl"). The +extract-with-ontologies prompt is JSONL, so reading the wrong field +silently dropped every extraction and left the knowledge graph +populated only by ontology schema + document provenance. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.extract.kg.ontology.extract import Processor +from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset +from trustgraph.base import PromptResult + + +@pytest.fixture +def extractor(): + """Create a Processor instance without running its heavy __init__. + + Matches the pattern used in test_prompt_and_extraction.py: only + the attributes the code under test touches need to be set. + """ + ex = object.__new__(Processor) + ex.URI_PREFIXES = { + "rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#", + "rdfs:": "http://www.w3.org/2000/01/rdf-schema#", + "owl:": "http://www.w3.org/2002/07/owl#", + "xsd:": "http://www.w3.org/2001/XMLSchema#", + } + return ex + + +@pytest.fixture +def food_subset(): + """A minimal food ontology subset the extracted entities reference.""" + return OntologySubset( + ontology_id="food", + classes={ + "Recipe": { + "uri": "http://purl.org/ontology/fo/Recipe", + "type": "owl:Class", + "labels": [{"value": "Recipe", "lang": "en-gb"}], + "comment": "A Recipe.", + }, + "Food": { + "uri": "http://purl.org/ontology/fo/Food", + "type": "owl:Class", + "labels": [{"value": "Food", "lang": "en-gb"}], + "comment": "A Food.", + }, + }, + object_properties={ + "ingredients": { + "uri": "http://purl.org/ontology/fo/ingredients", + "type": "owl:ObjectProperty", + "labels": [{"value": "ingredients", "lang": "en-gb"}], + "comment": "Relates a recipe to its ingredients.", + "domain": "Recipe", + "range": "Food", + }, + }, + datatype_properties={}, + metadata={ + "name": "Food Ontology", + "namespace": "http://purl.org/ontology/fo/", + }, + ) + + +def _flow_with_prompt_result(prompt_result): + """Build the ``flow(name)`` callable the extractor invokes. + + ``extract_with_simplified_format`` calls + ``flow("prompt-request").prompt(...)`` — so we need ``flow`` to be + callable, return an object whose ``.prompt`` is an AsyncMock that + resolves to ``prompt_result``. + """ + prompt_service = MagicMock() + prompt_service.prompt = AsyncMock(return_value=prompt_result) + + def flow(name): + assert name == "prompt-request", ( + f"extractor should only invoke flow('prompt-request'), " + f"got {name!r}" + ) + return prompt_service + + return flow, prompt_service.prompt + + +class TestReadsObjectsForJsonlPrompt: + """extract-with-ontologies is a JSONL prompt; the extractor must + read ``result.objects``, not ``result.object``.""" + + async def test_populated_objects_produces_triples( + self, extractor, food_subset, + ): + """Happy path: PromptResult with populated .objects -> non-empty + triples list.""" + + prompt_result = PromptResult( + response_type="jsonl", + objects=[ + {"type": "entity", "entity": "Cornish Pasty", + "entity_type": "Recipe"}, + {"type": "entity", "entity": "beef", + "entity_type": "Food"}, + {"type": "relationship", + "subject": "Cornish Pasty", "subject_type": "Recipe", + "relation": "ingredients", + "object": "beef", "object_type": "Food"}, + ], + ) + + flow, prompt_mock = _flow_with_prompt_result(prompt_result) + + triples = await extractor.extract_with_simplified_format( + flow, "some chunk", food_subset, {"text": "some chunk"}, + ) + + prompt_mock.assert_awaited_once() + assert triples, ( + "extract_with_simplified_format returned no triples; if " + "this fails, the extractor is probably reading .object " + "instead of .objects again" + ) + + async def test_none_objects_returns_empty_without_crashing( + self, extractor, food_subset, + ): + """The exact shape that hit production on v2.3: the extractor + was reading ``.object`` for a JSONL prompt, which returned + ``None`` and tripped the parser's 'Unexpected response type' + path. With the fix we read ``.objects``; if that's also + ``None`` we must still return ``[]`` cleanly, not crash.""" + + prompt_result = PromptResult( + response_type="jsonl", + objects=None, + ) + + flow, _ = _flow_with_prompt_result(prompt_result) + + triples = await extractor.extract_with_simplified_format( + flow, "chunk", food_subset, {"text": "chunk"}, + ) + + assert triples == [] + + async def test_empty_objects_returns_empty( + self, extractor, food_subset, + ): + """Valid JSONL response with zero entries should yield zero + triples, not raise.""" + + prompt_result = PromptResult( + response_type="jsonl", + objects=[], + ) + + flow, _ = _flow_with_prompt_result(prompt_result) + + triples = await extractor.extract_with_simplified_format( + flow, "chunk", food_subset, {"text": "chunk"}, + ) + + assert triples == [] + + async def test_ignores_object_field_for_jsonl_prompt( + self, extractor, food_subset, + ): + """If ``.object`` is somehow set but ``.objects`` is None, the + extractor must not silently fall back to ``.object``. This + guards against a well-meaning regression that "helpfully" + re-adds fallback fields. + + The extractor should read only ``.objects`` for this prompt; + when that is None we expect the empty-result path. + """ + + prompt_result = PromptResult( + response_type="json", + object={"not": "the field we should be reading"}, + objects=None, + ) + + flow, _ = _flow_with_prompt_result(prompt_result) + + triples = await extractor.extract_with_simplified_format( + flow, "chunk", food_subset, {"text": "chunk"}, + ) + + assert triples == [], ( + "Extractor fell back to .object for a JSONL prompt — " + "this is the regression shape we are trying to prevent" + ) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index a05f4dfe..ef9a7331 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -412,7 +412,13 @@ class Processor(FlowProcessor): id="extract-with-ontologies", variables=prompt_variables ) - extraction_response = result.object + + # extract-with-ontologies is a JSONL prompt, so PromptResult + # always populates .objects (a list of dicts). Reading .object + # (singular) silently gives None for JSONL responses and drops + # every extraction. + extraction_response = result.objects + logger.debug(f"Simplified extraction response: {extraction_response}") # Parse response into structured format From 7521e152b9d66ed5724cbe919aaa9667f857341e Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 22 Apr 2026 15:17:17 +0100 Subject: [PATCH 10/21] fix: uuid-ify flow-svc ConfigClient subscription to avoid Pulsar ConsumerBusy on restart (#843) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit flow-svc's long-lived ConfigClient was constructed with subscription=f"{self.id}--config--{id}", where id=params.get("id") is the deterministic processor id. On Pulsar the config-response topic maps to class=response -> Exclusive subscription; when the supervisor restarts flow-svc within Pulsar's inactive-subscription TTL (minutes), the previous process's ghost consumer still holds the subscription and the new process's re-subscribe is rejected with ConsumerBusy, crash-looping flow-svc. This is a v2.2 -> v2.3 regression in practice, but not a change in subscription semantics: the Exclusive mapping for response/notify is identical between releases. The regression is that PR #822 split flow-svc out of config-svc and added this new, long-lived request/response call site — the new site simply didn't follow the uuid convention used by the equivalent sites elsewhere (gateway/config/receiver.py, AsyncProcessor._create_config_client). Fix: generate a fresh uuid per process instance for the subscription suffix, matching that convention. --- trustgraph-flow/trustgraph/flow/service/service.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/trustgraph-flow/trustgraph/flow/service/service.py b/trustgraph-flow/trustgraph/flow/service/service.py index 74077ccb..3dacf47d 100644 --- a/trustgraph-flow/trustgraph/flow/service/service.py +++ b/trustgraph-flow/trustgraph/flow/service/service.py @@ -5,6 +5,7 @@ by coordinating with the config service via pub/sub. """ import logging +import uuid from trustgraph.schema import Error @@ -83,9 +84,17 @@ class Processor(AsyncProcessor): processor=self.id, flow=None, name="config-response", ) + # Unique subscription suffix per process instance. Pulsar's + # exclusive subscriptions reject a second consumer on the same + # (topic, subscription-name) — so a deterministic name here + # collides with its own ghost when the supervisor restarts the + # process before Pulsar has timed out the previous session + # (ConsumerBusy). Matches the uuid convention used elsewhere + # (gateway/config/receiver.py, AsyncProcessor._create_config_client). + config_rr_id = str(uuid.uuid4()) self.config_client = ConfigClient( backend=self.pubsub, - subscription=f"{self.id}--config--{id}", + subscription=f"{self.id}--config--{config_rr_id}", consumer_name=self.id, request_topic=config_request_queue, request_schema=ConfigRequest, From 95c3b62ef12833977d9da8f2990654780083b9ed Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 22 Apr 2026 16:16:36 +0100 Subject: [PATCH 11/21] Correct output from tg-show-flow-state and tg-show-processor-state (#846) --- .../trustgraph/cli/show_flow_state.py | 20 ++++++++++--------- .../trustgraph/cli/show_processor_state.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/trustgraph-cli/trustgraph/cli/show_flow_state.py b/trustgraph-cli/trustgraph/cli/show_flow_state.py index 8fec04ec..3a733270 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_state.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_state.py @@ -44,16 +44,18 @@ def show_processors(metrics_url, flow_label): obj = resp.json() - tbl = [ - [ - m["metric"]["job"], - "\U0001f49a" if int(m["value"][1]) > 0 else "\U0000274c" - ] - for m in obj["data"]["result"] - ] + # consumer_state is one sample per consumer (queue); a processor + # with N subscriptions shows up N times. Aggregate to one row per + # processor: green only if every consumer is running. + by_proc = {} + for m in obj["data"]["result"]: + name = m["metric"].get("processor", m["metric"]["job"]) + running = int(m["value"][1]) > 0 + by_proc[name] = by_proc.get(name, True) and running - for row in tbl: - print(f"- {row[0]:30} {row[1]}") + for name in sorted(by_proc): + icon = "\U0001f49a" if by_proc[name] else "\U0000274c" + print(f"- {name:30} {icon}") def main(): diff --git a/trustgraph-cli/trustgraph/cli/show_processor_state.py b/trustgraph-cli/trustgraph/cli/show_processor_state.py index b4ae4a16..9de05bc6 100644 --- a/trustgraph-cli/trustgraph/cli/show_processor_state.py +++ b/trustgraph-cli/trustgraph/cli/show_processor_state.py @@ -17,7 +17,7 @@ def dump_status(url): tbl = [ [ - m["metric"]["job"], + m["metric"].get("processor", m["metric"]["job"]), "\U0001f49a" ] for m in obj["data"]["result"] From 31027e30ae65eda82cbbc2565a8dccf26697fb9d Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 22 Apr 2026 16:16:57 +0100 Subject: [PATCH 12/21] Better reporting from api-gateway's metric endpoint (#845) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Connect failures (DNS, connect refused, server disconnect) now return 502 Bad Gateway with a body that names the upstream URL. - Other exceptions still return 500 but now include the exception message in the body and log with exc_info=True so the stack trace lands in the gateway log. - Also fixed the logging.error → logger.error inconsistency in the same block (module had a named logger at the top that wasn't being used). --- .../trustgraph/gateway/endpoint/metrics.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py index d17d111b..903a199c 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py @@ -47,26 +47,37 @@ class MetricsEndpoint: if not self.auth.permitted(token, self.operation): return web.HTTPUnauthorized() + path = request.match_info["path"] + url = ( + self.prometheus_url + "/api/v1/" + path + "?" + + request.query_string + ) + try: - path = request.match_info["path"] - async with aiohttp.ClientSession() as session: - - url = ( - self.prometheus_url + "/api/v1/" + path + "?" + - request.query_string - ) - async with session.get(url) as resp: return web.Response( status=resp.status, text=await resp.text() ) + except aiohttp.ClientConnectionError as e: + + # Upstream unreachable (connect refused, DNS failure, + # server disconnect). Distinguish from our own errors so + # callers know where the fault is. + logger.error(f"Metrics upstream {url} unreachable: {e}") + return web.Response( + status=502, + text=f"Bad Gateway: metrics upstream unreachable: {e}", + ) + except Exception as e: - logging.error(f"Exception: {e}") - - raise web.HTTPInternalServerError() + logger.error(f"Metrics proxy exception: {e}", exc_info=True) + return web.Response( + status=500, + text=f"Internal Server Error: {e}", + ) From ae9936c9ccffa96596ba6b4d03dfb7b8ad9044c7 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 22 Apr 2026 18:03:46 +0100 Subject: [PATCH 13/21] feat: pluggable bootstrap framework with ordered initialisers (#847) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A generic, long-running bootstrap processor that converges a deployment to its configured initial state and then idles. Replaces the previous one-shot `tg-init-trustgraph` container model and provides an extension point for enterprise / third-party initialisers. See docs/tech-specs/bootstrap.md for the full design. Bootstrapper ------------ A single AsyncProcessor (trustgraph.bootstrap.bootstrapper.Processor) that: * Reads a list of initialiser specifications (class, name, flag, params) from either a direct `initialisers` parameter (processor-group embedding) or a YAML/JSON file (`-c`, CLI). * On each wake, runs a cheap service-gate (config-svc + flow-svc round-trips), then iterates the initialiser list, running each whose configured flag differs from the one stored in __system__/init-state/. * Stores per-initialiser completion state in the reserved __system__ workspace. * Adapts cadence: ~5s on gate failure, ~15s while converging, ~300s in steady state. * Isolates failures — one initialiser's exception does not block others in the same cycle; the failed one retries next wake. Initialiser contract -------------------- * Subclass trustgraph.bootstrap.base.Initialiser. * Implement async run(ctx, old_flag, new_flag). * Opt out of the service gate with class attr wait_for_services=False (only used by PulsarTopology, since config-svc cannot come up until Pulsar namespaces exist). * ctx carries short-lived config and flow-svc clients plus a scoped logger. Core initialisers (trustgraph.bootstrap.initialisers.*) ------------------------------------------------------- * PulsarTopology — creates Pulsar tenant + namespaces (pre-gate, blocking HTTP offloaded to executor). * TemplateSeed — seeds __template__ from an external JSON file; re-run is upsert-missing by default, overwrite-all opt-in. * WorkspaceInit — populates a named workspace from either the full contents of __template__ or a seed file; raises cleanly if the template isn't seeded yet so the bootstrapper retries on the next cycle. * DefaultFlowStart — starts a specific flow in a workspace; no-ops if the flow is already running. Enterprise or third-party initialisers plug in via fully-qualified dotted class paths in the bootstrapper's configuration — no core code change required. Config service -------------- * push(): filter out reserved workspaces (ids starting with "_") from the change notifications. Stored config is preserved; only the broadcast is suppressed, so bootstrap / template state lives in config-svc without live processors ever reacting to it. Config client ------------- * ConfigClient.get_all(workspace): wraps the existing `config` operation to return {type: {key: value}} for a workspace. WorkspaceInit uses it to copy __template__ without needing a hardcoded types list. pyproject.toml -------------- * Adds a `bootstrap` console script pointing at the new Processor. * Remove tg-init-trustgraph, superceded by bootstrap processor --- docs/tech-specs/bootstrap.md | 297 +++++++++++++ docs/tech-specs/iam.md | 1 - .../trustgraph/base/config_client.py | 12 + trustgraph-cli/pyproject.toml | 1 - .../trustgraph/cli/init_trustgraph.py | 271 ------------ trustgraph-flow/pyproject.toml | 1 + .../trustgraph/bootstrap/__init__.py | 0 trustgraph-flow/trustgraph/bootstrap/base.py | 68 +++ .../bootstrap/bootstrapper/__init__.py | 1 + .../bootstrap/bootstrapper/__main__.py | 6 + .../bootstrap/bootstrapper/service.py | 414 ++++++++++++++++++ .../bootstrap/initialisers/__init__.py | 20 + .../initialisers/default_flow_start.py | 101 +++++ .../bootstrap/initialisers/pulsar_topology.py | 131 ++++++ .../bootstrap/initialisers/template_seed.py | 93 ++++ .../bootstrap/initialisers/workspace_init.py | 138 ++++++ .../trustgraph/config/service/service.py | 30 ++ 17 files changed, 1312 insertions(+), 273 deletions(-) create mode 100644 docs/tech-specs/bootstrap.md delete mode 100644 trustgraph-cli/trustgraph/cli/init_trustgraph.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/__init__.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/base.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/bootstrapper/__init__.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/bootstrapper/__main__.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/initialisers/__init__.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/initialisers/template_seed.py create mode 100644 trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py diff --git a/docs/tech-specs/bootstrap.md b/docs/tech-specs/bootstrap.md new file mode 100644 index 00000000..af7387d1 --- /dev/null +++ b/docs/tech-specs/bootstrap.md @@ -0,0 +1,297 @@ +--- +layout: default +title: "Bootstrap Framework Technical Specification" +parent: "Tech Specs" +--- + +# Bootstrap Framework Technical Specification + +## Overview + +A generic, pluggable framework for running one-time initialisation steps +against a TrustGraph deployment — replacing the dedicated +`tg-init-trustgraph` container with a long-running processor that +converges the system to a desired initial state and then idles. + +The framework is content-agnostic. It knows how to run, retry, +mark-as-done, and surface failures; the actual init work lives in +small pluggable classes called **initialisers**. Core initialisers +ship in the `trustgraph-flow` package; enterprise and third-party +initialisers can be loaded by dotted path without any core code +change. + +## Motivation + +The existing `tg-init-trustgraph` is a one-shot CLI run in its own +container. It performs two very different jobs (Pulsar topology +setup and config seeding) in a single script, is wasteful as a whole +container, cannot handle partial-success states, and has no way to +extend the boot process with enterprise-specific concerns (user +provisioning, workspace initialisation, IAM scaffolding) without +forking the tool. + +A pluggable, long-running reconciler addresses all of this and slots +naturally into the existing processor-group model. + +## Design + +### Bootstrapper Processor + +A single `AsyncProcessor` subclass. One entry in a processor group. +Parameters include the processor's own identity and a list of +**initialiser specifications** — each spec names a class (by dotted +path), a unique instance name, a flag string, and the parameters +that will be passed to the initialiser's constructor. + +On each wake the bootstrapper does the following, in order: + +1. Open a short-lived context (config client, flow-svc client, + logger). The context is torn down at the end of the wake so + steady-state idle cost is effectively nil. +2. Run all **pre-service initialisers** (those that opt out of the + service gate — principally `PulsarTopology`, which must run + before the services it gates on can even come up). +3. Check the **service gate**: cheap round-trips to config-svc and + flow-svc. If either fails, skip to the sleep step using the + short gate-retry cadence. +4. Run all **post-service initialisers** that haven't already + completed at the currently-configured flag. +5. Sleep. Cadence adapts to state (see below). + +### Initialiser Contract + +An initialiser is a class with: + +- A class-level `name` identifier, unique within the bootstrapper's + configuration. This is the key under which completion state is + stored. +- A class-level `wait_for_services` flag. When `True` (the default) + the initialiser runs only after the service gate passes. When + `False`, it runs before the gate, on every wake. +- A constructor that accepts the initialiser's own params as kwargs. +- An async `run(ctx, old_flag, new_flag)` method that performs the + init work and returns on success. Any raised exception is + logged and treated as a transient failure — the stored flag is + not updated and the initialiser will re-run on the next cycle. + +`old_flag` is the previously-stored flag string, or `None` if the +initialiser has never successfully run in this deployment. `new_flag` +is the flag the operator has configured for this run. This pair +lets an initialiser distinguish a clean first-run from a migration +between flag versions and behave accordingly (see "Flag change and +re-run safety" below). + +### Context + +The context is the bootstrapper-owned object passed to every +initialiser's `run()` method. Its fields are deliberately narrow: + +| Field | Purpose | +|---|---| +| `logger` | A child logger named for the initialiser instance | +| `config` | A short-lived `ConfigClient` for config-svc reads/writes | +| `flow` | A short-lived `RequestResponse` client for flow-svc | + +The context is always fully-populated regardless of which services +a given initialiser uses, for symmetry. Additional fields may be +added in future without breaking existing initialisers. Clients are +started at the beginning of a wake cycle and stopped at the end. + +Initialisers that need services beyond config-svc and flow-svc are +responsible for their own readiness checks and for raising cleanly +when a prerequisite is not met. + +### Completion State + +Per-initialiser completion state is stored in the reserved +`__system__` workspace, under a dedicated config type for bootstrap +state. The stored value is the flag string that was configured when +the initialiser last succeeded. + +On each cycle, for each initialiser, the bootstrapper reads the +stored flag and compares it to the currently-configured flag. If +they match, the initialiser is skipped silently. If they differ, +the initialiser runs; on success, the stored flag is updated. + +Because the state lives in a reserved (`_`-prefixed) workspace, it +is stored by config-svc but excluded from the config push broadcast. +Live processors never see it and cannot act on it. + +### The Service Gate + +The gate is a cheap, bootstrapper-internal check that config-svc +and flow-svc are both reachable and responsive. It is intentionally +a simple pair of low-cost round-trips — a config list against +`__system__` and a flow-svc `list-blueprints` — rather than any +deeper health check. + +Its purpose is to avoid filling logs with noise and to concentrate +retry effort during the brief window when services are coming up. +The gate is applied only to initialisers with +`wait_for_services=True` (the default); `False` is reserved for +initialisers that set up infrastructure the gate itself depends on. + +### Adaptive Cadence + +The sleep between wake cycles is chosen from three tiers based on +observed state: + +| Tier | Duration | When | +|---|---|---| +| Gate backoff | ~5 s | Services not responding — concentrate retry during startup | +| Init retry | ~15 s | Gate passes but at least one initialiser is not yet at its configured flag — transient failures, waiting on prereqs, recently-bumped flag not yet applied | +| Steady | ~300 s | All configured initialisers at their configured flag; gate passes; nothing to do | + +The short tiers ensure a fresh deployment converges quickly; +steady state costs a single round-trip per initialiser every few +minutes. + +### Failure Handling + +An initialiser raising an exception does not stop the bootstrapper +or block other initialisers. Each initialiser in the cycle is +attempted independently; failures are logged and retried on the next +cycle. This means there is no ordered-DAG enforcement: order of +initialisers in the configuration determines the attempt order +within a cycle, but a dependency between two initialisers is +expressed by the dependant raising cleanly when its prerequisite +isn't satisfied. Over successive cycles the system converges. + +### Flag Change and Re-run Safety + +Each initialiser's completion state is a string flag chosen by the +operator. Typically these follow a simple version pattern +(`v1`, `v2`, ...), but the bootstrapper imposes no format. + +Changing the flag in the group configuration causes the +corresponding initialiser to re-run on the next cycle. Initialisers +must be written so that re-running after a flag bump is safe — they +receive both the previous and the new flag and are responsible for +either cleanly re-applying the work or performing a step-change +migration from the prior state. + +This gives operators an explicit, visible mechanism for triggering +re-initialisation. Re-runs are never implicit. + +## Core Initialisers + +The following initialisers ship in `trustgraph.bootstrap.initialisers` +and cover the base deployment case. + +### PulsarTopology + +Creates the Pulsar tenant and the four namespaces +(`flow`, `request`, `response`, `notify`) with appropriate +retention policies if they don't exist. + +Opts out of the service gate (`wait_for_services = False`) because +config-svc and flow-svc cannot come online until the Pulsar +namespaces exist. + +Parameters: Pulsar admin URL, tenant name. + +Idempotent via the admin API (GET-then-PUT). Flag change causes +re-evaluation of all namespaces; any absent are created. + +### TemplateSeed + +Populates the reserved `__template__` workspace from an external +JSON seed file. The seed file has the standard shape of +`{config-type: {config-key: value}}`. + +Runs post-gate. Parameters: path to the seed file, overwrite +policy (upsert-missing only, or overwrite-all). + +On clean run, writes the whole file. On flag change, behaviour +depends on the overwrite policy — typically upsert-missing so +that operator-customised keys are preserved across seed-file +upgrades. + +### WorkspaceInit + +Creates a named workspace and populates it from the seed file or +from the full contents of the `__template__` workspace. + +Runs post-gate. Parameters: workspace name, source (seed file or +`__template__`), optional `seed_file` path, `overwrite` flag. + +When `source` is `template`, the initialiser copies every config +type and key present in `__template__` — there is no per-type +selection. Deployments that want to seed only a subset should +either curate the seed file they feed to `TemplateSeed` or use +`source: seed-file` directly here. + +Raises cleanly if its source does not exist — depends on +`TemplateSeed` having run in the same cycle or a prior one. + +### DefaultFlowStart + +Starts a specific flow in a specific workspace using a specific +blueprint. + +Runs post-gate. Parameters: workspace name, flow id, blueprint +name, description, optional parameter overrides. + +Separated from `WorkspaceInit` deliberately so that deployments +which want a workspace without an auto-started flow can simply omit +this initialiser from their bootstrap configuration. + +## Extensibility + +New initialisers are added by: + +1. Subclassing the initialiser base class. +2. Implementing `run(ctx, old_flag, new_flag)`. +3. Choosing `wait_for_services` (almost always `True`). +4. Adding an entry in the bootstrapper's configuration with the new + class's dotted path. + +No core code changes are required to add an enterprise or third-party +initialiser. Enterprise builds ship their own package with their own +initialiser classes (e.g. `CreateAdminUser`, `ProvisionWorkspaces`) +and reference them in the bootstrapper config alongside the core +initialisers. + +## Reserved Workspaces + +This specification relies on the "reserved workspace" convention: + +- Any workspace id beginning with `_` is reserved. +- Reserved workspaces are stored normally by config-svc but never + appear in the config push broadcast. +- Live processors cannot react to reserved-workspace state. + +The bootstrapper uses two reserved workspaces: + +- `__template__` — factory-default seed config, readable by + initialisers that copy-from-template. +- `__system__` — bootstrapper completion state (under the + `init-state` config type) and any other system-internal bookkeeping. + +See the reserved-workspace convention in the config service for +the general rule and its enforcement. + +## Non-Goals + +- No DAG scheduling across initialisers. Dependencies are expressed + by the dependant failing cleanly until its prerequisite is met, + and convergence over subsequent cycles. +- No parallel execution of initialisers within a cycle. A cycle runs + each initialiser sequentially. +- No implicit re-runs. Re-running an initialiser requires an explicit + flag change by the operator. +- No cross-initialiser atomicity. Each initialiser's completion is + recorded independently on its own success. + +## Operational Notes + +- Running the bootstrapper as a processor-group entry replaces the + previous `tg-init-trustgraph` container. The bootstrapper is also + CLI-invocable directly for standalone testing via + `Processor.launch(...)`. +- First-boot convergence is typically a handful of short cycles + followed by a transition to the steady cadence. Deployments + should expect the first few minutes of logs to show + initialisation activity, thereafter effective silence. +- Bumping a flag is a deliberate operational act. The log line + emitted on re-run makes the event visible for audit. diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md index 5de50749..cb1399fe 100644 --- a/docs/tech-specs/iam.md +++ b/docs/tech-specs/iam.md @@ -848,7 +848,6 @@ service, not in the config service. Reasons: - **API key scoping.** API keys could be scoped to specific collections within a workspace rather than granting workspace-wide access. To be designed when the need arises. -- **tg-init-trustgraph** only initialises a single workspace. ## References diff --git a/trustgraph-base/trustgraph/base/config_client.py b/trustgraph-base/trustgraph/base/config_client.py index 504a6d58..eb3892f8 100644 --- a/trustgraph-base/trustgraph/base/config_client.py +++ b/trustgraph-base/trustgraph/base/config_client.py @@ -84,6 +84,18 @@ class ConfigClient(RequestResponse): ) return resp.directory + async def get_all(self, workspace, timeout=CONFIG_TIMEOUT): + """Return every config entry in ``workspace`` as a nested dict + ``{type: {key: value}}``. Values are returned as the raw + strings stored by config-svc (typically JSON); callers parse + as needed. An empty dict means the workspace has no config.""" + resp = await self._request( + operation="config", + workspace=workspace, + timeout=timeout, + ) + return resp.config + async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT): """Return the set of distinct workspaces with any config of the given type.""" diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index a5738449..d316ae4f 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -40,7 +40,6 @@ tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" -tg-init-trustgraph = "trustgraph.cli.init_trustgraph:main" tg-invoke-agent = "trustgraph.cli.invoke_agent:main" tg-invoke-document-rag = "trustgraph.cli.invoke_document_rag:main" tg-invoke-graph-rag = "trustgraph.cli.invoke_graph_rag:main" diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py deleted file mode 100644 index d984f925..00000000 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ /dev/null @@ -1,271 +0,0 @@ -""" -Initialises TrustGraph pub/sub infrastructure and pushes initial config. - -For Pulsar: creates tenant, namespaces, and retention policies. -For RabbitMQ: queues are auto-declared, so only config push is needed. -""" - -import requests -import time -import argparse -import json - -from trustgraph.clients.config_client import ConfigClient -from trustgraph.base.pubsub import add_pubsub_args - -default_pulsar_admin_url = "http://pulsar:8080" -subscriber = "tg-init-pubsub" - - -def get_clusters(url): - - print("Get clusters...", flush=True) - - resp = requests.get(f"{url}/admin/v2/clusters") - - if resp.status_code != 200: raise RuntimeError("Could not fetch clusters") - - return resp.json() - -def ensure_tenant(url, tenant, clusters): - - resp = requests.get(f"{url}/admin/v2/tenants/{tenant}") - - if resp.status_code == 200: - print(f"Tenant {tenant} already exists.", flush=True) - return - - resp = requests.put( - f"{url}/admin/v2/tenants/{tenant}", - json={ - "adminRoles": [], - "allowedClusters": clusters, - } - ) - - if resp.status_code != 204: - print(resp.text, flush=True) - raise RuntimeError("Tenant creation failed.") - - print(f"Tenant {tenant} created.", flush=True) - -def ensure_namespace(url, tenant, namespace, config): - - resp = requests.get(f"{url}/admin/v2/namespaces/{tenant}/{namespace}") - - if resp.status_code == 200: - print(f"Namespace {tenant}/{namespace} already exists.", flush=True) - return - - resp = requests.put( - f"{url}/admin/v2/namespaces/{tenant}/{namespace}", - json=config, - ) - - if resp.status_code != 204: - print(resp.status_code, flush=True) - print(resp.text, flush=True) - raise RuntimeError(f"Namespace {tenant}/{namespace} creation failed.") - - print(f"Namespace {tenant}/{namespace} created.", flush=True) - -def ensure_config(config, workspace="default", **pubsub_config): - - cli = ConfigClient( - subscriber=subscriber, - workspace=workspace, - **pubsub_config, - ) - - while True: - - try: - - print("Get current config...", flush=True) - current, version = cli.config(timeout=5) - - except Exception as e: - - print("Exception:", e, flush=True) - time.sleep(2) - print("Retrying...", flush=True) - continue - - print("Current config version is", version, flush=True) - - if version != 0: - print("Already updated, not updating config. Done.", flush=True) - return - - print("Config is version 0, updating...", flush=True) - - batch = [] - - for type in config: - for key in config[type]: - print(f"Adding {type}/{key} to update.", flush=True) - batch.append({ - "type": type, - "key": key, - "value": json.dumps(config[type][key]), - }) - - try: - cli.put(batch, timeout=10) - print("Update succeeded.", flush=True) - break - except Exception as e: - print("Exception:", e, flush=True) - time.sleep(2) - print("Retrying...", flush=True) - continue - -def init_pulsar(pulsar_admin_url, tenant): - """Pulsar-specific setup: create tenant, namespaces, retention policies.""" - - clusters = get_clusters(pulsar_admin_url) - - ensure_tenant(pulsar_admin_url, tenant, clusters) - - ensure_namespace(pulsar_admin_url, tenant, "flow", {}) - - ensure_namespace(pulsar_admin_url, tenant, "request", {}) - - ensure_namespace(pulsar_admin_url, tenant, "response", { - "retention_policies": { - "retentionSizeInMB": -1, - "retentionTimeInMinutes": 3, - "subscriptionExpirationTimeMinutes": 30, - } - }) - - ensure_namespace(pulsar_admin_url, tenant, "notify", { - "retention_policies": { - "retentionSizeInMB": -1, - "retentionTimeInMinutes": 3, - "subscriptionExpirationTimeMinutes": 5, - } - }) - - -def push_config(config_json, config_file, workspace="default", - **pubsub_config): - """Push initial config if provided.""" - - if config_json is not None: - - try: - print("Decoding config...", flush=True) - dec = json.loads(config_json) - print("Decoded.", flush=True) - except Exception as e: - print("Exception:", e, flush=True) - raise e - - ensure_config(dec, workspace=workspace, **pubsub_config) - - elif config_file is not None: - - try: - print("Decoding config...", flush=True) - dec = json.load(open(config_file)) - print("Decoded.", flush=True) - except Exception as e: - print("Exception:", e, flush=True) - raise e - - ensure_config(dec, workspace=workspace, **pubsub_config) - - else: - print("No config to update.", flush=True) - - -def main(): - - parser = argparse.ArgumentParser( - prog='tg-init-trustgraph', - description=__doc__, - ) - - parser.add_argument( - '--pulsar-admin-url', - default=default_pulsar_admin_url, - help=f'Pulsar admin URL (default: {default_pulsar_admin_url})', - ) - - parser.add_argument( - '-c', '--config', - help=f'Initial configuration to load', - ) - - parser.add_argument( - '-C', '--config-file', - help=f'Initial configuration to load from file', - ) - - parser.add_argument( - '-t', '--tenant', - default="tg", - help=f'Tenant (default: tg)', - ) - - parser.add_argument( - '-w', '--workspace', - default="default", - help=f'Workspace (default: default)', - ) - - add_pubsub_args(parser) - - args = parser.parse_args() - - backend_type = args.pubsub_backend - - # Extract pubsub config from args - pubsub_config = { - k: v for k, v in vars(args).items() - if k not in ( - 'pulsar_admin_url', 'config', 'config_file', 'tenant', - 'workspace', - ) - } - - while True: - - try: - - # Pulsar-specific setup (tenants, namespaces) - if backend_type == 'pulsar': - print(flush=True) - print( - f"Initialising Pulsar at {args.pulsar_admin_url}...", - flush=True, - ) - init_pulsar(args.pulsar_admin_url, args.tenant) - else: - print(flush=True) - print( - f"Using {backend_type} backend (no admin setup needed).", - flush=True, - ) - - # Push config (works with any backend) - push_config( - args.config, args.config_file, - workspace=args.workspace, - **pubsub_config, - ) - - print("Initialisation complete.", flush=True) - break - - except Exception as e: - - print("Exception:", e, flush=True) - - print("Sleeping...", flush=True) - time.sleep(2) - print("Will retry...", flush=True) - -if __name__ == "__main__": - main() diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 8ba85adf..cc7dac63 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -60,6 +60,7 @@ agent-orchestrator = "trustgraph.agent.orchestrator:run" api-gateway = "trustgraph.gateway:run" chunker-recursive = "trustgraph.chunking.recursive:run" chunker-token = "trustgraph.chunking.token:run" +bootstrap = "trustgraph.bootstrap.bootstrapper:run" config-svc = "trustgraph.config.service:run" flow-svc = "trustgraph.flow.service:run" doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" diff --git a/trustgraph-flow/trustgraph/bootstrap/__init__.py b/trustgraph-flow/trustgraph/bootstrap/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trustgraph-flow/trustgraph/bootstrap/base.py b/trustgraph-flow/trustgraph/bootstrap/base.py new file mode 100644 index 00000000..cb022a16 --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/base.py @@ -0,0 +1,68 @@ +""" +Bootstrap framework: Initialiser base class and per-wake context. + +See docs/tech-specs/bootstrap.md for the full design. +""" + +import logging +from dataclasses import dataclass +from typing import Any + + +@dataclass +class InitContext: + """Shared per-wake context passed to each initialiser. + + The bootstrapper constructs one of these on every wake cycle, + tears it down at cycle end, and passes it into each initialiser's + ``run()`` method. Fields are short-lived and safe to use during + a single cycle only. + """ + + logger: logging.Logger + config: Any # ConfigClient + flow: Any # RequestResponse client for flow-svc + + +class Initialiser: + """Base class for bootstrap initialisers. + + Subclasses implement :meth:`run`. The bootstrapper manages + completion state, flag comparison, retry and error handling — + subclasses describe only the work to perform. + + Class attributes: + + * ``wait_for_services`` (bool, default ``True``): when ``True`` the + initialiser only runs after the bootstrapper's service gate has + passed (config-svc and flow-svc reachable). Set ``False`` for + initialisers that bring up infrastructure the gate itself + depends on — principally Pulsar topology, without which + config-svc cannot come online. + """ + + wait_for_services: bool = True + + def __init__(self, **params): + # Subclasses should consume their own params via keyword + # arguments in their own __init__ signatures. This catch-all + # is here so any kwargs that filter through unnoticed don't + # raise TypeError on construction. + pass + + async def run(self, ctx, old_flag, new_flag): + """Perform initialisation work. + + :param ctx: :class:`InitContext` with logger, config client, + flow-svc client. + :param old_flag: Previously-stored flag string, or ``None`` if + this initialiser has never successfully completed in this + deployment. + :param new_flag: Currently-configured flag. A string chosen + by the operator; typically something like ``"v1"``. + + :raises: Any exception on failure. The bootstrapper catches, + logs, and re-runs on the next cycle; completion state is + only written on clean return. + """ + raise NotImplementedError diff --git a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/__init__.py b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/__main__.py b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/__main__.py new file mode 100644 index 00000000..da5a9021 --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py new file mode 100644 index 00000000..eb6238d3 --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py @@ -0,0 +1,414 @@ +""" +Bootstrapper processor. + +Runs a pluggable list of initialisers in a reconciliation loop. +Each initialiser's completion state is recorded in the reserved +``__system__`` workspace under the ``init-state`` config type. + +See docs/tech-specs/bootstrap.md for the full design. +""" + +import asyncio +import importlib +import json +import logging +import uuid +from argparse import ArgumentParser +from dataclasses import dataclass + +from trustgraph.base import AsyncProcessor +from trustgraph.base import ProducerMetrics, SubscriberMetrics +from trustgraph.base.config_client import ConfigClient +from trustgraph.base.request_response_spec import RequestResponse +from trustgraph.schema import ( + ConfigRequest, ConfigResponse, + config_request_queue, config_response_queue, +) +from trustgraph.schema import ( + FlowRequest, FlowResponse, + flow_request_queue, flow_response_queue, +) + +from .. base import Initialiser, InitContext + +logger = logging.getLogger(__name__) + +default_ident = "bootstrap" + +# Reserved workspace + config type under which completion state is +# stored. Reserved (`_`-prefix) workspaces are excluded from the +# config push broadcast — live processors never see these keys. +SYSTEM_WORKSPACE = "__system__" +INIT_STATE_TYPE = "init-state" + +# Cadence tiers. +GATE_BACKOFF = 5 # Services not responding; retry soon. +INIT_RETRY = 15 # Gate passed but something ran/failed; + # converge quickly. +STEADY_INTERVAL = 300 # Everything at target flag; idle cheaply. + + +@dataclass +class InitialiserSpec: + """One entry in the bootstrapper's configured list of initialisers.""" + name: str + flag: str + instance: Initialiser + + +def _resolve_class(dotted): + """Import and return a class by its dotted path.""" + module_path, _, class_name = dotted.rpartition(".") + if not module_path: + raise ValueError( + f"Initialiser class must be a dotted path, got {dotted!r}" + ) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def _load_initialisers_file(path): + """Load the initialisers spec list from a YAML or JSON file. + + File shape: + + .. code-block:: yaml + + initialisers: + - class: trustgraph.bootstrap.initialisers.PulsarTopology + name: pulsar-topology + flag: v1 + params: + admin_url: http://pulsar:8080 + tenant: tg + - ... + """ + with open(path) as f: + content = f.read() + if path.endswith((".yaml", ".yml")): + import yaml + doc = yaml.safe_load(content) + else: + doc = json.loads(content) + if not isinstance(doc, dict) or "initialisers" not in doc: + raise RuntimeError( + f"{path}: expected a mapping with an 'initialisers' key" + ) + return doc["initialisers"] + + +class Processor(AsyncProcessor): + + def __init__(self, **params): + + super().__init__(**params) + + # Source the initialisers list either from a direct parameter + # (processor-group embedding) or from a file (CLI launch). + inits = params.get("initialisers") + if inits is None: + inits_file = params.get("initialisers_file") + if inits_file is None: + raise RuntimeError( + "Bootstrapper requires either the 'initialisers' " + "parameter or --initialisers-file" + ) + inits = _load_initialisers_file(inits_file) + + self.specs = [] + names = set() + + for entry in inits: + if not isinstance(entry, dict): + raise RuntimeError( + f"Initialiser entry must be a mapping, got: {entry!r}" + ) + for required in ("class", "name", "flag"): + if required not in entry: + raise RuntimeError( + f"Initialiser entry missing required field " + f"{required!r}: {entry!r}" + ) + + name = entry["name"] + if name in names: + raise RuntimeError(f"Duplicate initialiser name {name!r}") + names.add(name) + + cls = _resolve_class(entry["class"]) + + try: + instance = cls(**entry.get("params", {})) + except Exception as e: + raise RuntimeError( + f"Failed to instantiate initialiser " + f"{entry['class']!r} as {name!r}: " + f"{type(e).__name__}: {e}" + ) + + self.specs.append(InitialiserSpec( + name=name, + flag=entry["flag"], + instance=instance, + )) + + logger.info( + f"Bootstrapper: loaded {len(self.specs)} initialisers" + ) + + # ------------------------------------------------------------------ + # Client construction (short-lived per wake cycle). + # ------------------------------------------------------------------ + + def _make_config_client(self): + rr_id = str(uuid.uuid4()) + return ConfigClient( + backend=self.pubsub_backend, + subscription=f"{self.id}--config--{rr_id}", + consumer_name=self.id, + request_topic=config_request_queue, + request_schema=ConfigRequest, + request_metrics=ProducerMetrics( + processor=self.id, flow=None, name="config-request", + ), + response_topic=config_response_queue, + response_schema=ConfigResponse, + response_metrics=SubscriberMetrics( + processor=self.id, flow=None, name="config-response", + ), + ) + + def _make_flow_client(self): + rr_id = str(uuid.uuid4()) + return RequestResponse( + backend=self.pubsub_backend, + subscription=f"{self.id}--flow--{rr_id}", + consumer_name=self.id, + request_topic=flow_request_queue, + request_schema=FlowRequest, + request_metrics=ProducerMetrics( + processor=self.id, flow=None, name="flow-request", + ), + response_topic=flow_response_queue, + response_schema=FlowResponse, + response_metrics=SubscriberMetrics( + processor=self.id, flow=None, name="flow-response", + ), + ) + + async def _open_clients(self): + config = self._make_config_client() + flow = self._make_flow_client() + await config.start() + try: + await flow.start() + except Exception: + await self._safe_stop(config) + raise + return config, flow + + async def _safe_stop(self, client): + try: + await client.stop() + except Exception: + pass + + # ------------------------------------------------------------------ + # Service gate. + # ------------------------------------------------------------------ + + async def _gate_ready(self, config, flow): + try: + await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE) + except Exception as e: + logger.info( + f"Gate: config-svc not ready ({type(e).__name__}: {e})" + ) + return False + + try: + resp = await flow.request( + FlowRequest( + operation="list-blueprints", + workspace=SYSTEM_WORKSPACE, + ), + timeout=5, + ) + if resp.error: + logger.info( + f"Gate: flow-svc error: " + f"{resp.error.type}: {resp.error.message}" + ) + return False + except Exception as e: + logger.info( + f"Gate: flow-svc not ready ({type(e).__name__}: {e})" + ) + return False + + return True + + # ------------------------------------------------------------------ + # Completion state. + # ------------------------------------------------------------------ + + async def _stored_flag(self, config, name): + raw = await config.get(SYSTEM_WORKSPACE, INIT_STATE_TYPE, name) + if raw is None: + return None + try: + return json.loads(raw) + except Exception: + return raw + + async def _store_flag(self, config, name, flag): + await config.put( + SYSTEM_WORKSPACE, INIT_STATE_TYPE, name, + json.dumps(flag), + ) + + # ------------------------------------------------------------------ + # Per-spec execution. + # ------------------------------------------------------------------ + + async def _run_spec(self, spec, config, flow): + """Run a single initialiser spec. + + Returns one of: + - ``"skip"``: stored flag already matches target, nothing to do. + - ``"ran"``: initialiser ran and completion state was updated. + - ``"failed"``: initialiser raised. + - ``"failed-state-write"``: initialiser succeeded but we could + not persist the new flag (transient — will re-run next cycle). + """ + + try: + old_flag = await self._stored_flag(config, spec.name) + except Exception as e: + logger.warning( + f"{spec.name}: could not read stored flag " + f"({type(e).__name__}: {e})" + ) + return "failed" + + if old_flag == spec.flag: + return "skip" + + child_logger = logger.getChild(spec.name) + child_ctx = InitContext( + logger=child_logger, + config=config, + flow=flow, + ) + + child_logger.info( + f"Running (old_flag={old_flag!r} -> new_flag={spec.flag!r})" + ) + + try: + await spec.instance.run(child_ctx, old_flag, spec.flag) + except Exception as e: + child_logger.error( + f"Failed: {type(e).__name__}: {e}", exc_info=True, + ) + return "failed" + + try: + await self._store_flag(config, spec.name, spec.flag) + except Exception as e: + child_logger.warning( + f"Completed but could not persist state flag " + f"({type(e).__name__}: {e}); will re-run next cycle" + ) + return "failed-state-write" + + child_logger.info(f"Completed (flag={spec.flag!r})") + return "ran" + + # ------------------------------------------------------------------ + # Main loop. + # ------------------------------------------------------------------ + + async def run(self): + + logger.info( + f"Bootstrapper starting with {len(self.specs)} initialisers" + ) + + while self.running: + + sleep_for = STEADY_INTERVAL + + try: + config, flow = await self._open_clients() + except Exception as e: + logger.info( + f"Failed to open clients " + f"({type(e).__name__}: {e}); retry in {GATE_BACKOFF}s" + ) + await asyncio.sleep(GATE_BACKOFF) + continue + + try: + # Phase 1: pre-service initialisers run unconditionally. + pre_specs = [ + s for s in self.specs + if not s.instance.wait_for_services + ] + pre_results = {} + for spec in pre_specs: + pre_results[spec.name] = await self._run_spec( + spec, config, flow, + ) + + # Phase 2: gate. + gate_ok = await self._gate_ready(config, flow) + + # Phase 3: post-service initialisers, if gate passed. + post_results = {} + if gate_ok: + post_specs = [ + s for s in self.specs + if s.instance.wait_for_services + ] + for spec in post_specs: + post_results[spec.name] = await self._run_spec( + spec, config, flow, + ) + + # Cadence selection. + if not gate_ok: + sleep_for = GATE_BACKOFF + else: + all_results = {**pre_results, **post_results} + if any(r != "skip" for r in all_results.values()): + sleep_for = INIT_RETRY + else: + sleep_for = STEADY_INTERVAL + + finally: + await self._safe_stop(config) + await self._safe_stop(flow) + + await asyncio.sleep(sleep_for) + + # ------------------------------------------------------------------ + # CLI arg plumbing. + # ------------------------------------------------------------------ + + @staticmethod + def add_args(parser: ArgumentParser) -> None: + + AsyncProcessor.add_args(parser) + + parser.add_argument( + '-c', '--initialisers-file', + help='Path to YAML or JSON file describing the ' + 'initialisers to run. Ignored when the ' + "'initialisers' parameter is provided directly " + '(e.g. when running inside a processor group).', + ) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/__init__.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/__init__.py new file mode 100644 index 00000000..6171eb02 --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/__init__.py @@ -0,0 +1,20 @@ +""" +Core bootstrap initialisers. + +These cover the base TrustGraph deployment case. Enterprise or +third-party initialisers live in their own packages and are +referenced in the bootstrapper's config by fully-qualified dotted +path. +""" + +from . pulsar_topology import PulsarTopology +from . template_seed import TemplateSeed +from . workspace_init import WorkspaceInit +from . default_flow_start import DefaultFlowStart + +__all__ = [ + "PulsarTopology", + "TemplateSeed", + "WorkspaceInit", + "DefaultFlowStart", +] diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py new file mode 100644 index 00000000..7e7f96bd --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py @@ -0,0 +1,101 @@ +""" +DefaultFlowStart initialiser — starts a named flow in a workspace +using a specified blueprint. + +Separated from WorkspaceInit so deployments that want a workspace +without an auto-started flow can simply omit this initialiser. + +Parameters +---------- +workspace : str (default "default") + Workspace in which to start the flow. +flow_id : str (default "default") + Identifier for the started flow. +blueprint : str (required) + Blueprint name (must already exist in the workspace's config, + typically via TemplateSeed -> WorkspaceInit). +description : str (default "Default") + Human-readable description passed to flow-svc. +parameters : dict (optional) + Optional parameter overrides passed to start-flow. +""" + +from trustgraph.schema import FlowRequest + +from .. base import Initialiser + + +class DefaultFlowStart(Initialiser): + + def __init__( + self, + workspace="default", + flow_id="default", + blueprint=None, + description="Default", + parameters=None, + **kwargs, + ): + super().__init__(**kwargs) + if not blueprint: + raise ValueError( + "DefaultFlowStart requires 'blueprint'" + ) + self.workspace = workspace + self.flow_id = flow_id + self.blueprint = blueprint + self.description = description + self.parameters = dict(parameters) if parameters else {} + + async def run(self, ctx, old_flag, new_flag): + + # Check whether the flow already exists. Belt-and-braces + # beyond the flag gate: if an operator stops and restarts the + # bootstrapper after the flow is already running, we don't + # want to blindly try to start it again. + list_resp = await ctx.flow.request( + FlowRequest( + operation="list-flows", + workspace=self.workspace, + ), + timeout=10, + ) + if list_resp.error: + raise RuntimeError( + f"list-flows failed: " + f"{list_resp.error.type}: {list_resp.error.message}" + ) + + if self.flow_id in (list_resp.flow_ids or []): + ctx.logger.info( + f"Flow {self.flow_id!r} already running in workspace " + f"{self.workspace!r}; nothing to do" + ) + return + + ctx.logger.info( + f"Starting flow {self.flow_id!r} " + f"(blueprint={self.blueprint!r}) " + f"in workspace {self.workspace!r}" + ) + + resp = await ctx.flow.request( + FlowRequest( + operation="start-flow", + workspace=self.workspace, + flow_id=self.flow_id, + blueprint_name=self.blueprint, + description=self.description, + parameters=self.parameters, + ), + timeout=30, + ) + if resp.error: + raise RuntimeError( + f"start-flow failed: " + f"{resp.error.type}: {resp.error.message}" + ) + + ctx.logger.info( + f"Flow {self.flow_id!r} started" + ) diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py new file mode 100644 index 00000000..843fe056 --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py @@ -0,0 +1,131 @@ +""" +PulsarTopology initialiser — creates Pulsar tenant and namespaces +with their retention policies. + +Runs pre-gate (``wait_for_services = False``) because config-svc and +flow-svc can't connect to Pulsar until these namespaces exist. +Admin-API calls are idempotent so re-runs on flag change are safe. +""" + +import asyncio +import requests + +from .. base import Initialiser + +# Namespace configs. flow/request take broker defaults. response +# and notify get aggressive retention — those classes carry short-lived +# request/response and notification traffic only. +NAMESPACE_CONFIG = { + "flow": {}, + "request": {}, + "response": { + "retention_policies": { + "retentionSizeInMB": -1, + "retentionTimeInMinutes": 3, + "subscriptionExpirationTimeMinutes": 30, + }, + }, + "notify": { + "retention_policies": { + "retentionSizeInMB": -1, + "retentionTimeInMinutes": 3, + "subscriptionExpirationTimeMinutes": 5, + }, + }, +} + +REQUEST_TIMEOUT = 10 + + +class PulsarTopology(Initialiser): + + wait_for_services = False + + def __init__( + self, + admin_url="http://pulsar:8080", + tenant="tg", + **kwargs, + ): + super().__init__(**kwargs) + self.admin_url = admin_url.rstrip("/") + self.tenant = tenant + + async def run(self, ctx, old_flag, new_flag): + # requests is blocking; offload to executor so the loop stays + # responsive. + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._reconcile_sync, ctx.logger) + + # ------------------------------------------------------------------ + # Sync admin-API calls. + # ------------------------------------------------------------------ + + def _get_clusters(self): + resp = requests.get( + f"{self.admin_url}/admin/v2/clusters", + timeout=REQUEST_TIMEOUT, + ) + resp.raise_for_status() + return resp.json() + + def _tenant_exists(self): + resp = requests.get( + f"{self.admin_url}/admin/v2/tenants/{self.tenant}", + timeout=REQUEST_TIMEOUT, + ) + return resp.status_code == 200 + + def _create_tenant(self, clusters): + resp = requests.put( + f"{self.admin_url}/admin/v2/tenants/{self.tenant}", + json={"adminRoles": [], "allowedClusters": clusters}, + timeout=REQUEST_TIMEOUT, + ) + if resp.status_code != 204: + raise RuntimeError( + f"Tenant {self.tenant!r} create failed: " + f"{resp.status_code} {resp.text}" + ) + + def _namespace_exists(self, namespace): + resp = requests.get( + f"{self.admin_url}/admin/v2/namespaces/" + f"{self.tenant}/{namespace}", + timeout=REQUEST_TIMEOUT, + ) + return resp.status_code == 200 + + def _create_namespace(self, namespace, config): + resp = requests.put( + f"{self.admin_url}/admin/v2/namespaces/" + f"{self.tenant}/{namespace}", + json=config, + timeout=REQUEST_TIMEOUT, + ) + if resp.status_code != 204: + raise RuntimeError( + f"Namespace {self.tenant}/{namespace} create failed: " + f"{resp.status_code} {resp.text}" + ) + + def _reconcile_sync(self, logger): + if not self._tenant_exists(): + clusters = self._get_clusters() + logger.info( + f"Creating tenant {self.tenant!r} with clusters {clusters}" + ) + self._create_tenant(clusters) + else: + logger.debug(f"Tenant {self.tenant!r} already exists") + + for namespace, config in NAMESPACE_CONFIG.items(): + if self._namespace_exists(namespace): + logger.debug( + f"Namespace {self.tenant}/{namespace} already exists" + ) + continue + logger.info( + f"Creating namespace {self.tenant}/{namespace}" + ) + self._create_namespace(namespace, config) diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/template_seed.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/template_seed.py new file mode 100644 index 00000000..5f1e4c19 --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/template_seed.py @@ -0,0 +1,93 @@ +""" +TemplateSeed initialiser — populates the reserved ``__template__`` +workspace from an external JSON seed file. + +Seed file shape: + +.. code-block:: json + + { + "flow-blueprint": { + "ontology": { ... }, + "agent": { ... } + }, + "prompt": { + ... + }, + ... + } + +Top-level keys are config types; nested keys are config entries. +Values are arbitrary JSON (they'll be ``json.dumps()``'d on write). + +Parameters +---------- +config_file : str + Path to the seed file on disk. +overwrite : bool (default False) + On re-run (flag change), if True overwrite all keys; if False + upsert-missing-only (preserves any operator customisation of + the template). +""" + +import json + +from .. base import Initialiser + +TEMPLATE_WORKSPACE = "__template__" + + +class TemplateSeed(Initialiser): + + def __init__(self, config_file, overwrite=False, **kwargs): + super().__init__(**kwargs) + if not config_file: + raise ValueError("TemplateSeed requires 'config_file'") + self.config_file = config_file + self.overwrite = overwrite + + async def run(self, ctx, old_flag, new_flag): + + with open(self.config_file) as f: + seed = json.load(f) + + if old_flag is None: + # Clean first run — write every entry. + await self._write_all(ctx, seed) + return + + # Re-run after flag change. + if self.overwrite: + await self._write_all(ctx, seed) + else: + await self._upsert_missing(ctx, seed) + + async def _write_all(self, ctx, seed): + values = [] + for type_name, entries in seed.items(): + for key, value in entries.items(): + values.append((type_name, key, json.dumps(value))) + if values: + await ctx.config.put_many(TEMPLATE_WORKSPACE, values) + ctx.logger.info( + f"Template seeded with {len(values)} entries" + ) + + async def _upsert_missing(self, ctx, seed): + written = 0 + for type_name, entries in seed.items(): + existing = set( + await ctx.config.keys(TEMPLATE_WORKSPACE, type_name) + ) + values = [] + for key, value in entries.items(): + if key not in existing: + values.append( + (type_name, key, json.dumps(value)) + ) + if values: + await ctx.config.put_many(TEMPLATE_WORKSPACE, values) + written += len(values) + ctx.logger.info( + f"Template upsert-missing: {written} new entries" + ) diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py new file mode 100644 index 00000000..10aefe9d --- /dev/null +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py @@ -0,0 +1,138 @@ +""" +WorkspaceInit initialiser — creates a workspace and populates it from +either the ``__template__`` workspace or a seed file on disk. + +Parameters +---------- +workspace : str + Target workspace to create / populate. +source : str + Either ``"template"`` (copy the full contents of the + ``__template__`` workspace) or ``"seed-file"`` (read from + ``seed_file``). +seed_file : str (required when source=="seed-file") + Path to a JSON seed file with the same shape TemplateSeed consumes. +overwrite : bool (default False) + On re-run (flag change), if True overwrite all keys; if False, + upsert-missing-only (preserves in-workspace customisations). + +Raises (in ``run``) +------------------- +When source is ``"template"``, raises ``RuntimeError`` if the +``__template__`` workspace is empty — indicating that TemplateSeed +hasn't run yet. The bootstrapper's retry loop will re-attempt on +the next cycle once the prerequisite is satisfied. +""" + +import json + +from .. base import Initialiser + +TEMPLATE_WORKSPACE = "__template__" + + +class WorkspaceInit(Initialiser): + + def __init__( + self, + workspace="default", + source="template", + seed_file=None, + overwrite=False, + **kwargs, + ): + super().__init__(**kwargs) + + if source not in ("template", "seed-file"): + raise ValueError( + f"WorkspaceInit: source must be 'template' or " + f"'seed-file', got {source!r}" + ) + if source == "seed-file" and not seed_file: + raise ValueError( + "WorkspaceInit: seed_file required when source='seed-file'" + ) + + self.workspace = workspace + self.source = source + self.seed_file = seed_file + self.overwrite = overwrite + + async def run(self, ctx, old_flag, new_flag): + if self.source == "seed-file": + tree = self._load_seed_file() + else: + tree = await self._load_from_template(ctx) + + if old_flag is None or self.overwrite: + await self._write_all(ctx, tree) + else: + await self._upsert_missing(ctx, tree) + + def _load_seed_file(self): + with open(self.seed_file) as f: + return json.load(f) + + async def _load_from_template(self, ctx): + """Build a seed tree from the entire ``__template__`` workspace. + Raises if the workspace is empty, so the bootstrapper knows + the prerequisite isn't met yet.""" + + raw_tree = await ctx.config.get_all(TEMPLATE_WORKSPACE) + + tree = {} + total = 0 + for type_name, entries in raw_tree.items(): + parsed = {} + for key, raw in entries.items(): + if raw is None: + continue + try: + parsed[key] = json.loads(raw) + except Exception: + parsed[key] = raw + total += 1 + if parsed: + tree[type_name] = parsed + + if total == 0: + raise RuntimeError( + "Template workspace is empty — has TemplateSeed run yet?" + ) + + ctx.logger.debug( + f"Loaded {total} template entries across {len(tree)} types" + ) + return tree + + async def _write_all(self, ctx, tree): + values = [] + for type_name, entries in tree.items(): + for key, value in entries.items(): + values.append((type_name, key, json.dumps(value))) + if values: + await ctx.config.put_many(self.workspace, values) + ctx.logger.info( + f"Workspace {self.workspace!r} populated with " + f"{len(values)} entries" + ) + + async def _upsert_missing(self, ctx, tree): + written = 0 + for type_name, entries in tree.items(): + existing = set( + await ctx.config.keys(self.workspace, type_name) + ) + values = [] + for key, value in entries.items(): + if key not in existing: + values.append( + (type_name, key, json.dumps(value)) + ) + if values: + await ctx.config.put_many(self.workspace, values) + written += len(values) + ctx.logger.info( + f"Workspace {self.workspace!r} upsert-missing: " + f"{written} new entries" + ) diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index 56a54ee0..058f4e4b 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -24,6 +24,21 @@ logger = logging.getLogger(__name__) default_ident = "config-svc" + +def is_reserved_workspace(workspace): + """Reserved workspaces are storage-only. + + Any workspace id beginning with ``_`` is reserved for internal use + (e.g. ``__template__`` holding factory-default seed config). + Reads and writes work normally so bootstrap and provisioning code + can use the standard config API, but **change notifications for + reserved workspaces are suppressed**. Services subscribed to the + config push therefore never see reserved-workspace events and + cannot accidentally act on template content as if it were live + state. + """ + return workspace.startswith("_") + default_config_request_queue = config_request_queue default_config_response_queue = config_response_queue default_config_push_queue = config_push_queue @@ -130,6 +145,21 @@ class Processor(AsyncProcessor): async def push(self, changes=None): + # Suppress notifications from reserved workspaces (ids starting + # with "_", e.g. "__template__"). Stored config is preserved; + # only the broadcast is filtered. Keeps services oblivious to + # template / bootstrap state. + if changes: + filtered = {} + for type_name, workspaces in changes.items(): + visible = [ + w for w in workspaces + if not is_reserved_workspace(w) + ] + if visible: + filtered[type_name] = visible + changes = filtered + version = await self.config.get_version() resp = ConfigPush( From 67b2fc448fb7792594163d43b813415d5231ec66 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Fri, 24 Apr 2026 17:29:10 +0100 Subject: [PATCH 14/21] feat: IAM service, gateway auth middleware, capability model, and CLIs (#849) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the legacy GATEWAY_SECRET shared-token gate with an IAM-backed identity and authorisation model. The gateway no longer has an "allow-all" or "no auth" mode; every request is authenticated via the IAM service, authorised against a capability model that encodes both the operation and the workspace it targets, and rejected with a deliberately-uninformative 401 / 403 on any failure. IAM service (trustgraph-flow/trustgraph/iam, trustgraph-base/schema/iam) ----------------------------------------------------------------------- * New backend service (iam-svc) owning users, workspaces, API keys, passwords and JWT signing keys in Cassandra. Reached over the standard pub/sub request/response pattern; gateway is the only caller. * Operations: bootstrap, resolve-api-key, login, get-signing-key-public, rotate-signing-key, create/list/get/update/disable/delete/enable-user, change-password, reset-password, create/list/get/update/disable- workspace, create/list/revoke-api-key. * Ed25519 JWT signing (alg=EdDSA). Key rotation writes a new kid and retires the previous one; validation is grace-period friendly. * Passwords: PBKDF2-HMAC-SHA-256, 600k iterations, per-user salt. * API keys: 128-bit random, SHA-256 hashed. Plaintext returned once. * Bootstrap is explicit: --bootstrap-mode {token,bootstrap} is a required startup argument with no permissive default. Masked "auth failure" errors hide whether a refused bootstrap request was due to mode, state, or authorisation. Gateway authentication (trustgraph-flow/trustgraph/gateway/auth.py) ------------------------------------------------------------------- * IamAuth replaces the legacy Authenticator. Distinguishes JWTs (three-segment dotted) from API keys by shape; verifies JWTs locally using the cached IAM public key; resolves API keys via IAM with a short-TTL hash-keyed cache. Every failure path surfaces the same 401 body ("auth failure") so callers cannot enumerate credential state. * Public key is fetched at gateway startup with a bounded retry loop; traffic does not begin flowing until auth has started. Capability model (trustgraph-flow/trustgraph/gateway/capabilities.py) --------------------------------------------------------------------- * Roles have two dimensions: a capability set and a workspace scope. OSS ships reader / writer / admin; the first two are workspace- assigned, admin is cross-workspace ("*"). No "cross-workspace" pseudo-capability — workspace permission is a property of the role. * check(identity, capability, target_workspace=None) is the single authorisation test: some role must grant the capability *and* be active in the target workspace. * enforce_workspace validates a request-body workspace against the caller's role scopes and injects the resolved value. Cross- workspace admin is permitted by role scope, not by a bypass. * Gateway endpoints declare a required capability explicitly — no permissive default. Construction fails fast if omitted. Enterprise editions can replace the role table without changing the wire protocol. WebSocket first-frame auth (dispatch/mux.py, endpoint/socket.py) ---------------------------------------------------------------- * /api/v1/socket handshake unconditionally accepts; authentication runs on the first WebSocket frame ({"type":"auth","token":"..."}) with {"type":"auth-ok","workspace":"..."} / {"type":"auth-failed"}. The socket stays open on failure so the client can re-authenticate — browsers treat a handshake-time 401 as terminal, breaking reconnection. * Mux.receive rejects every non-auth frame before auth succeeds, enforces the caller's workspace (envelope + inner payload) using the role-scope resolver, and supports mid-session re-auth. * Flow import/export streaming endpoints keep the legacy ?token= handshake (URL-scoped short-lived transfers; no re-auth need). Auth surface ------------ * POST /api/v1/auth/login — public, returns a JWT. * POST /api/v1/auth/bootstrap — public; forwards to IAM's bootstrap op which itself enforces mode + tables-empty. * POST /api/v1/auth/change-password — any authenticated user. * POST /api/v1/iam — admin-only generic forwarder for the rest of the IAM API (per-op REST endpoints to follow in a later change). Removed / breaking ------------------ * GATEWAY_SECRET / --api-token / default_api_token and the legacy Authenticator.permitted contract. The gateway cannot run without IAM. * ?token= on /api/v1/socket. * DispatcherManager and Mux both raise on auth=None — no silent downgrade path. CLI tools (trustgraph-cli) -------------------------- tg-bootstrap-iam, tg-login, tg-create-user, tg-list-users, tg-disable-user, tg-enable-user, tg-delete-user, tg-change-password, tg-reset-password, tg-create-api-key, tg-list-api-keys, tg-revoke-api-key, tg-create-workspace, tg-list-workspaces. Passwords read via getpass; tokens / one-time secrets written to stdout with operator context on stderr so shell composition works cleanly. AsyncSocketClient / SocketClient updated to the first-frame auth protocol. Specifications -------------- * docs/tech-specs/iam.md updated with the error policy, workspace resolver extension point, and OSS role-scope model. * docs/tech-specs/iam-protocol.md (new) — transport, dataclasses, operation table, error taxonomy, bootstrap modes. * docs/tech-specs/capabilities.md (new) — capability vocabulary, OSS role bundles, agent-as-composition note, enforcement-boundary policy, enterprise extensibility. Tests ----- * test_auth.py (rewritten) — IamAuth + JWT round-trip with real Ed25519 keypairs + API-key cache behaviour. * test_capabilities.py (new) — role table sanity, check across role x workspace combinations, enforce_workspace paths, unknown-cap / unknown-role fail-closed. * Every endpoint test construction now names its capability explicitly (no permissive defaults relied upon). New tests pin the fail-closed invariants: DispatcherManager / Mux refuse auth=None; i18n path-traversal defense is exercised. * test_socket_graceful_shutdown rewritten against IamAuth. --- docs/tech-specs/capabilities.md | 218 ++++ docs/tech-specs/iam-protocol.md | 329 +++++ docs/tech-specs/iam.md | 41 + iam-testing.txt | 252 ++++ tests/unit/test_gateway/test_auth.py | 351 ++++- tests/unit/test_gateway/test_capabilities.py | 203 +++ .../test_gateway/test_dispatch_manager.py | 94 +- tests/unit/test_gateway/test_dispatch_mux.py | 34 +- .../test_gateway/test_endpoint_constant.py | 32 +- tests/unit/test_gateway/test_endpoint_i18n.py | 78 +- .../test_gateway/test_endpoint_manager.py | 34 +- .../test_gateway/test_endpoint_metrics.py | 29 +- .../unit/test_gateway/test_endpoint_socket.py | 56 +- .../unit/test_gateway/test_endpoint_stream.py | 69 +- .../test_gateway/test_endpoint_variable.py | 32 +- tests/unit/test_gateway/test_service.py | 468 +++---- .../test_socket_graceful_shutdown.py | 137 +- .../trustgraph/api/async_socket_client.py | 56 +- .../trustgraph/api/socket_client.py | 53 +- trustgraph-base/trustgraph/base/iam_client.py | 279 ++++ .../trustgraph/messaging/__init__.py | 11 +- .../trustgraph/messaging/translators/iam.py | 194 +++ .../trustgraph/schema/services/__init__.py | 1 + .../trustgraph/schema/services/iam.py | 142 +++ trustgraph-cli/pyproject.toml | 14 + trustgraph-cli/trustgraph/cli/_iam.py | 75 ++ .../trustgraph/cli/bootstrap_iam.py | 94 ++ .../trustgraph/cli/change_password.py | 46 + .../trustgraph/cli/create_api_key.py | 71 ++ trustgraph-cli/trustgraph/cli/create_user.py | 87 ++ .../trustgraph/cli/create_workspace.py | 46 + trustgraph-cli/trustgraph/cli/delete_user.py | 62 + trustgraph-cli/trustgraph/cli/disable_user.py | 45 + trustgraph-cli/trustgraph/cli/enable_user.py | 45 + .../trustgraph/cli/list_api_keys.py | 69 + trustgraph-cli/trustgraph/cli/list_users.py | 65 + .../trustgraph/cli/list_workspaces.py | 53 + trustgraph-cli/trustgraph/cli/login.py | 62 + .../trustgraph/cli/reset_password.py | 54 + .../trustgraph/cli/revoke_api_key.py | 44 + trustgraph-flow/pyproject.toml | 1 + trustgraph-flow/trustgraph/gateway/auth.py | 266 +++- .../trustgraph/gateway/capabilities.py | 238 ++++ .../trustgraph/gateway/dispatch/iam.py | 40 + .../trustgraph/gateway/dispatch/manager.py | 38 +- .../trustgraph/gateway/dispatch/mux.py | 106 +- .../gateway/endpoint/auth_endpoints.py | 115 ++ .../gateway/endpoint/constant_endpoint.py | 41 +- .../trustgraph/gateway/endpoint/i18n.py | 23 +- .../trustgraph/gateway/endpoint/manager.py | 289 ++++- .../trustgraph/gateway/endpoint/metrics.py | 18 +- .../trustgraph/gateway/endpoint/socket.py | 58 +- .../gateway/endpoint/stream_endpoint.py | 64 +- .../gateway/endpoint/variable_endpoint.py | 42 +- trustgraph-flow/trustgraph/gateway/service.py | 33 +- trustgraph-flow/trustgraph/iam/__init__.py | 0 .../trustgraph/iam/service/__init__.py | 1 + .../trustgraph/iam/service/__main__.py | 4 + trustgraph-flow/trustgraph/iam/service/iam.py | 1132 +++++++++++++++++ .../trustgraph/iam/service/service.py | 210 +++ trustgraph-flow/trustgraph/tables/iam.py | 422 ++++++ 61 files changed, 6474 insertions(+), 792 deletions(-) create mode 100644 docs/tech-specs/capabilities.md create mode 100644 docs/tech-specs/iam-protocol.md create mode 100644 iam-testing.txt create mode 100644 tests/unit/test_gateway/test_capabilities.py create mode 100644 trustgraph-base/trustgraph/base/iam_client.py create mode 100644 trustgraph-base/trustgraph/messaging/translators/iam.py create mode 100644 trustgraph-base/trustgraph/schema/services/iam.py create mode 100644 trustgraph-cli/trustgraph/cli/_iam.py create mode 100644 trustgraph-cli/trustgraph/cli/bootstrap_iam.py create mode 100644 trustgraph-cli/trustgraph/cli/change_password.py create mode 100644 trustgraph-cli/trustgraph/cli/create_api_key.py create mode 100644 trustgraph-cli/trustgraph/cli/create_user.py create mode 100644 trustgraph-cli/trustgraph/cli/create_workspace.py create mode 100644 trustgraph-cli/trustgraph/cli/delete_user.py create mode 100644 trustgraph-cli/trustgraph/cli/disable_user.py create mode 100644 trustgraph-cli/trustgraph/cli/enable_user.py create mode 100644 trustgraph-cli/trustgraph/cli/list_api_keys.py create mode 100644 trustgraph-cli/trustgraph/cli/list_users.py create mode 100644 trustgraph-cli/trustgraph/cli/list_workspaces.py create mode 100644 trustgraph-cli/trustgraph/cli/login.py create mode 100644 trustgraph-cli/trustgraph/cli/reset_password.py create mode 100644 trustgraph-cli/trustgraph/cli/revoke_api_key.py create mode 100644 trustgraph-flow/trustgraph/gateway/capabilities.py create mode 100644 trustgraph-flow/trustgraph/gateway/dispatch/iam.py create mode 100644 trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py create mode 100644 trustgraph-flow/trustgraph/iam/__init__.py create mode 100644 trustgraph-flow/trustgraph/iam/service/__init__.py create mode 100644 trustgraph-flow/trustgraph/iam/service/__main__.py create mode 100644 trustgraph-flow/trustgraph/iam/service/iam.py create mode 100644 trustgraph-flow/trustgraph/iam/service/service.py create mode 100644 trustgraph-flow/trustgraph/tables/iam.py diff --git a/docs/tech-specs/capabilities.md b/docs/tech-specs/capabilities.md new file mode 100644 index 00000000..60f5acbf --- /dev/null +++ b/docs/tech-specs/capabilities.md @@ -0,0 +1,218 @@ +--- +layout: default +title: "Capability Vocabulary Technical Specification" +parent: "Tech Specs" +--- + +# Capability Vocabulary Technical Specification + +## Overview + +Authorisation in TrustGraph is **capability-based**. Every gateway +endpoint maps to exactly one *capability*; a user's roles each grant +a set of capabilities; an authenticated request is permitted when +the required capability is a member of the union of the caller's +role capability sets. + +This document defines the capability vocabulary — the closed list +of capability strings that the gateway recognises — and the +open-source edition's role bundles. + +The capability mechanism is shared between open-source and potential +3rd party enterprise capability. The open-source edition ships a +fixed three-role bundle (`reader`, `writer`, `admin`). Enterprise +capability may define additional roles by composing their own +capability bundles from the same vocabulary; no protocol, gateway, +or backend-service change is required. + +## Motivation + +The original IAM spec used hierarchical "minimum role" checks +(`admin` implies `writer` implies `reader`). That shape is simple +but paints the role model into a corner: any enterprise need to +grant a subset of admin abilities (helpdesk that can reset +passwords but not edit flows; analyst who can query but not ingest) +requires a protocol-level change. + +A capability vocabulary decouples "what a request needs" from +"what roles a user has" and makes the role table pure data. The +open-source bundles can stay coarse while the enterprise role +table expands without any code movement. + +## Design + +### Capability string format + +`:` or `` (for capabilities with no +natural read/write split). All lowercase, kebab-case for +multi-word subsystems. + +### Capability list + +**Data plane** + +| Capability | Covers | +|---|---| +| `agent` | agent (query-only; no write counterpart) | +| `graph:read` | graph-rag, graph-embeddings-query, triples-query, sparql, graph-embeddings-export, triples-export | +| `graph:write` | triples-import, graph-embeddings-import | +| `documents:read` | document-rag, document-embeddings-query, document-embeddings-export, entity-contexts-export, document-stream-export, library list / fetch | +| `documents:write` | document-embeddings-import, entity-contexts-import, text-load, document-load, library add / replace / delete | +| `rows:read` | rows-query, row-embeddings-query, nlp-query, structured-query, structured-diag | +| `rows:write` | rows-import | +| `llm` | text-completion, prompt (stateless invocation) | +| `embeddings` | Raw text-embedding service (stateless compute; typed-data embedding stores live under their data-subject capability) | +| `mcp` | mcp-tool | +| `collections:read` | List / describe collections | +| `collections:write` | Create / delete collections | +| `knowledge:read` | List / get knowledge cores | +| `knowledge:write` | Create / delete knowledge cores | + +**Control plane** + +| Capability | Covers | +|---|---| +| `config:read` | Read workspace config | +| `config:write` | Write workspace config | +| `flows:read` | List / describe flows, blueprints, flow classes | +| `flows:write` | Start / stop / update flows | +| `users:read` | List / get users within the workspace | +| `users:write` | Create / update / disable users within the workspace | +| `users:admin` | Assign / remove roles on users within the workspace | +| `keys:self` | Create / revoke / list **own** API keys | +| `keys:admin` | Create / revoke / list **any user's** API keys within the workspace | +| `workspaces:admin` | Create / delete / disable workspaces (system-level) | +| `iam:admin` | JWT signing-key rotation, IAM-level operations | +| `metrics:read` | Prometheus metrics proxy | + +### Open-source role bundles + +The open-source edition ships three roles: + +| Role | Capabilities | +|---|---| +| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self` | +| `writer` | everything in `reader` **+** `graph:write`, `documents:write`, `rows:write`, `collections:write`, `knowledge:write` | +| `admin` | everything in `writer` **+** `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read` | + +Open-source bundles are deliberately coarse. `workspaces:admin` and +`iam:admin` live inside `admin` without a separate role; a single +`admin` user holds the keys to the whole deployment. + +### The `agent` capability and composition + +The `agent` capability is granted independently of the capabilities +it composes under the hood (`llm`, `graph`, `documents`, `rows`, +`mcp`, etc.). A user holding `agent` but not `llm` can still cause +LLM invocations because the agent implementation chooses which +services to invoke on the caller's behalf. + +This is deliberate. A common policy is "allow controlled access +via the agent, deny raw model calls" — granting `agent` without +granting `llm` expresses exactly that. An administrator granting +`agent` should treat it as a grant of everything the agent +composes at deployment time. + +### Authorisation evaluation + +For a request bearing a resolved set of roles +`R = {r1, r2, ...}` against an endpoint that requires capability +`c`: + +``` +allow if c IN union(bundle(r) for r in R) +``` + +No hierarchy, no precedence, no role-order sensitivity. A user +with a single role is the common case; a user with multiple roles +gets the union of their bundles. + +### Enforcement boundary + +Capability checks — and authentication — are applied **only at the +API gateway**, on requests arriving from external callers. +Operations originating inside the platform (backend service to +backend service, agent to LLM, flow-svc to config-svc, bootstrap +initialisers, scheduled reconcilers, autonomous flow steps) are +**not capability-checked**. Backend services trust the workspace +set by the gateway on inbound pub/sub messages and trust +internally-originated messages without further authorisation. + +This policy has four consequences that are part of the spec, not +accidents of implementation: + +1. **The gateway is the single trust boundary for user + authorisation.** Every backend service is a downstream consumer + of an already-authorised workspace scope. +2. **Pub/sub carries workspace, not user identity.** Messages on + the bus do not carry credentials or the identity that originated + a request; they carry the resolved workspace only. This keeps + the bus protocol free of secrets and aligns with the workspace + resolver's role as the gateway-side narrowing step. +3. **Composition is transitive.** Granting a capability that the + platform composes internally (for example, `agent`) transitively + grants everything that capability composes under the hood, + because the downstream calls are internal-origin and are not + re-checked. The composite nature of `agent` described above is + a consequence of this policy, not a special case. +4. **Internal-origin operations have no user.** Bootstrap, + reconcilers, and other platform-initiated work act with + system-level authority. The workspace field on such messages + identifies which workspace's data is being touched, not who + asked. + +**Trust model.** Whoever has pub/sub access is implicitly trusted +to act as any workspace. Defense-in-depth within the backend is +not part of this design; the security perimeter is the gateway +and the bus itself (TLS / network isolation between the bus and +any untrusted network). + +### Unknown capabilities and unknown roles + +- An endpoint declaring an unknown capability is a server-side bug + and fails closed (403, logged). +- A user carrying a role name that is not defined in the role table + is ignored for authorisation purposes and logged as a warning. + Behaviour is deterministic: unknown roles contribute zero + capabilities. + +### Capability scope + +Every capability is **implicitly scoped to the caller's resolved +workspace**. A `users:write` capability does not permit a user +in workspace `acme` to create users in workspace `beta` — the +workspace-resolver has already narrowed the request to one +workspace before the capability check runs. See the IAM +specification for the workspace-resolver contract. + +The three exceptions are the system-level capabilities +`workspaces:admin` and `iam:admin`, which operate across +workspaces by definition, and `metrics:read`, which returns +process-level series not scoped to any workspace. + +## Enterprise extensibility + +Enterprise editions extend the role table additively: + +``` +data-analyst: {query, library:read, collections:read, knowledge:read} +helpdesk: {users:read, users:write, users:admin, keys:admin} +data-engineer: writer + {flows:read, config:read} +workspace-owner: admin − {workspaces:admin, iam:admin} +``` + +None of this requires a protocol change — the wire-protocol `roles` +field on user records is already a set, the gateway's +capability-check is already capability-based, and the capability +vocabulary is closed. Enterprises may introduce roles whose bundles +compose the same capabilities differently. + +When an enterprise introduces a new capability (e.g. for a feature +that does not exist in open source), the capability string is +added to the vocabulary and recognised by the gateway build that +ships that feature. + +## References + +- [Identity and Access Management Specification](iam.md) +- [Architecture Principles](architecture-principles.md) diff --git a/docs/tech-specs/iam-protocol.md b/docs/tech-specs/iam-protocol.md new file mode 100644 index 00000000..8638e7e9 --- /dev/null +++ b/docs/tech-specs/iam-protocol.md @@ -0,0 +1,329 @@ +--- +layout: default +title: "IAM Service Protocol Technical Specification" +parent: "Tech Specs" +--- + +# IAM Service Protocol Technical Specification + +## Overview + +The IAM service is a backend processor, reached over the standard +request/response pub/sub pattern. It is the authority for users, +workspaces, API keys, and login credentials. The API gateway +delegates to it for authentication resolution and for all user / +workspace / key management. + +This document defines the wire protocol: the `IamRequest` and +`IamResponse` dataclasses, the operation set, the per-operation +input and output fields, the error taxonomy, and the initial HTTP +forwarding endpoint used while IAM is being integrated into the +gateway. + +Architectural context — roles, capabilities, workspace scoping, +enforcement boundary — lives in [`iam.md`](iam.md) and +[`capabilities.md`](capabilities.md). + +## Transport + +- **Request topic:** `request:tg/request/iam-request` +- **Response topic:** `response:tg/response/iam-response` +- **Pattern:** request/response, correlated by the `id` message + property, the same pattern used by `config-svc` and `flow-svc`. +- **Caller:** the API gateway only. Under the enforcement-boundary + policy (see capabilities spec), the IAM service trusts the bus + and performs no per-request authentication or capability check + against the caller. The gateway has already evaluated capability + membership and workspace scoping before sending the request. + +## Dataclasses + +### `IamRequest` + +```python +@dataclass +class IamRequest: + # One of the operation strings below. + operation: str = "" + + # Scope of this request. Required on every workspace-scoped + # operation. Omitted (or empty) for system-level ops + # (workspace CRUD, signing-key ops, bootstrap, resolve-api-key, + # login). + workspace: str = "" + + # Acting user id, for audit. Set by the gateway to the + # authenticated caller's id on user-initiated operations. + # Empty for internal-origin (bootstrap, reconcilers) and for + # resolve-api-key / login (no actor yet). + actor: str = "" + + # --- identity selectors --- + user_id: str = "" + username: str = "" # login; unique within a workspace + key_id: str = "" # revoke-api-key, list-api-keys (own) + api_key: str = "" # resolve-api-key (plaintext) + + # --- credentials --- + password: str = "" # login, change-password (current) + new_password: str = "" # change-password + + # --- user fields --- + user: UserInput | None = None # create-user, update-user + + # --- workspace fields --- + workspace_record: WorkspaceInput | None = None # create-workspace, update-workspace + + # --- api key fields --- + key: ApiKeyInput | None = None # create-api-key +``` + +### `IamResponse` + +```python +@dataclass +class IamResponse: + # Populated on success of operations that return them. + user: UserRecord | None = None # create-user, get-user, update-user + users: list[UserRecord] = field(default_factory=list) # list-users + workspace: WorkspaceRecord | None = None # create-workspace, get-workspace, update-workspace + workspaces: list[WorkspaceRecord] = field(default_factory=list) # list-workspaces + + # create-api-key returns the plaintext once. Never populated + # on any other operation. + api_key_plaintext: str = "" + api_key: ApiKeyRecord | None = None # create-api-key + api_keys: list[ApiKeyRecord] = field(default_factory=list) # list-api-keys + + # login, rotate-signing-key + jwt: str = "" + jwt_expires: str = "" # ISO-8601 UTC + + # get-signing-key-public + signing_key_public: str = "" # PEM + + # resolve-api-key returns who this key authenticates as. + resolved_user_id: str = "" + resolved_workspace: str = "" + resolved_roles: list[str] = field(default_factory=list) + + # reset-password + temporary_password: str = "" # returned once to the operator + + # bootstrap: on first run, the initial admin's one-time API key + # is returned for the operator to capture. + bootstrap_admin_user_id: str = "" + bootstrap_admin_api_key: str = "" + + # Present on any failed operation. + error: Error | None = None +``` + +### Value types + +```python +@dataclass +class UserInput: + username: str = "" + name: str = "" + email: str = "" + password: str = "" # only on create-user; never on update-user + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + +@dataclass +class UserRecord: + id: str = "" + workspace: str = "" + username: str = "" + name: str = "" + email: str = "" + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + created: str = "" # ISO-8601 UTC + # Password hash is never included in any response. + +@dataclass +class WorkspaceInput: + id: str = "" + name: str = "" + enabled: bool = True + +@dataclass +class WorkspaceRecord: + id: str = "" + name: str = "" + enabled: bool = True + created: str = "" # ISO-8601 UTC + +@dataclass +class ApiKeyInput: + user_id: str = "" + name: str = "" # operator-facing label, e.g. "laptop" + expires: str = "" # optional ISO-8601 UTC; empty = no expiry + +@dataclass +class ApiKeyRecord: + id: str = "" + user_id: str = "" + name: str = "" + prefix: str = "" # first 4 chars of plaintext, for identification in lists + expires: str = "" # empty = no expiry + created: str = "" + last_used: str = "" # empty if never used + # key_hash is never included in any response. +``` + +## Operations + +| Operation | Request fields | Response fields | Notes | +|---|---|---|---| +| `login` | `username`, `password`, `workspace` (optional) | `jwt`, `jwt_expires` | If `workspace` omitted, IAM resolves to the user's assigned workspace. | +| `resolve-api-key` | `api_key` (plaintext) | `resolved_user_id`, `resolved_workspace`, `resolved_roles` | Gateway-internal. Service returns `auth-failed` for unknown / expired / revoked keys. | +| `change-password` | `user_id`, `password` (current), `new_password` | — | Self-service. IAM validates `password` against stored hash. | +| `reset-password` | `user_id` | `temporary_password` | Admin-initiated. IAM generates a random password, sets `must_change_password=true` on the user, returns the plaintext once. | +| `create-user` | `workspace`, `user` | `user` | Admin-only. `user.password` is hashed and stored; `user.roles` must be subset of known roles. | +| `list-users` | `workspace` | `users` | | +| `get-user` | `workspace`, `user_id` | `user` | | +| `update-user` | `workspace`, `user_id`, `user` | `user` | `password` field on `user` is rejected; use `change-password` / `reset-password`. | +| `disable-user` | `workspace`, `user_id` | — | Soft-delete; sets `enabled=false`. Revokes all the user's API keys. | +| `create-workspace` | `workspace_record` | `workspace` | System-level. | +| `list-workspaces` | — | `workspaces` | System-level. | +| `get-workspace` | `workspace_record` (id only) | `workspace` | System-level. | +| `update-workspace` | `workspace_record` | `workspace` | System-level. | +| `disable-workspace` | `workspace_record` (id only) | — | System-level. Sets `enabled=false`; revokes all workspace API keys; disables all users in the workspace. | +| `create-api-key` | `workspace`, `key` | `api_key_plaintext`, `api_key` | Plaintext returned **once**; only hash stored. `key.name` required. | +| `list-api-keys` | `workspace`, `user_id` | `api_keys` | | +| `revoke-api-key` | `workspace`, `key_id` | — | Deletes the key record. | +| `get-signing-key-public` | — | `signing_key_public` | Gateway fetches this at startup. | +| `rotate-signing-key` | — | — | System-level. Introduces a new signing key; old key continues to validate JWTs for a grace period (implementation-defined, minimum 1h). | +| `bootstrap` | — | `bootstrap_admin_user_id`, `bootstrap_admin_api_key` | If IAM tables are empty, creates the initial `default` workspace, an `admin` user, an initial API key, and an initial signing key; returns them once. No-op on subsequent calls (returns empty fields). | + +## Error taxonomy + +All errors are carried in the `IamResponse.error` field. `error.type` +is one of the values below; `error.message` is a human-readable +string that is **not** surfaced verbatim to external callers (the +gateway maps to `auth failure` / `access denied` per the IAM error +policy). + +| `type` | When | +|---|---| +| `invalid-argument` | Malformed request (missing required field, unknown operation, invalid format). | +| `not-found` | Named resource does not exist (`user_id`, `key_id`, workspace). | +| `duplicate` | Create operation collides with an existing resource (username, workspace id, key name). | +| `auth-failed` | `login` with wrong credentials; `resolve-api-key` with unknown / expired / revoked key; `change-password` with wrong current password. Single bucket to deny oracle attacks. | +| `weak-password` | Password does not meet policy (length, complexity — policy defined at service level). | +| `disabled` | Target user or workspace has `enabled=false`. | +| `operation-not-permitted` | Non-admin attempting system-level operation, or workspace-scoped operation attempting to affect another workspace. | +| `internal-error` | Unexpected IAM-side failure. Log and surface as 500 at the gateway. | + +The gateway is responsible for translating `auth-failed` and +`operation-not-permitted` into the obfuscated external error +response (`"auth failure"` / `"access denied"`); `invalid-argument` +becomes a descriptive 400; `not-found` / `duplicate` / +`weak-password` / `disabled` become descriptive 4xx but never leak +IAM-internal detail. + +## Credential storage + +- **Passwords** are stored using a slow KDF (bcrypt / argon2id — the + service picks; documented as an implementation detail). The + `password_hash` column stores the full KDF-encoded string + (algorithm, cost, salt, hash). Not a plain SHA-256. +- **API keys** are stored as SHA-256 of the plaintext. API keys + are 128-bit random values (`tg_` + base64url); the entropy + makes a slow hash unnecessary. The hash serves as the primary + key on the `iam_api_keys` table, enabling O(1) lookup on + `resolve-api-key`. +- **JWT signing key** is stored as an RSA or Ed25519 private key + (implementation choice) in a dedicated `iam_signing_keys` table + with a `kid`, `created`, and optional `retired` timestamp. At + most one active key; up to N retired keys are kept for a grace + period to validate previously-issued JWTs. + +Passwords, API-key plaintext, and signing-key private material are +never returned in any response other than the explicit one-time +responses above (`reset-password`, `create-api-key`, `bootstrap`). + +## Bootstrap modes + +`iam-svc` requires a bootstrap mode to be chosen at startup. There is +no default — an unset or invalid mode causes the service to refuse +to start. The purpose is to force the operator to make an explicit +security decision rather than rely on an implicit "safe" fallback. + +| Mode | Startup behaviour | `bootstrap` operation | Suitability | +|---|---|---|---| +| `token` | On first start with empty tables, auto-seeds the `default` workspace, admin user, admin API key (using the operator-provided `--bootstrap-token`), and an initial signing key. No-op on subsequent starts. | Refused — returns `auth-failed` / `"auth failure"` regardless of caller. | Production, any public-exposure deployment. | +| `bootstrap` | No startup seeding. Tables remain empty until the `bootstrap` operation is invoked over the pub/sub bus (typically via `tg-bootstrap-iam`). | Live while tables are empty. Generates and returns the admin API key once. Refused (`auth-failed`) once tables are populated. | Dev / compose up / CI. **Not safe under public exposure** — any caller reaching the gateway's `/api/v1/iam` forwarder before the operator can cause a token to be issued to them. Operators choosing this mode accept that risk. | + +### Error masking + +In both modes, any refused invocation of the `bootstrap` operation +returns the same error (`auth-failed` / `"auth failure"`). A caller +cannot distinguish: + +- "service is in token mode" +- "service is in bootstrap mode but already bootstrapped" +- "operation forbidden" + +This matches the general IAM error-policy stance (see `iam.md`) and +prevents externally enumerating IAM's state. + +### Bootstrap-token lifecycle + +The bootstrap token — whether operator-supplied (`token` mode) or +service-generated (`bootstrap` mode) — is a one-time credential. It +is stored as admin's single API key, tagged `name="bootstrap"`. The +operator's first admin action after bootstrap should be: + +1. Create a durable admin user and API key (or issue a durable API + key to the bootstrap admin). +2. Revoke the bootstrap key via `revoke-api-key`. +3. Remove the bootstrap token from any deployment configuration. + +The `name="bootstrap"` marker makes bootstrap keys easy to detect in +tooling (e.g. a `tg-list-api-keys` filter). + +## HTTP forwarding (initial integration) + +For the initial gateway integration — before the IAM service is +wired into the authentication middleware — the gateway exposes a +single forwarding endpoint: + +``` +POST /api/v1/iam +``` + +- Request body is a JSON encoding of `IamRequest`. +- Response body is a JSON encoding of `IamResponse`. +- The gateway's existing authentication (`GATEWAY_SECRET` bearer) + gates access to this endpoint so the IAM protocol can be + exercised end-to-end in tests without touching the live auth + path. +- This endpoint is **not** the final shape. Once the middleware is + in place, per-operation REST endpoints replace it (for example + `POST /api/v1/auth/login`, `POST /api/v1/users`, `DELETE + /api/v1/api-keys/{id}`), and this generic forwarder is removed. + +The endpoint performs only message marshalling: it does not read +or rewrite fields in the request, and it applies no capability +check. All authorisation for user / workspace / key management +lands in the subsequent middleware work. + +## Non-goals for this spec + +- REST endpoint shape for the final gateway surface — covered in + Phase 2 of the IAM implementation plan, not here. +- OIDC / SAML external IdP protocol — out of scope for open source. +- Key-signing algorithm choice, password KDF choice, JWT claim + layout — implementation details captured in code + ADRs, not + locked in the protocol spec. + +## References + +- [Identity and Access Management Specification](iam.md) +- [Capability Vocabulary Specification](capabilities.md) diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md index cb1399fe..50b64444 100644 --- a/docs/tech-specs/iam.md +++ b/docs/tech-specs/iam.md @@ -423,6 +423,37 @@ resolve API keys and to handle login requests. User management operations (create user, revoke key, etc.) also go through the IAM service. +### Error policy + +External error responses carry **no diagnostic detail** for +authentication or access-control failures. The goal is to give an +attacker probing the endpoint no signal about which condition they +tripped. + +| Category | HTTP | Body | WebSocket frame | +|----------|------|------|-----------------| +| Authentication failure | `401 Unauthorized` | `{"error": "auth failure"}` | `{"type": "auth-failed", "error": "auth failure"}` | +| Access control failure | `403 Forbidden` | `{"error": "access denied"}` | `{"error": "access denied"}` (endpoint-specific frame type) | + +"Authentication failure" covers missing credential, malformed +credential, invalid signature, expired token, revoked API key, and +unknown API key — all indistinguishable to the caller. + +"Access control failure" covers role insufficient, workspace +mismatch, user disabled, and workspace disabled — all +indistinguishable to the caller. + +**Server-side logging is richer.** The audit log records the specific +reason (`"workspace-mismatch: user alice assigned 'acme', requested +'beta'"`, `"role-insufficient: admin required, user has writer"`, +etc.) for operators and post-incident forensics. These messages never +appear in responses. + +Other error classes (bad request, internal error) remain descriptive +because they do not reveal anything about the auth or access-control +surface — e.g. `"missing required field 'workspace'"` or +`"invalid JSON"` is fine. + ### Gateway changes The current `Authenticator` class is replaced with a thin authentication @@ -713,6 +744,16 @@ These are not implemented but the architecture does not preclude them: - **Multi-workspace access.** Users could be granted access to additional workspaces beyond their primary assignment. The workspace validation step checks a grant list instead of a single assignment. +- **Workspace resolver.** Workspace resolution on each authenticated + request — "given this user and this requested workspace, which + workspace (if any) may the request operate on?" — is encapsulated + in a single pluggable resolver. The open-source edition ships a + resolver that permits only the user's single assigned workspace; + enterprise editions that implement multi-workspace access swap in a + resolver that consults a permitted set. The wire protocol (the + optional `workspace` field on the authenticated request) is + identical in both editions, so clients written against one edition + work unchanged against the other. - **Rules-based access control.** A separate access control service could evaluate fine-grained policies (per-collection permissions, operation-level restrictions, time-based access). The gateway diff --git a/iam-testing.txt b/iam-testing.txt new file mode 100644 index 00000000..0d03ffc3 --- /dev/null +++ b/iam-testing.txt @@ -0,0 +1,252 @@ + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation": "bootstrap"}' + + + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation": "resolve-api-key", "api_key": "tg_r-n43hDWV9WOY06w6o5YpevAxirlS33D"}' + + + + + + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation": "resolve-api-key", "api_key": "asdalsdjasdkasdasda"}' + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-users","workspace":"default"}' + + + + # 1. Admin creates a writer user "alice" + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{ + "operation": "create-user", + "workspace": "default", + "user": { + "username": "alice", + "name": "Alice", + "email": "alice@example.com", + "password": "changeme", + "roles": ["writer"] + } + }' + # expect: {"user": {"id": "", ...}} — grab alice's uuid + + # 2. Issue alice an API key + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{ + "operation": "create-api-key", + "workspace": "default", + "key": { + "user_id": "f2363a10-3b83-44ea-a008-43caae8ba607", + "name": "alice-laptop" + } + }' + # expect: {"api_key_plaintext": "tg_...", "api_key": {"id": "", "prefix": "tg_xxxx", ...}} + + # 3. Resolve alice's key — should return alice's id + workspace + writer role + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"resolve-api-key","api_key":"tg_gt4buvk5NG-QS7oP_0Gk5yTWyj1qensf"}' + + # expect: {"resolved_user_id":"","resolved_workspace":"default","resolved_roles":["writer"]} + + # 4. List alice's keys (admin view of alice's keys) + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-api-keys","workspace":"default","user_id":"f2363a10-3b83-44ea-a008-43caae8ba607"}' + # expect: {"api_keys": [{"id":"","user_id":"","name":"alice-laptop","prefix":"tg_xxxx",...}]} + + # 5. Revoke alice's key + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"revoke-api-key","workspace":"default","key_id":"55f1c1f7-5448-49fd-9eda-56c192b61177"}' + + + # expect: {} (empty, no error) + + # 6. Confirm the revoked key no longer resolves + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"resolve-api-key","api_key":"tg_gt4buvk5NG-QS7oP_0Gk5yTWyj1qensf"}' + # expect: {"error":{"type":"auth-failed","message":"unknown api key"}} + + + +---------------------------------------------------------------------------- + + You'll want to re-bootstrap a fresh deployment to pick up the new signing-key row (or accept that login will lazily generate one on first + call). Then: + + # 1. Create a user with a known password (admin's password is random) + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"create-user","workspace":"default","user":{"username":"alice","password":"s3cret","roles":["writer"]}}' + + + + # 2. Log alice in + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"s3cret"}' + # expect: {"jwt":"eyJ...","jwt_expires":"2026-..."} + + # 3. Fetch the public key (what the gateway will use later to verify) + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"get-signing-key-public"}' + + # expect: {"signing_key_public":"-----BEGIN PUBLIC KEY-----\n..."} + + # 4. Wrong password + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Authorization: Bearer $GATEWAY_SECRET" \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"nope"}' + + + + # expect: {"error":{"type":"auth-failed","message":"bad credentials"}} + + + + + +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAseLB/a9Bo/RN/Rb/x763 ++vdxmUKG75oWsXBmbwZGDXyN6fwqZ3L7cEje93qK0PYFuCHxhY1Hn0gW7FZ8ovH+ +qEksekUlpfPYqKGiT5Mb0DKk49D4yKkIbJFugWalpwIilvRbQO0jy3V8knqGQ1xL +NfNYFrI2Rxe0Tq2OHVYc5YwYbyj1nz2TY5fd9qrzXtGRv5HZztkl25lWhRvG9G0K +urKDdBDbi894gIYorXvcwZw/b1GDXG/aUy/By1Oy3hXnCLsN8pA3nA437TTTWxHx +QgPH15jIF9hezO+3/ESZ7EhVEtgmwTxPddfXRa0ZoT6JyWOgcloKtnP4Lp9eQ4va +yQIDAQAB +-----END PUBLIC KEY----- + + + + + + New operations: + - change-password — self-service. Requires current + new password. + - reset-password — admin-driven. Generates a random temporary, sets must_change_password=true, returns plaintext once. + - get-user, update-user, disable-user — workspace-scoped. update-user refuses to change username (immutable — error if different) and refuses + password-via-update. disable-user also revokes all the user's API keys, per spec. + - create-workspace, list-workspaces, get-workspace, update-workspace, disable-workspace — system-level. disable-workspace cascades: disables + all users + revokes all their keys. Rejects ids starting with _ (reserved, per the bootstrap framework convention). + - rotate-signing-key — generates a new Ed25519 key, retires the current one (sets retired timestamp; row stays for future grace-period + validation), switches the in-memory cache. + + Touched files: + - trustgraph-flow/trustgraph/tables/iam.py — added retire_signing_key, update_user_profile, update_user_password, update_user_enabled, + update_workspace. + - trustgraph-flow/trustgraph/iam/service/iam.py — 12 new handlers + dispatch entries. + - trustgraph-base/trustgraph/base/iam_client.py — matching client helpers for all of them. + + Smoke-test suggestions: + + # change password for alice (from "s3cret" → "n3wer") + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"change-password","user_id":"b2960feb-caef-401d-af65-01bdb6960cad","password":"s3cret","new_password":"n3wer"}' + + # login with new password + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"n3wer"}' + + # admin resets alice's password + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"reset-password","workspace":"default","user_id":"b2960feb-caef-401d-af65-01bdb6960cad"}' + + + # → {"temporary_password":"..."} + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"fH2ttyrIcVXCIkH_"}' + + + # create a second workspace + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"create-workspace","workspace_record":{"id":"acme","name":"Acme Corp","enabled":true}}' + + + # rotate signing key (next login produces a JWT signed by a new kid) + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"rotate-signing-key"}' + + + + + + + curl -s -X POST "http://localhost:8088/api/v1/flow" \ + -H "Authorization: Bearer tg_bs_kBAhfejiEJmbcO1gElbxk3MpV7wQFygP" \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-flows"}' + + curl -s -X POST "http://localhost:8088/api/v1/iam" \ + -H "Authorization: Bearer tg_bs_kBAhfejiEJmbcO1gElbxk3MpV7wQFygP" \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-users"}' + + + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer tg_bs_kBAhfejiEJmbcO1gElbxk3MpV7wQFygP" \ + -d '{ + "operation": "create-user", + "workspace": "default", + "user": { + "username": "alice", + "name": "Alice", + "email": "alice@example.com", + "password": "s3cret", + "roles": ["writer"] + } + }' + + + + + # Login (public, no token needed) → returns a JWT + curl -s -X POST "http://localhost:8088/api/v1/auth/login" \ + -H "Content-Type: application/json" \ + -d '{"username":"alice","password":"s3cret"}' + + + + export TRUSTGRAPH_TOKEN=$(tg-bootstrap-iam) # on fresh bootstrap-mode deployment + # or set to your existing admin API key + + tg-create-user --username alice --roles writer + # → prints alice's user id + + ALICE_ID= + + ALICE_KEY=$(tg-create-api-key --user-id $ALICE_ID --name alice-laptop) + # → alice's plaintext API key + + tg-list-users + tg-list-api-keys --user-id $ALICE_ID + + tg-revoke-api-key --key-id <...> + tg-disable-user --user-id $ALICE_ID + + # User self-service: + tg-login --username alice # prompts for password, prints JWT + tg-change-password # prompts for current + new + + diff --git a/tests/unit/test_gateway/test_auth.py b/tests/unit/test_gateway/test_auth.py index d4d4fc2b..ba2b9bc2 100644 --- a/tests/unit/test_gateway/test_auth.py +++ b/tests/unit/test_gateway/test_auth.py @@ -1,69 +1,312 @@ """ -Tests for Gateway Authentication +Tests for gateway/auth.py — IamAuth, JWT verification, API key +resolution cache. + +JWTs are signed with real Ed25519 keypairs generated per-test, so +the crypto path is exercised end-to-end without mocks. API-key +resolution is tested against a stubbed IamClient since the real +one requires pub/sub. """ +import base64 +import json +import time +from unittest.mock import AsyncMock, Mock, patch + import pytest +from aiohttp import web +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 -from trustgraph.gateway.auth import Authenticator +from trustgraph.gateway.auth import ( + IamAuth, Identity, + _b64url_decode, _verify_jwt_eddsa, + API_KEY_CACHE_TTL, +) -class TestAuthenticator: - """Test cases for Authenticator class""" +# -- helpers --------------------------------------------------------------- - def test_authenticator_initialization_with_token(self): - """Test Authenticator initialization with valid token""" - auth = Authenticator(token="test-token-123") - - assert auth.token == "test-token-123" - assert auth.allow_all is False - def test_authenticator_initialization_with_allow_all(self): - """Test Authenticator initialization with allow_all=True""" - auth = Authenticator(allow_all=True) - - assert auth.token is None - assert auth.allow_all is True +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") - def test_authenticator_initialization_without_token_raises_error(self): - """Test Authenticator initialization without token raises RuntimeError""" - with pytest.raises(RuntimeError, match="Need a token"): - Authenticator() - def test_authenticator_initialization_with_empty_token_raises_error(self): - """Test Authenticator initialization with empty token raises RuntimeError""" - with pytest.raises(RuntimeError, match="Need a token"): - Authenticator(token="") +def make_keypair(): + priv = ed25519.Ed25519PrivateKey.generate() + public_pem = priv.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("ascii") + return priv, public_pem - def test_permitted_with_allow_all_returns_true(self): - """Test permitted method returns True when allow_all is enabled""" - auth = Authenticator(allow_all=True) - - # Should return True regardless of token or roles - assert auth.permitted("any-token", []) is True - assert auth.permitted("different-token", ["admin"]) is True - assert auth.permitted(None, ["user"]) is True - def test_permitted_with_matching_token_returns_true(self): - """Test permitted method returns True with matching token""" - auth = Authenticator(token="secret-token") - - # Should return True when tokens match - assert auth.permitted("secret-token", []) is True - assert auth.permitted("secret-token", ["admin", "user"]) is True +def sign_jwt(priv, claims, alg="EdDSA"): + header = {"alg": alg, "typ": "JWT", "kid": "kid-test"} + h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode()) + p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode()) + signing_input = f"{h}.{p}".encode("ascii") + if alg == "EdDSA": + sig = priv.sign(signing_input) + else: + raise ValueError(f"test helper doesn't sign {alg}") + return f"{h}.{p}.{_b64url(sig)}" - def test_permitted_with_non_matching_token_returns_false(self): - """Test permitted method returns False with non-matching token""" - auth = Authenticator(token="secret-token") - - # Should return False when tokens don't match - assert auth.permitted("wrong-token", []) is False - assert auth.permitted("different-token", ["admin"]) is False - assert auth.permitted(None, ["user"]) is False - def test_permitted_with_token_and_allow_all_returns_true(self): - """Test permitted method with both token and allow_all set""" - auth = Authenticator(token="test-token", allow_all=True) - - # allow_all should take precedence - assert auth.permitted("any-token", []) is True - assert auth.permitted("wrong-token", ["admin"]) is True \ No newline at end of file +def make_request(auth_header): + """Minimal stand-in for an aiohttp request — IamAuth only reads + ``request.headers["Authorization"]``.""" + req = Mock() + req.headers = {} + if auth_header is not None: + req.headers["Authorization"] = auth_header + return req + + +# -- pure helpers ---------------------------------------------------------- + + +class TestB64UrlDecode: + + def test_round_trip_without_padding(self): + data = b"hello" + encoded = _b64url(data) + assert _b64url_decode(encoded) == data + + def test_handles_various_lengths(self): + for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"): + assert _b64url_decode(_b64url(s)) == s + + +# -- JWT verification ----------------------------------------------------- + + +class TestVerifyJwtEddsa: + + def test_valid_jwt_passes(self): + priv, pub = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", + "roles": ["reader"], + "iat": int(time.time()), + "exp": int(time.time()) + 60, + } + token = sign_jwt(priv, claims) + got = _verify_jwt_eddsa(token, pub) + assert got["sub"] == "user-1" + assert got["workspace"] == "default" + + def test_expired_jwt_rejected(self): + priv, pub = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()) - 3600, + "exp": int(time.time()) - 1, + } + token = sign_jwt(priv, claims) + with pytest.raises(ValueError, match="expired"): + _verify_jwt_eddsa(token, pub) + + def test_bad_signature_rejected(self): + priv_a, _ = make_keypair() + _, pub_b = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()), + "exp": int(time.time()) + 60, + } + token = sign_jwt(priv_a, claims) + # pub_b never signed this token. + with pytest.raises(Exception): + _verify_jwt_eddsa(token, pub_b) + + def test_malformed_jwt_rejected(self): + _, pub = make_keypair() + with pytest.raises(ValueError, match="malformed"): + _verify_jwt_eddsa("not-a-jwt", pub) + + def test_unsupported_algorithm_rejected(self): + priv, pub = make_keypair() + # Manually build an "alg":"HS256" header — no signer needed + # since we expect it to bail before verifying. + header = {"alg": "HS256", "typ": "JWT", "kid": "x"} + payload = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()), "exp": int(time.time()) + 60, + } + h = _b64url(json.dumps(header, separators=(",", ":")).encode()) + p = _b64url(json.dumps(payload, separators=(",", ":")).encode()) + sig = _b64url(b"not-a-real-sig") + token = f"{h}.{p}.{sig}" + with pytest.raises(ValueError, match="unsupported alg"): + _verify_jwt_eddsa(token, pub) + + +# -- Identity -------------------------------------------------------------- + + +class TestIdentity: + + def test_fields(self): + i = Identity( + user_id="u", workspace="w", roles=["reader"], source="api-key", + ) + assert i.user_id == "u" + assert i.workspace == "w" + assert i.roles == ["reader"] + assert i.source == "api-key" + + +# -- IamAuth.authenticate -------------------------------------------------- + + +class TestIamAuthDispatch: + """``authenticate()`` chooses between the JWT and API-key paths + by shape of the bearer.""" + + @pytest.mark.asyncio + async def test_no_authorization_header_raises_401(self): + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) + + @pytest.mark.asyncio + async def test_non_bearer_header_raises_401(self): + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Basic whatever")) + + @pytest.mark.asyncio + async def test_empty_bearer_raises_401(self): + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer ")) + + @pytest.mark.asyncio + async def test_unknown_format_raises_401(self): + # Not tg_... and not dotted-JWT shape. + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer garbage")) + + @pytest.mark.asyncio + async def test_valid_jwt_resolves_to_identity(self): + priv, pub = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", + "roles": ["writer"], + "iat": int(time.time()), + "exp": int(time.time()) + 60, + } + token = sign_jwt(priv, claims) + + auth = IamAuth(backend=Mock()) + auth._signing_public_pem = pub + + ident = await auth.authenticate( + make_request(f"Bearer {token}") + ) + assert ident.user_id == "user-1" + assert ident.workspace == "default" + assert ident.roles == ["writer"] + assert ident.source == "jwt" + + @pytest.mark.asyncio + async def test_jwt_without_public_key_fails(self): + # If the gateway hasn't fetched IAM's public key yet, JWTs + # must not validate — even ones that would otherwise pass. + priv, _ = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()), "exp": int(time.time()) + 60, + } + token = sign_jwt(priv, claims) + auth = IamAuth(backend=Mock()) + # _signing_public_pem defaults to None + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(f"Bearer {token}")) + + @pytest.mark.asyncio + async def test_api_key_path(self): + auth = IamAuth(backend=Mock()) + + async def fake_resolve(api_key): + assert api_key == "tg_testkey" + return ("user-xyz", "default", ["admin"]) + + async def fake_with_client(op): + return await op(Mock(resolve_api_key=fake_resolve)) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = await auth.authenticate( + make_request("Bearer tg_testkey") + ) + assert ident.user_id == "user-xyz" + assert ident.workspace == "default" + assert ident.roles == ["admin"] + assert ident.source == "api-key" + + @pytest.mark.asyncio + async def test_api_key_rejection_masked_as_401(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: unknown api key") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate( + make_request("Bearer tg_bogus") + ) + + +# -- API key cache --------------------------------------------------------- + + +class TestApiKeyCache: + + @pytest.mark.asyncio + async def test_cache_hit_skips_iam(self): + auth = IamAuth(backend=Mock()) + calls = {"n": 0} + + async def fake_with_client(op): + calls["n"] += 1 + return await op(Mock( + resolve_api_key=AsyncMock( + return_value=("u", "default", ["reader"]), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + await auth.authenticate(make_request("Bearer tg_k1")) + await auth.authenticate(make_request("Bearer tg_k1")) + await auth.authenticate(make_request("Bearer tg_k1")) + + # Only the first lookup reaches IAM; the rest are cache hits. + assert calls["n"] == 1 + + @pytest.mark.asyncio + async def test_different_keys_are_separately_cached(self): + auth = IamAuth(backend=Mock()) + seen = [] + + async def fake_with_client(op): + async def resolve(plaintext): + seen.append(plaintext) + return ("u-" + plaintext, "default", ["reader"]) + return await op(Mock(resolve_api_key=resolve)) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + a = await auth.authenticate(make_request("Bearer tg_a")) + b = await auth.authenticate(make_request("Bearer tg_b")) + + assert a.user_id == "u-tg_a" + assert b.user_id == "u-tg_b" + assert seen == ["tg_a", "tg_b"] + + @pytest.mark.asyncio + async def test_cache_has_ttl_constant_set(self): + # Not a behaviour test — just ensures we don't accidentally + # set TTL to 0 (which would defeat the cache) or to a week. + assert 10 <= API_KEY_CACHE_TTL <= 3600 diff --git a/tests/unit/test_gateway/test_capabilities.py b/tests/unit/test_gateway/test_capabilities.py new file mode 100644 index 00000000..063e9ea4 --- /dev/null +++ b/tests/unit/test_gateway/test_capabilities.py @@ -0,0 +1,203 @@ +""" +Tests for gateway/capabilities.py — the capability + role + workspace +model that underpins all gateway authorisation. +""" + +import pytest +from aiohttp import web + +from trustgraph.gateway.capabilities import ( + PUBLIC, AUTHENTICATED, + KNOWN_CAPABILITIES, ROLE_DEFINITIONS, + check, enforce_workspace, access_denied, auth_failure, +) + + +# -- test fixtures --------------------------------------------------------- + + +class _Identity: + """Minimal stand-in for auth.Identity — the capability module + accesses ``.workspace`` and ``.roles``.""" + def __init__(self, workspace, roles): + self.user_id = "user-1" + self.workspace = workspace + self.roles = list(roles) + + +def reader_in(ws): + return _Identity(ws, ["reader"]) + + +def writer_in(ws): + return _Identity(ws, ["writer"]) + + +def admin_in(ws): + return _Identity(ws, ["admin"]) + + +# -- role table sanity ----------------------------------------------------- + + +class TestRoleTable: + + def test_oss_roles_present(self): + assert set(ROLE_DEFINITIONS.keys()) == {"reader", "writer", "admin"} + + def test_admin_is_cross_workspace(self): + assert ROLE_DEFINITIONS["admin"]["workspace_scope"] == "*" + + def test_reader_writer_are_assigned_scope(self): + assert ROLE_DEFINITIONS["reader"]["workspace_scope"] == "assigned" + assert ROLE_DEFINITIONS["writer"]["workspace_scope"] == "assigned" + + def test_admin_superset_of_writer(self): + admin = ROLE_DEFINITIONS["admin"]["capabilities"] + writer = ROLE_DEFINITIONS["writer"]["capabilities"] + assert writer.issubset(admin) + + def test_writer_superset_of_reader(self): + writer = ROLE_DEFINITIONS["writer"]["capabilities"] + reader = ROLE_DEFINITIONS["reader"]["capabilities"] + assert reader.issubset(writer) + + def test_admin_has_users_admin(self): + assert "users:admin" in ROLE_DEFINITIONS["admin"]["capabilities"] + + def test_writer_does_not_have_users_admin(self): + assert "users:admin" not in ROLE_DEFINITIONS["writer"]["capabilities"] + + def test_every_bundled_capability_is_known(self): + for role in ROLE_DEFINITIONS.values(): + for cap in role["capabilities"]: + assert cap in KNOWN_CAPABILITIES + + +# -- check() --------------------------------------------------------------- + + +class TestCheck: + + def test_reader_has_reader_cap_in_own_workspace(self): + assert check(reader_in("default"), "graph:read", "default") + + def test_reader_does_not_have_writer_cap(self): + assert not check(reader_in("default"), "graph:write", "default") + + def test_reader_cannot_act_in_other_workspace(self): + assert not check(reader_in("default"), "graph:read", "acme") + + def test_writer_has_write_in_own_workspace(self): + assert check(writer_in("default"), "graph:write", "default") + + def test_writer_cannot_act_in_other_workspace(self): + assert not check(writer_in("default"), "graph:write", "acme") + + def test_admin_has_everything_everywhere(self): + for cap in ("graph:read", "graph:write", "config:write", + "users:admin", "metrics:read"): + assert check(admin_in("default"), cap, "acme"), ( + f"admin should have {cap} in acme" + ) + + def test_admin_has_caps_without_explicit_workspace(self): + assert check(admin_in("default"), "users:admin") + + def test_default_target_is_identity_workspace(self): + # Reader with no target workspace → should check against own + assert check(reader_in("default"), "graph:read") + + def test_unknown_capability_returns_false(self): + assert not check(admin_in("default"), "nonsense:cap", "default") + + def test_unknown_role_contributes_nothing(self): + ident = _Identity("default", ["made-up-role"]) + assert not check(ident, "graph:read", "default") + + def test_multi_role_union(self): + # If a user is both reader and admin, they inherit admin's + # cross-workspace powers. + ident = _Identity("default", ["reader", "admin"]) + assert check(ident, "users:admin", "acme") + + +# -- enforce_workspace() --------------------------------------------------- + + +class TestEnforceWorkspace: + + def test_reader_in_own_workspace_allowed(self): + data = {"workspace": "default", "operation": "x"} + enforce_workspace(data, reader_in("default")) + assert data["workspace"] == "default" + + def test_reader_no_workspace_injects_assigned(self): + data = {"operation": "x"} + enforce_workspace(data, reader_in("default")) + assert data["workspace"] == "default" + + def test_reader_mismatched_workspace_denied(self): + data = {"workspace": "acme", "operation": "x"} + with pytest.raises(web.HTTPForbidden): + enforce_workspace(data, reader_in("default")) + + def test_admin_can_target_any_workspace(self): + data = {"workspace": "acme", "operation": "x"} + enforce_workspace(data, admin_in("default")) + assert data["workspace"] == "acme" + + def test_admin_no_workspace_defaults_to_assigned(self): + data = {"operation": "x"} + enforce_workspace(data, admin_in("default")) + assert data["workspace"] == "default" + + def test_writer_same_workspace_specified_allowed(self): + data = {"workspace": "default"} + enforce_workspace(data, writer_in("default")) + assert data["workspace"] == "default" + + def test_non_dict_passthrough(self): + # Non-dict bodies are returned unchanged (e.g. streaming). + result = enforce_workspace("not-a-dict", reader_in("default")) + assert result == "not-a-dict" + + def test_with_capability_tightens_check(self): + # Reader lacks graph:write; workspace-only check would pass + # (scope is fine), but combined check must reject. + data = {"workspace": "default"} + with pytest.raises(web.HTTPForbidden): + enforce_workspace( + data, reader_in("default"), capability="graph:write", + ) + + def test_with_capability_passes_when_granted(self): + data = {"workspace": "default"} + enforce_workspace( + data, reader_in("default"), capability="graph:read", + ) + assert data["workspace"] == "default" + + +# -- helpers --------------------------------------------------------------- + + +class TestResponseHelpers: + + def test_auth_failure_is_401(self): + exc = auth_failure() + assert exc.status == 401 + assert "auth failure" in exc.text + + def test_access_denied_is_403(self): + exc = access_denied() + assert exc.status == 403 + assert "access denied" in exc.text + + +class TestSentinels: + + def test_public_and_authenticated_are_distinct(self): + assert PUBLIC != AUTHENTICATED + assert PUBLIC not in KNOWN_CAPABILITIES + assert AUTHENTICATED not in KNOWN_CAPABILITIES diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index f091a46d..e399d712 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -42,7 +42,7 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) assert manager.backend == mock_backend assert manager.config_receiver == mock_config_receiver @@ -59,7 +59,10 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix") + manager = DispatcherManager( + mock_backend, mock_config_receiver, + auth=Mock(), prefix="custom-prefix", + ) assert manager.prefix == "custom-prefix" @@ -68,7 +71,7 @@ class TestDispatcherManager: """Test start_flow method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) flow_data = {"name": "test_flow", "steps": []} @@ -82,7 +85,7 @@ class TestDispatcherManager: """Test stop_flow method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Pre-populate with a flow flow_data = {"name": "test_flow", "steps": []} @@ -96,7 +99,7 @@ class TestDispatcherManager: """Test dispatch_global_service returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_global_service() @@ -107,7 +110,7 @@ class TestDispatcherManager: """Test dispatch_core_export returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_core_export() @@ -118,7 +121,7 @@ class TestDispatcherManager: """Test dispatch_core_import returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_core_import() @@ -130,7 +133,7 @@ class TestDispatcherManager: """Test process_core_import method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import: mock_importer = Mock() @@ -148,7 +151,7 @@ class TestDispatcherManager: """Test process_core_export method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export: mock_exporter = Mock() @@ -166,7 +169,7 @@ class TestDispatcherManager: """Test process_global_service method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager.invoke_global_service = AsyncMock(return_value="global_result") @@ -181,7 +184,7 @@ class TestDispatcherManager: """Test invoke_global_service with existing dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Pre-populate with existing dispatcher mock_dispatcher = Mock() @@ -198,7 +201,7 @@ class TestDispatcherManager: """Test invoke_global_service creates new dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() @@ -230,7 +233,7 @@ class TestDispatcherManager: """Test dispatch_flow_import returns correct method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) result = manager.dispatch_flow_import() @@ -240,7 +243,7 @@ class TestDispatcherManager: """Test dispatch_flow_export returns correct method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) result = manager.dispatch_flow_export() @@ -250,7 +253,7 @@ class TestDispatcherManager: """Test dispatch_socket returns correct method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) result = manager.dispatch_socket() @@ -260,7 +263,7 @@ class TestDispatcherManager: """Test dispatch_flow_service returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_flow_service() @@ -272,7 +275,7 @@ class TestDispatcherManager: """Test process_flow_import with valid flow and kind""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -308,7 +311,7 @@ class TestDispatcherManager: """Test process_flow_import with invalid flow""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) params = {"flow": "invalid_flow", "kind": "triples"} @@ -323,7 +326,7 @@ class TestDispatcherManager: warnings.simplefilter("ignore", RuntimeWarning) mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -345,7 +348,7 @@ class TestDispatcherManager: """Test process_flow_export with valid flow and kind""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -378,26 +381,47 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_socket(self): - """Test process_socket method""" + """process_socket constructs a Mux with the manager's auth + instance passed through — this is the gateway's trust path + for first-frame WebSocket authentication. A Mux cannot be + built without auth (tested separately); this test pins that + the dispatcher-manager threads the correct auth value into + the Mux constructor call.""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) - + mock_auth = Mock() + manager = DispatcherManager( + mock_backend, mock_config_receiver, auth=mock_auth, + ) + with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux: mock_mux_instance = Mock() mock_mux.return_value = mock_mux_instance - + result = await manager.process_socket("ws", "running", {}) - - mock_mux.assert_called_once_with(manager, "ws", "running") + + mock_mux.assert_called_once_with( + manager, "ws", "running", auth=mock_auth, + ) assert result == mock_mux_instance + def test_dispatcher_manager_requires_auth(self): + """Constructing a DispatcherManager without an auth argument + must fail — a no-auth DispatcherManager would produce a + Mux without authentication, silently downgrading the socket + auth path.""" + mock_backend = Mock() + mock_config_receiver = Mock() + + with pytest.raises(ValueError, match="auth"): + DispatcherManager(mock_backend, mock_config_receiver, auth=None) + @pytest.mark.asyncio async def test_process_flow_service(self): """Test process_flow_service method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager.invoke_flow_service = AsyncMock(return_value="flow_result") @@ -412,7 +436,7 @@ class TestDispatcherManager: """Test invoke_flow_service with existing dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Add flow to the flows dictionary manager.flows[("default", "test_flow")] = {"services": {"agent": {}}} @@ -432,7 +456,7 @@ class TestDispatcherManager: """Test invoke_flow_service creates request-response dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -476,7 +500,7 @@ class TestDispatcherManager: """Test invoke_flow_service creates sender dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -516,7 +540,7 @@ class TestDispatcherManager: """Test invoke_flow_service with invalid flow""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with pytest.raises(RuntimeError, match="Invalid flow"): await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent") @@ -526,7 +550,7 @@ class TestDispatcherManager: """Test invoke_flow_service with kind not supported by flow""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow without agent interface manager.flows[("default", "test_flow")] = { @@ -543,7 +567,7 @@ class TestDispatcherManager: """Test invoke_flow_service with invalid kind""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow with interface but unsupported kind manager.flows[("default", "test_flow")] = { @@ -570,7 +594,7 @@ class TestDispatcherManager: """ mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) async def slow_start(): # Yield to the event loop so other coroutines get a chance to run, @@ -606,7 +630,7 @@ class TestDispatcherManager: """ mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager.flows[("default", "test_flow")] = { "interfaces": { diff --git a/tests/unit/test_gateway/test_dispatch_mux.py b/tests/unit/test_gateway/test_dispatch_mux.py index a0bc9460..c1baa920 100644 --- a/tests/unit/test_gateway/test_dispatch_mux.py +++ b/tests/unit/test_gateway/test_dispatch_mux.py @@ -12,6 +12,19 @@ from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE class TestMux: """Test cases for Mux class""" + def test_mux_requires_auth(self): + """Constructing a Mux without an ``auth`` argument must + fail. The Mux implements the first-frame auth protocol and + there is no no-auth mode — a no-auth Mux would silently + accept every frame without authenticating it.""" + with pytest.raises(ValueError, match="auth"): + Mux( + dispatcher_manager=MagicMock(), + ws=MagicMock(), + running=MagicMock(), + auth=None, + ) + def test_mux_initialization(self): """Test Mux initialization""" mock_dispatcher_manager = MagicMock() @@ -21,7 +34,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) assert mux.dispatcher_manager == mock_dispatcher_manager @@ -40,7 +54,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Call destroy @@ -61,7 +76,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=None, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Call destroy @@ -81,7 +97,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message with valid JSON @@ -108,7 +125,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message without request field @@ -137,7 +155,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message without id field @@ -164,7 +183,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message with invalid JSON diff --git a/tests/unit/test_gateway/test_endpoint_constant.py b/tests/unit/test_gateway/test_endpoint_constant.py index f208c967..98588e55 100644 --- a/tests/unit/test_gateway/test_endpoint_constant.py +++ b/tests/unit/test_gateway/test_endpoint_constant.py @@ -13,29 +13,36 @@ class TestConstantEndpoint: """Test cases for ConstantEndpoint class""" def test_constant_endpoint_initialization(self): - """Test ConstantEndpoint initialization""" + """Construction records the configured capability on the + instance. The capability is a required argument — no + permissive default — and the test passes an explicit + value to demonstrate the contract.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = ConstantEndpoint( endpoint_path="/api/test", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability="config:read", ) - + assert endpoint.path == "/api/test" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "service" + assert endpoint.capability == "config:read" @pytest.mark.asyncio async def test_constant_endpoint_start_method(self): """Test ConstantEndpoint start method (should be no-op)""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - - endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher) - + + endpoint = ConstantEndpoint( + "/api/test", mock_auth, mock_dispatcher, + capability="config:read", + ) + # start() should complete without error await endpoint.start() @@ -44,10 +51,13 @@ class TestConstantEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - - endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher) + + endpoint = ConstantEndpoint( + "/api/test", mock_auth, mock_dispatcher, + capability="config:read", + ) endpoint.add_routes(mock_app) - + # Verify add_routes was called with POST route mock_app.add_routes.assert_called_once() # The call should include web.post with the path and handler diff --git a/tests/unit/test_gateway/test_endpoint_i18n.py b/tests/unit/test_gateway/test_endpoint_i18n.py index ab693cdf..c2b51568 100644 --- a/tests/unit/test_gateway/test_endpoint_i18n.py +++ b/tests/unit/test_gateway/test_endpoint_i18n.py @@ -1,4 +1,12 @@ -"""Tests for Gateway i18n pack endpoint.""" +"""Tests for Gateway i18n pack endpoint. + +Production registers this endpoint with ``capability=PUBLIC``: the +login UI needs to render its own i18n strings before any user has +authenticated, so the endpoint is deliberately pre-auth. These +tests exercise the PUBLIC configuration — that is the production +contract. Behaviour of authenticated endpoints is covered by the +IamAuth tests in ``test_auth.py``. +""" import json from unittest.mock import MagicMock @@ -7,6 +15,7 @@ import pytest from aiohttp import web from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint +from trustgraph.gateway.capabilities import PUBLIC class TestI18nPackEndpoint: @@ -17,23 +26,28 @@ class TestI18nPackEndpoint: endpoint = I18nPackEndpoint( endpoint_path="/api/v1/i18n/packs/{lang}", auth=mock_auth, + capability=PUBLIC, ) assert endpoint.path == "/api/v1/i18n/packs/{lang}" assert endpoint.auth == mock_auth - assert endpoint.operation == "service" + assert endpoint.capability == PUBLIC @pytest.mark.asyncio async def test_i18n_endpoint_start_method(self): mock_auth = MagicMock() - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) await endpoint.start() def test_add_routes_registers_get_handler(self): mock_auth = MagicMock() mock_app = MagicMock() - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) endpoint.add_routes(mock_app) mock_app.add_routes.assert_called_once() @@ -41,35 +55,55 @@ class TestI18nPackEndpoint: assert len(call_args) == 1 @pytest.mark.asyncio - async def test_handle_unauthorized_on_invalid_auth_scheme(self): + async def test_handle_returns_pack_without_authenticating(self): + """The PUBLIC endpoint serves the language pack without + invoking the auth handler at all — pre-login UI must be + reachable. The test uses an auth mock that raises if + touched, so any auth attempt by the endpoint is caught.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) + def _should_not_be_called(*args, **kwargs): + raise AssertionError( + "PUBLIC endpoint must not invoke auth.authenticate" + ) + mock_auth.authenticate = _should_not_be_called + + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) request = MagicMock() request.path = "/api/v1/i18n/packs/en" + # A caller-supplied Authorization header of any form should + # be ignored — PUBLIC means we don't look at it. request.headers = {"Authorization": "Token abc"} request.match_info = {"lang": "en"} - resp = await endpoint.handle(request) - assert isinstance(resp, web.HTTPUnauthorized) - - @pytest.mark.asyncio - async def test_handle_returns_pack_when_permitted(self): - mock_auth = MagicMock() - mock_auth.permitted.return_value = True - - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) - - request = MagicMock() - request.path = "/api/v1/i18n/packs/en" - request.headers = {} - request.match_info = {"lang": "en"} - resp = await endpoint.handle(request) assert resp.status == 200 payload = json.loads(resp.body.decode("utf-8")) assert isinstance(payload, dict) assert "cli.verify_system_status.title" in payload + + @pytest.mark.asyncio + async def test_handle_rejects_path_traversal(self): + """The ``lang`` path parameter is reflected through to the + filesystem-backed pack loader. The endpoint contains an + explicit defense against ``/`` and ``..`` in the value; this + test pins that defense in place.""" + mock_auth = MagicMock() + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) + + for bad in ("../../etc/passwd", "en/../fr", "a/b"): + request = MagicMock() + request.path = f"/api/v1/i18n/packs/{bad}" + request.headers = {} + request.match_info = {"lang": bad} + + resp = await endpoint.handle(request) + assert isinstance(resp, web.HTTPBadRequest), ( + f"path-traversal defense did not reject lang={bad!r}" + ) diff --git a/tests/unit/test_gateway/test_endpoint_manager.py b/tests/unit/test_gateway/test_endpoint_manager.py index 4766f8d7..cf12565c 100644 --- a/tests/unit/test_gateway/test_endpoint_manager.py +++ b/tests/unit/test_gateway/test_endpoint_manager.py @@ -12,30 +12,24 @@ class TestEndpointManager: """Test cases for EndpointManager class""" def test_endpoint_manager_initialization(self): - """Test EndpointManager initialization creates all endpoints""" + """EndpointManager wires up the full endpoint set and + records dispatcher_manager / timeout on the instance.""" mock_dispatcher_manager = MagicMock() mock_auth = MagicMock() - - # Mock dispatcher methods - mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock() - mock_dispatcher_manager.dispatch_socket.return_value = MagicMock() - mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock() - mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock() - mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock() - mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock() - mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock() - + + # The dispatcher_manager exposes a small set of factory + # methods — MagicMock auto-creates them, returning fresh + # MagicMocks on each call. manager = EndpointManager( dispatcher_manager=mock_dispatcher_manager, auth=mock_auth, prometheus_url="http://prometheus:9090", - timeout=300 + timeout=300, ) - + assert manager.dispatcher_manager == mock_dispatcher_manager assert manager.timeout == 300 - assert manager.services == {} - assert len(manager.endpoints) > 0 # Should have multiple endpoints + assert len(manager.endpoints) > 0 def test_endpoint_manager_with_default_timeout(self): """Test EndpointManager with default timeout value""" @@ -79,9 +73,15 @@ class TestEndpointManager: prometheus_url="http://test:9090" ) - # Verify all dispatcher methods were called during initialization + # Each dispatcher factory is invoked exactly once during + # construction — one per endpoint that needs a dedicated + # wire. dispatch_auth_iam is the dedicated factory for the + # AuthEndpoints forwarder (login / bootstrap / + # change-password), distinct from dispatch_global_service + # (the generic /api/v1/{kind} route). mock_dispatcher_manager.dispatch_global_service.assert_called_once() - mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice + mock_dispatcher_manager.dispatch_auth_iam.assert_called_once() + mock_dispatcher_manager.dispatch_socket.assert_called_once() mock_dispatcher_manager.dispatch_flow_service.assert_called_once() mock_dispatcher_manager.dispatch_flow_import.assert_called_once() mock_dispatcher_manager.dispatch_flow_export.assert_called_once() diff --git a/tests/unit/test_gateway/test_endpoint_metrics.py b/tests/unit/test_gateway/test_endpoint_metrics.py index bacf551d..6d911bbd 100644 --- a/tests/unit/test_gateway/test_endpoint_metrics.py +++ b/tests/unit/test_gateway/test_endpoint_metrics.py @@ -12,31 +12,35 @@ class TestMetricsEndpoint: """Test cases for MetricsEndpoint class""" def test_metrics_endpoint_initialization(self): - """Test MetricsEndpoint initialization""" + """Construction records the configured capability on the + instance. In production MetricsEndpoint is gated by + 'metrics:read' so that's the natural value to pass.""" mock_auth = MagicMock() - + endpoint = MetricsEndpoint( prometheus_url="http://prometheus:9090", endpoint_path="/metrics", - auth=mock_auth + auth=mock_auth, + capability="metrics:read", ) - + assert endpoint.prometheus_url == "http://prometheus:9090" assert endpoint.path == "/metrics" assert endpoint.auth == mock_auth - assert endpoint.operation == "service" + assert endpoint.capability == "metrics:read" @pytest.mark.asyncio async def test_metrics_endpoint_start_method(self): """Test MetricsEndpoint start method (should be no-op)""" mock_auth = MagicMock() - + endpoint = MetricsEndpoint( prometheus_url="http://localhost:9090", endpoint_path="/metrics", - auth=mock_auth + auth=mock_auth, + capability="metrics:read", ) - + # start() should complete without error await endpoint.start() @@ -44,15 +48,16 @@ class TestMetricsEndpoint: """Test add_routes method registers GET route with wildcard path""" mock_auth = MagicMock() mock_app = MagicMock() - + endpoint = MetricsEndpoint( prometheus_url="http://prometheus:9090", endpoint_path="/metrics", - auth=mock_auth + auth=mock_auth, + capability="metrics:read", ) - + endpoint.add_routes(mock_app) - + # Verify add_routes was called with GET route mock_app.add_routes.assert_called_once() # The call should include web.get with wildcard path pattern diff --git a/tests/unit/test_gateway/test_endpoint_socket.py b/tests/unit/test_gateway/test_endpoint_socket.py index 83eb38c2..189bc32b 100644 --- a/tests/unit/test_gateway/test_endpoint_socket.py +++ b/tests/unit/test_gateway/test_endpoint_socket.py @@ -1,5 +1,12 @@ """ -Tests for Gateway Socket Endpoint +Tests for Gateway Socket Endpoint. + +In production the only SocketEndpoint registered with HTTP-layer +auth is ``/api/v1/socket`` using ``capability=AUTHENTICATED`` with +``in_band_auth=True`` (first-frame auth over the websocket frames, +not at the handshake). The tests below use AUTHENTICATED as the +representative capability; construction / worker / listener +behaviour is independent of which capability is configured. """ import pytest @@ -7,41 +14,47 @@ from unittest.mock import MagicMock, AsyncMock from aiohttp import WSMsgType from trustgraph.gateway.endpoint.socket import SocketEndpoint +from trustgraph.gateway.capabilities import AUTHENTICATED class TestSocketEndpoint: """Test cases for SocketEndpoint class""" def test_socket_endpoint_initialization(self): - """Test SocketEndpoint initialization""" + """Construction records the configured capability on the + instance. No permissive default is applied.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = SocketEndpoint( endpoint_path="/api/socket", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability=AUTHENTICATED, ) - + assert endpoint.path == "/api/socket" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "socket" + assert endpoint.capability == AUTHENTICATED @pytest.mark.asyncio async def test_worker_method(self): """Test SocketEndpoint worker method""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) - + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) + mock_ws = MagicMock() mock_running = MagicMock() - + # Call worker method await endpoint.worker(mock_ws, mock_dispatcher, mock_running) - + # Verify dispatcher.run was called mock_dispatcher.run.assert_called_once() @@ -50,8 +63,11 @@ class TestSocketEndpoint: """Test SocketEndpoint listener method with text message""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) # Mock websocket with text message mock_msg = MagicMock() @@ -80,8 +96,11 @@ class TestSocketEndpoint: """Test SocketEndpoint listener method with binary message""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) # Mock websocket with binary message mock_msg = MagicMock() @@ -110,8 +129,11 @@ class TestSocketEndpoint: """Test SocketEndpoint listener method with close message""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) # Mock websocket with close message mock_msg = MagicMock() diff --git a/tests/unit/test_gateway/test_endpoint_stream.py b/tests/unit/test_gateway/test_endpoint_stream.py index b99946c8..a3b49465 100644 --- a/tests/unit/test_gateway/test_endpoint_stream.py +++ b/tests/unit/test_gateway/test_endpoint_stream.py @@ -12,48 +12,57 @@ class TestStreamEndpoint: """Test cases for StreamEndpoint class""" def test_stream_endpoint_initialization_with_post(self): - """Test StreamEndpoint initialization with POST method""" + """Construction records the configured capability on the + instance. StreamEndpoint is used in production for the + core-import / core-export / document-stream routes; a + document-write capability is a realistic value for a POST + stream (e.g. core-import).""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="POST" + capability="documents:write", + method="POST", ) - + assert endpoint.path == "/api/stream" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "service" + assert endpoint.capability == "documents:write" assert endpoint.method == "POST" def test_stream_endpoint_initialization_with_get(self): - """Test StreamEndpoint initialization with GET method""" + """GET stream — export-style endpoint, read capability.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="GET" + capability="documents:read", + method="GET", ) - + assert endpoint.method == "GET" def test_stream_endpoint_initialization_default_method(self): - """Test StreamEndpoint initialization with default POST method""" + """Test StreamEndpoint initialization with default POST method. + The method default is cosmetic; the capability is not + defaulted — it is always required.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability="documents:write", ) - + assert endpoint.method == "POST" # Default value @pytest.mark.asyncio @@ -61,9 +70,12 @@ class TestStreamEndpoint: """Test StreamEndpoint start method (should be no-op)""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - - endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher) - + + endpoint = StreamEndpoint( + "/api/stream", mock_auth, mock_dispatcher, + capability="documents:write", + ) + # start() should complete without error await endpoint.start() @@ -72,16 +84,17 @@ class TestStreamEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="POST" + capability="documents:write", + method="POST", ) - + endpoint.add_routes(mock_app) - + # Verify add_routes was called with POST route mock_app.add_routes.assert_called_once() call_args = mock_app.add_routes.call_args[0][0] @@ -92,16 +105,17 @@ class TestStreamEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="GET" + capability="documents:read", + method="GET", ) - + endpoint.add_routes(mock_app) - + # Verify add_routes was called with GET route mock_app.add_routes.assert_called_once() call_args = mock_app.add_routes.call_args[0][0] @@ -112,13 +126,14 @@ class TestStreamEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="INVALID" + capability="documents:write", + method="INVALID", ) - + with pytest.raises(RuntimeError, match="Bad method"): endpoint.add_routes(mock_app) \ No newline at end of file diff --git a/tests/unit/test_gateway/test_endpoint_variable.py b/tests/unit/test_gateway/test_endpoint_variable.py index ffaf4e9a..1cdc8f9f 100644 --- a/tests/unit/test_gateway/test_endpoint_variable.py +++ b/tests/unit/test_gateway/test_endpoint_variable.py @@ -12,29 +12,36 @@ class TestVariableEndpoint: """Test cases for VariableEndpoint class""" def test_variable_endpoint_initialization(self): - """Test VariableEndpoint initialization""" + """Construction records the configured capability on the + instance. VariableEndpoint is used in production for the + /api/v1/{kind} admin-scoped global service routes, so a + write-side capability is a realistic value for the test.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = VariableEndpoint( endpoint_path="/api/variable", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability="config:write", ) - + assert endpoint.path == "/api/variable" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "service" + assert endpoint.capability == "config:write" @pytest.mark.asyncio async def test_variable_endpoint_start_method(self): """Test VariableEndpoint start method (should be no-op)""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - - endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher) - + + endpoint = VariableEndpoint( + "/api/var", mock_auth, mock_dispatcher, + capability="config:write", + ) + # start() should complete without error await endpoint.start() @@ -43,10 +50,13 @@ class TestVariableEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - - endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher) + + endpoint = VariableEndpoint( + "/api/variable", mock_auth, mock_dispatcher, + capability="config:write", + ) endpoint.add_routes(mock_app) - + # Verify add_routes was called with POST route mock_app.add_routes.assert_called_once() call_args = mock_app.add_routes.call_args[0][0] diff --git a/tests/unit/test_gateway/test_service.py b/tests/unit/test_gateway/test_service.py index 71428db4..107e6819 100644 --- a/tests/unit/test_gateway/test_service.py +++ b/tests/unit/test_gateway/test_service.py @@ -1,355 +1,179 @@ """ -Tests for Gateway Service API +Tests for gateway/service.py — the Api class that wires together +the pub/sub backend, IAM auth, config receiver, dispatcher manager, +and endpoint manager. + +The legacy ``GATEWAY_SECRET`` / ``default_api_token`` / allow-all +surface is gone, so the tests here focus on the Api's construction +and composition rather than the removed auth behaviour. IamAuth's +own behaviour is covered in test_auth.py. """ import pytest -import asyncio -from unittest.mock import Mock, patch, MagicMock, AsyncMock +from unittest.mock import AsyncMock, Mock, patch from aiohttp import web -import pulsar -from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token - -# Tests for Gateway Service API +from trustgraph.gateway.service import ( + Api, + default_pulsar_host, default_prometheus_url, + default_timeout, default_port, +) +from trustgraph.gateway.auth import IamAuth -class TestApi: - """Test cases for Api class""" - +# -- constants ------------------------------------------------------------- - def test_api_initialization_with_defaults(self): - """Test Api initialization with default values""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_backend = Mock() - mock_get_pubsub.return_value = mock_backend - api = Api() +class TestDefaults: - assert api.port == default_port - assert api.timeout == default_timeout - assert api.pulsar_host == default_pulsar_host - assert api.pulsar_api_key is None - assert api.prometheus_url == default_prometheus_url + "/" - assert api.auth.allow_all is True + def test_exports_default_constants(self): + # These are consumed by CLIs / tests / docs. Sanity-check + # that they're the expected shape. + assert default_port == 8088 + assert default_timeout == 600 + assert default_pulsar_host.startswith("pulsar://") + assert default_prometheus_url.startswith("http") - # Verify get_pubsub was called - mock_get_pubsub.assert_called_once() - def test_api_initialization_with_custom_config(self): - """Test Api initialization with custom configuration""" +# -- Api construction ------------------------------------------------------ + + +@pytest.fixture +def mock_backend(): + return Mock() + + +@pytest.fixture +def api(mock_backend): + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + yield Api() + + +class TestApiConstruction: + + def test_defaults(self, api): + assert api.port == default_port + assert api.timeout == default_timeout + assert api.pulsar_host == default_pulsar_host + assert api.pulsar_api_key is None + # prometheus_url gets normalised with a trailing slash + assert api.prometheus_url == default_prometheus_url + "/" + + def test_auth_is_iam_backed(self, api): + # Any Api always gets an IamAuth. There is no "no auth" mode + # (GATEWAY_SECRET / allow_all has been removed — see IAM spec). + assert isinstance(api.auth, IamAuth) + + def test_components_wired(self, api): + assert api.config_receiver is not None + assert api.dispatcher_manager is not None + assert api.endpoint_manager is not None + + def test_dispatcher_manager_has_auth(self, api): + # The Mux uses this handle for first-frame socket auth. + assert api.dispatcher_manager.auth is api.auth + + def test_custom_config(self, mock_backend): config = { "port": 9000, "timeout": 300, "pulsar_host": "pulsar://custom-host:6650", - "pulsar_api_key": "test-api-key", - "pulsar_listener": "custom-listener", + "pulsar_api_key": "custom-key", "prometheus_url": "http://custom-prometheus:9090", - "api_token": "secret-token" } + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + a = Api(**config) - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_backend = Mock() - mock_get_pubsub.return_value = mock_backend + assert a.port == 9000 + assert a.timeout == 300 + assert a.pulsar_host == "pulsar://custom-host:6650" + assert a.pulsar_api_key == "custom-key" + # Trailing slash added. + assert a.prometheus_url == "http://custom-prometheus:9090/" - api = Api(**config) + def test_prometheus_url_already_has_trailing_slash(self, mock_backend): + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + a = Api(prometheus_url="http://p:9090/") + assert a.prometheus_url == "http://p:9090/" - assert api.port == 9000 - assert api.timeout == 300 - assert api.pulsar_host == "pulsar://custom-host:6650" - assert api.pulsar_api_key == "test-api-key" - assert api.prometheus_url == "http://custom-prometheus:9090/" - assert api.auth.token == "secret-token" - assert api.auth.allow_all is False + def test_queue_overrides_parsed_for_config(self, mock_backend): + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + a = Api( + config_request_queue="alt-config-req", + config_response_queue="alt-config-resp", + ) + overrides = a.dispatcher_manager.queue_overrides + assert overrides.get("config", {}).get("request") == "alt-config-req" + assert overrides.get("config", {}).get("response") == "alt-config-resp" - # Verify get_pubsub was called with config - mock_get_pubsub.assert_called_once_with(**config) - def test_api_initialization_with_pulsar_api_key(self): - """Test Api initialization with Pulsar API key authentication""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() +# -- app_factory ----------------------------------------------------------- - api = Api(pulsar_api_key="test-key") - # Verify api key was stored - assert api.pulsar_api_key == "test-key" - mock_get_pubsub.assert_called_once() - - def test_api_initialization_prometheus_url_normalization(self): - """Test that prometheus_url gets normalized with trailing slash""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - # Test URL without trailing slash - api = Api(prometheus_url="http://prometheus:9090") - assert api.prometheus_url == "http://prometheus:9090/" - - # Test URL with trailing slash - api = Api(prometheus_url="http://prometheus:9090/") - assert api.prometheus_url == "http://prometheus:9090/" - - def test_api_initialization_empty_api_token_means_no_auth(self): - """Test that empty API token results in allow_all authentication""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api(api_token="") - assert api.auth.allow_all is True - - def test_api_initialization_none_api_token_means_no_auth(self): - """Test that None API token results in allow_all authentication""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api(api_token=None) - assert api.auth.allow_all is True +class TestAppFactory: @pytest.mark.asyncio - async def test_app_factory_creates_application(self): - """Test that app_factory creates aiohttp application""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api() - - # Mock the dependencies - api.config_receiver = Mock() - api.config_receiver.start = AsyncMock() - api.endpoint_manager = Mock() - api.endpoint_manager.add_routes = Mock() - api.endpoint_manager.start = AsyncMock() - - app = await api.app_factory() - - assert isinstance(app, web.Application) - assert app._client_max_size == 256 * 1024 * 1024 - - # Verify that config receiver was started - api.config_receiver.start.assert_called_once() - - # Verify that endpoint manager was configured - api.endpoint_manager.add_routes.assert_called_once_with(app) - api.endpoint_manager.start.assert_called_once() + async def test_creates_aiohttp_app(self, api): + # Stub out the long-tail dependencies that reach out to IAM / + # pub/sub so we can exercise the factory in isolation. + api.auth.start = AsyncMock() + api.config_receiver = Mock() + api.config_receiver.start = AsyncMock() + api.endpoint_manager = Mock() + api.endpoint_manager.add_routes = Mock() + api.endpoint_manager.start = AsyncMock() + api.endpoints = [] + + app = await api.app_factory() + + assert isinstance(app, web.Application) + assert app._client_max_size == 256 * 1024 * 1024 + api.auth.start.assert_called_once() + api.config_receiver.start.assert_called_once() + api.endpoint_manager.add_routes.assert_called_once_with(app) + api.endpoint_manager.start.assert_called_once() @pytest.mark.asyncio - async def test_app_factory_with_custom_endpoints(self): - """Test app_factory with custom endpoints""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api() - - # Mock custom endpoints - mock_endpoint1 = Mock() - mock_endpoint1.add_routes = Mock() - mock_endpoint1.start = AsyncMock() - - mock_endpoint2 = Mock() - mock_endpoint2.add_routes = Mock() - mock_endpoint2.start = AsyncMock() - - api.endpoints = [mock_endpoint1, mock_endpoint2] - - # Mock the dependencies - api.config_receiver = Mock() - api.config_receiver.start = AsyncMock() - api.endpoint_manager = Mock() - api.endpoint_manager.add_routes = Mock() - api.endpoint_manager.start = AsyncMock() - - app = await api.app_factory() - - # Verify custom endpoints were configured - mock_endpoint1.add_routes.assert_called_once_with(app) - mock_endpoint1.start.assert_called_once() - mock_endpoint2.add_routes.assert_called_once_with(app) - mock_endpoint2.start.assert_called_once() + async def test_auth_start_runs_before_accepting_traffic(self, api): + """``auth.start()`` fetches the IAM signing key, and must + complete (or time out) before the gateway begins accepting + requests. It's the first await in app_factory.""" + order = [] - def test_run_method_calls_web_run_app(self): - """Test that run method calls web.run_app""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \ - patch('aiohttp.web.run_app') as mock_run_app: - mock_get_pubsub.return_value = Mock() + # AsyncMock.side_effect expects a sync callable (its return + # value becomes the coroutine's return); a plain list.append + # avoids the "coroutine was never awaited" trap of an async + # side_effect. + api.auth.start = AsyncMock( + side_effect=lambda: order.append("auth"), + ) + api.config_receiver = Mock() + api.config_receiver.start = AsyncMock( + side_effect=lambda: order.append("config"), + ) + api.endpoint_manager = Mock() + api.endpoint_manager.add_routes = Mock() + api.endpoint_manager.start = AsyncMock( + side_effect=lambda: order.append("endpoints"), + ) + api.endpoints = [] - # Api.run() passes self.app_factory() — a coroutine — to - # web.run_app, which would normally consume it inside its own - # event loop. Since we mock run_app, close the coroutine here - # so it doesn't leak as an "unawaited coroutine" RuntimeWarning. - def _consume_coro(coro, **kwargs): - coro.close() - mock_run_app.side_effect = _consume_coro + await api.app_factory() - api = Api(port=8080) - api.run() - - # Verify run_app was called once with the correct port - mock_run_app.assert_called_once() - args, kwargs = mock_run_app.call_args - assert len(args) == 1 # Should have one positional arg (the coroutine) - assert kwargs == {'port': 8080} # Should have port keyword arg - - def test_api_components_initialization(self): - """Test that all API components are properly initialized""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api() - - # Verify all components are initialized - assert api.config_receiver is not None - assert api.dispatcher_manager is not None - assert api.endpoint_manager is not None - assert api.endpoints == [] - - # Verify component relationships - assert api.dispatcher_manager.backend == api.pubsub_backend - assert api.dispatcher_manager.config_receiver == api.config_receiver - assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager - # EndpointManager doesn't store auth directly, it passes it to individual endpoints - - -class TestRunFunction: - """Test cases for the run() function""" - - def test_run_function_with_metrics_enabled(self): - """Test run function with metrics enabled""" - import warnings - # Suppress the specific async warning with a broader pattern - warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning) - - with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \ - patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server: - - # Mock command line arguments - mock_args = Mock() - mock_args.metrics = True - mock_args.metrics_port = 8000 - mock_parse_args.return_value = mock_args - - # Create a simple mock instance without any async methods - mock_api_instance = Mock() - mock_api_instance.run = Mock() - - # Create a mock Api class without importing the real one - mock_api = Mock(return_value=mock_api_instance) - - # Patch using context manager to avoid importing the real Api class - with patch('trustgraph.gateway.service.Api', mock_api): - # Mock vars() to return a dict - with patch('builtins.vars') as mock_vars: - mock_vars.return_value = { - 'metrics': True, - 'metrics_port': 8000, - 'pulsar_host': default_pulsar_host, - 'timeout': default_timeout - } - - run() - - # Verify metrics server was started - mock_start_http_server.assert_called_once_with(8000) - - # Verify Api was created and run was called - mock_api.assert_called_once() - mock_api_instance.run.assert_called_once() - - @patch('trustgraph.gateway.service.start_http_server') - @patch('argparse.ArgumentParser.parse_args') - def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server): - """Test run function with metrics disabled""" - # Mock command line arguments - mock_args = Mock() - mock_args.metrics = False - mock_parse_args.return_value = mock_args - - # Create a simple mock instance without any async methods - mock_api_instance = Mock() - mock_api_instance.run = Mock() - - # Patch the Api class inside the test without using decorators - with patch('trustgraph.gateway.service.Api') as mock_api: - mock_api.return_value = mock_api_instance - - # Mock vars() to return a dict - with patch('builtins.vars') as mock_vars: - mock_vars.return_value = { - 'metrics': False, - 'metrics_port': 8000, - 'pulsar_host': default_pulsar_host, - 'timeout': default_timeout - } - - run() - - # Verify metrics server was NOT started - mock_start_http_server.assert_not_called() - - # Verify Api was created and run was called - mock_api.assert_called_once() - mock_api_instance.run.assert_called_once() - - @patch('argparse.ArgumentParser.parse_args') - def test_run_function_argument_parsing(self, mock_parse_args): - """Test that run function properly parses command line arguments""" - # Mock command line arguments - mock_args = Mock() - mock_args.metrics = False - mock_parse_args.return_value = mock_args - - # Create a simple mock instance without any async methods - mock_api_instance = Mock() - mock_api_instance.run = Mock() - - # Mock vars() to return a dict with all expected arguments - expected_args = { - 'pulsar_host': 'pulsar://test:6650', - 'pulsar_api_key': 'test-key', - 'pulsar_listener': 'test-listener', - 'prometheus_url': 'http://test-prometheus:9090', - 'port': 9000, - 'timeout': 300, - 'api_token': 'secret', - 'log_level': 'INFO', - 'metrics': False, - 'metrics_port': 8001 - } - - # Patch the Api class inside the test without using decorators - with patch('trustgraph.gateway.service.Api') as mock_api: - mock_api.return_value = mock_api_instance - - with patch('builtins.vars') as mock_vars: - mock_vars.return_value = expected_args - - run() - - # Verify Api was created with the parsed arguments - mock_api.assert_called_once_with(**expected_args) - mock_api_instance.run.assert_called_once() - - def test_run_function_creates_argument_parser(self): - """Test that run function creates argument parser with correct arguments""" - with patch('argparse.ArgumentParser') as mock_parser_class: - mock_parser = Mock() - mock_parser_class.return_value = mock_parser - mock_parser.parse_args.return_value = Mock(metrics=False) - - with patch('trustgraph.gateway.service.Api') as mock_api, \ - patch('builtins.vars') as mock_vars: - mock_vars.return_value = {'metrics': False} - mock_api.return_value = Mock() - - run() - - # Verify ArgumentParser was created - mock_parser_class.assert_called_once() - - # Verify add_argument was called for each expected argument - expected_arguments = [ - 'pulsar-host', 'pulsar-api-key', 'pulsar-listener', - 'prometheus-url', 'port', 'timeout', 'api-token', - 'log-level', 'metrics', 'metrics-port' - ] - - # Check that add_argument was called multiple times (once for each arg) - assert mock_parser.add_argument.call_count >= len(expected_arguments) \ No newline at end of file + # auth.start must be first (before config receiver, before + # any endpoint starts). + assert order[0] == "auth" + # All three must have run. + assert set(order) == {"auth", "config", "endpoints"} diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py index 1a63227d..23f22d30 100644 --- a/tests/unit/test_gateway/test_socket_graceful_shutdown.py +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -1,4 +1,15 @@ -"""Unit tests for SocketEndpoint graceful shutdown functionality.""" +"""Unit tests for SocketEndpoint graceful shutdown functionality. + +These tests exercise SocketEndpoint in its handshake-auth +configuration (``in_band_auth=False``) — the mode used in production +for the flow import/export streaming endpoints. The mux socket at +``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the +handshake always accepts and authentication runs on the first +WebSocket frame; that path is covered by the Mux tests. + +Every endpoint constructor here passes an explicit capability — no +permissive default is relied upon. +""" import pytest import asyncio @@ -6,13 +17,31 @@ from unittest.mock import AsyncMock, MagicMock, patch from aiohttp import web, WSMsgType from trustgraph.gateway.endpoint.socket import SocketEndpoint from trustgraph.gateway.running import Running +from trustgraph.gateway.auth import Identity + + +# Representative capability used across these tests — corresponds to +# the flow-import streaming endpoint pattern that uses this class. +TEST_CAP = "graph:write" + + +def _valid_identity(roles=("admin",)): + return Identity( + user_id="test-user", + workspace="default", + roles=list(roles), + source="api-key", + ) @pytest.fixture def mock_auth(): - """Mock authentication service.""" + """Mock IAM-backed authenticator. Successful by default — + ``authenticate`` returns a valid admin identity. Tests that + need the auth failure path override the ``authenticate`` + attribute locally.""" auth = MagicMock() - auth.permitted.return_value = True + auth.authenticate = AsyncMock(return_value=_valid_identity()) return auth @@ -25,7 +54,7 @@ def mock_dispatcher_factory(): dispatcher.receive = AsyncMock() dispatcher.destroy = AsyncMock() return dispatcher - + return dispatcher_factory @@ -35,7 +64,8 @@ def socket_endpoint(mock_auth, mock_dispatcher_factory): return SocketEndpoint( endpoint_path="/test-socket", auth=mock_auth, - dispatcher=mock_dispatcher_factory + dispatcher=mock_dispatcher_factory, + capability=TEST_CAP, ) @@ -61,7 +91,10 @@ def mock_request(): @pytest.mark.asyncio async def test_listener_graceful_shutdown_on_close(): """Test listener handles websocket close gracefully.""" - socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock()) + socket_endpoint = SocketEndpoint( + "/test", MagicMock(), AsyncMock(), + capability=TEST_CAP, + ) # Mock websocket that closes after one message ws = AsyncMock() @@ -99,9 +132,9 @@ async def test_listener_graceful_shutdown_on_close(): @pytest.mark.asyncio async def test_handle_normal_flow(): - """Test normal websocket handling flow.""" + """Valid bearer → handshake accepted, dispatcher created.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) dispatcher_created = False async def mock_dispatcher_factory(ws, running, match_info): @@ -111,7 +144,10 @@ async def test_handle_normal_flow(): dispatcher.destroy = AsyncMock() return dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} @@ -155,7 +191,7 @@ async def test_handle_normal_flow(): async def test_handle_exception_group_cleanup(): """Test exception group triggers dispatcher cleanup.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() @@ -163,7 +199,10 @@ async def test_handle_exception_group_cleanup(): async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} @@ -222,7 +261,7 @@ async def test_handle_exception_group_cleanup(): async def test_handle_dispatcher_cleanup_timeout(): """Test dispatcher cleanup with timeout.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) # Mock dispatcher that takes long to destroy mock_dispatcher = AsyncMock() @@ -231,7 +270,10 @@ async def test_handle_dispatcher_cleanup_timeout(): async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} @@ -285,49 +327,67 @@ async def test_handle_dispatcher_cleanup_timeout(): @pytest.mark.asyncio async def test_handle_unauthorized_request(): - """Test handling of unauthorized requests.""" + """A bearer that the IAM layer rejects causes the handshake to + fail with 401. IamAuth surfaces an HTTPUnauthorized; the + endpoint propagates it. Note that the endpoint intentionally + does NOT distinguish 'bad token', 'expired', 'revoked', etc. — + that's the IAM error-masking policy.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = False # Unauthorized - - socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock()) - + mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + )) + + socket_endpoint = SocketEndpoint( + "/test", mock_auth, AsyncMock(), + capability=TEST_CAP, + ) + request = MagicMock() request.query = {"token": "invalid-token"} - + result = await socket_endpoint.handle(request) - - # Should return HTTP 401 + assert isinstance(result, web.HTTPUnauthorized) - - # Should have checked permission - mock_auth.permitted.assert_called_once_with("invalid-token", "socket") + # authenticate must have been invoked with a synthetic request + # carrying Bearer . The endpoint wraps the query- + # string token into an Authorization header for a uniform auth + # path — the IAM layer does not look at query strings directly. + mock_auth.authenticate.assert_called_once() + passed_req = mock_auth.authenticate.call_args.args[0] + assert passed_req.headers["Authorization"] == "Bearer invalid-token" @pytest.mark.asyncio async def test_handle_missing_token(): - """Test handling of requests with missing token.""" + """Request with no ``token`` query param → 401 before any + IAM call is made (cheap short-circuit).""" mock_auth = MagicMock() - mock_auth.permitted.return_value = False - - socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock()) - + mock_auth.authenticate = AsyncMock( + side_effect=AssertionError( + "authenticate must not be invoked when no token is present" + ), + ) + + socket_endpoint = SocketEndpoint( + "/test", mock_auth, AsyncMock(), + capability=TEST_CAP, + ) + request = MagicMock() request.query = {} # No token - + result = await socket_endpoint.handle(request) - - # Should return HTTP 401 + assert isinstance(result, web.HTTPUnauthorized) - - # Should have checked permission with empty token - mock_auth.permitted.assert_called_once_with("", "socket") + mock_auth.authenticate.assert_not_called() @pytest.mark.asyncio async def test_handle_websocket_already_closed(): """Test handling when websocket is already closed.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() @@ -335,7 +395,10 @@ async def test_handle_websocket_already_closed(): async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index e5d553ea..ca9146b9 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -49,21 +49,67 @@ class AsyncSocketClient: return f"ws://{url}" def _build_ws_url(self): - ws_url = f"{self.url.rstrip('/')}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" - return ws_url + # /api/v1/socket uses the first-frame auth protocol — the + # token is sent as the first frame after connecting rather + # than in the URL. This avoids browser issues with 401 on + # the WebSocket handshake and lets long-lived sockets + # refresh credentials mid-session. + return f"{self.url.rstrip('/')}/api/v1/socket" async def connect(self): - """Establish the persistent websocket connection.""" + """Establish the persistent websocket connection and run the + first-frame auth handshake.""" if self._connected: return + if not self.token: + raise ProtocolException( + "AsyncSocketClient requires a token for first-frame " + "auth against /api/v1/socket" + ) + ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout ) self._socket = await self._connect_cm.__aenter__() + + # First-frame auth: send {"type":"auth","token":"..."} and + # wait for auth-ok / auth-failed. Run before starting the + # reader task so the response isn't consumed by the reader's + # id-based routing. + await self._socket.send(json.dumps({ + "type": "auth", "token": self.token, + })) + try: + raw = await asyncio.wait_for( + self._socket.recv(), timeout=self.timeout, + ) + except asyncio.TimeoutError: + await self._socket.close() + raise ProtocolException("Timeout waiting for auth response") + + try: + resp = json.loads(raw) + except Exception: + await self._socket.close() + raise ProtocolException( + f"Unexpected non-JSON auth response: {raw!r}" + ) + + if resp.get("type") == "auth-ok": + self.workspace = resp.get("workspace", self.workspace) + elif resp.get("type") == "auth-failed": + await self._socket.close() + raise ProtocolException( + f"auth failure: {resp.get('error', 'unknown')}" + ) + else: + await self._socket.close() + raise ProtocolException( + f"Unexpected auth response: {resp!r}" + ) + self._connected = True self._reader_task = asyncio.create_task(self._reader()) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 4eade3e8..aeb15f85 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -112,10 +112,10 @@ class SocketClient: return f"ws://{url}" def _build_ws_url(self): - ws_url = f"{self.url.rstrip('/')}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" - return ws_url + # /api/v1/socket uses the first-frame auth protocol — the + # token is sent as the first frame after connecting rather + # than in the URL. + return f"{self.url.rstrip('/')}/api/v1/socket" def _get_loop(self): """Get or create the event loop, reusing across calls.""" @@ -132,15 +132,58 @@ class SocketClient: return self._loop async def _ensure_connected(self): - """Lazily establish the persistent websocket connection.""" + """Lazily establish the persistent websocket connection and + run the first-frame auth handshake.""" if self._connected: return + if not self.token: + raise ProtocolException( + "SocketClient requires a token for first-frame auth " + "against /api/v1/socket" + ) + ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout ) self._socket = await self._connect_cm.__aenter__() + + # First-frame auth — run before starting the reader so the + # auth-ok / auth-failed response isn't consumed by the reader + # loop's id-based routing. + await self._socket.send(json.dumps({ + "type": "auth", "token": self.token, + })) + try: + raw = await asyncio.wait_for( + self._socket.recv(), timeout=self.timeout, + ) + except asyncio.TimeoutError: + await self._socket.close() + raise ProtocolException("Timeout waiting for auth response") + + try: + resp = json.loads(raw) + except Exception: + await self._socket.close() + raise ProtocolException( + f"Unexpected non-JSON auth response: {raw!r}" + ) + + if resp.get("type") == "auth-ok": + self.workspace = resp.get("workspace", self.workspace) + elif resp.get("type") == "auth-failed": + await self._socket.close() + raise ProtocolException( + f"auth failure: {resp.get('error', 'unknown')}" + ) + else: + await self._socket.close() + raise ProtocolException( + f"Unexpected auth response: {resp!r}" + ) + self._connected = True self._reader_task = asyncio.create_task(self._reader()) diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py new file mode 100644 index 00000000..5cfda7c8 --- /dev/null +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -0,0 +1,279 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import ( + IamRequest, IamResponse, + UserInput, WorkspaceInput, ApiKeyInput, +) + +IAM_TIMEOUT = 10 + + +class IamClient(RequestResponse): + """Client for the IAM service request/response pub/sub protocol. + + Mirrors ``ConfigClient``: a thin wrapper around ``RequestResponse`` + that knows the IAM request / response schemas. Only the subset of + operations actually implemented by the server today has helper + methods here; callers that need an unimplemented operation can + build ``IamRequest`` and call ``request()`` directly. + """ + + async def _request(self, timeout=IAM_TIMEOUT, **kwargs): + resp = await self.request( + IamRequest(**kwargs), + timeout=timeout, + ) + if resp.error: + raise RuntimeError( + f"{resp.error.type}: {resp.error.message}" + ) + return resp + + async def bootstrap(self, timeout=IAM_TIMEOUT): + """Initial-run IAM self-seed. Returns a tuple of + ``(admin_user_id, admin_api_key_plaintext)``. Both are empty + strings on repeat calls — the operation is a no-op once the + IAM tables are populated.""" + resp = await self._request( + operation="bootstrap", timeout=timeout, + ) + return resp.bootstrap_admin_user_id, resp.bootstrap_admin_api_key + + async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT): + """Resolve a plaintext API key to its identity triple. + + Returns ``(user_id, workspace, roles)`` or raises + ``RuntimeError`` with error type ``auth-failed`` if the key is + unknown / expired / revoked.""" + resp = await self._request( + operation="resolve-api-key", + api_key=api_key, + timeout=timeout, + ) + return ( + resp.resolved_user_id, + resp.resolved_workspace, + list(resp.resolved_roles), + ) + + async def create_user(self, workspace, user, actor="", + timeout=IAM_TIMEOUT): + """Create a user. ``user`` is a ``UserInput``.""" + resp = await self._request( + operation="create-user", + workspace=workspace, + actor=actor, + user=user, + timeout=timeout, + ) + return resp.user + + async def list_users(self, workspace, actor="", timeout=IAM_TIMEOUT): + resp = await self._request( + operation="list-users", + workspace=workspace, + actor=actor, + timeout=timeout, + ) + return list(resp.users) + + async def create_api_key(self, workspace, key, actor="", + timeout=IAM_TIMEOUT): + """Create an API key. ``key`` is an ``ApiKeyInput``. Returns + ``(plaintext, record)`` — plaintext is returned once and the + caller is responsible for surfacing it to the operator.""" + resp = await self._request( + operation="create-api-key", + workspace=workspace, + actor=actor, + key=key, + timeout=timeout, + ) + return resp.api_key_plaintext, resp.api_key + + async def list_api_keys(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="list-api-keys", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + return list(resp.api_keys) + + async def revoke_api_key(self, workspace, key_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="revoke-api-key", + workspace=workspace, + actor=actor, + key_id=key_id, + timeout=timeout, + ) + + async def login(self, username, password, workspace="", + timeout=IAM_TIMEOUT): + """Validate credentials and return ``(jwt, expires_iso)``. + ``workspace`` is optional; defaults at the server to the + OSS default workspace.""" + resp = await self._request( + operation="login", + workspace=workspace, + username=username, + password=password, + timeout=timeout, + ) + return resp.jwt, resp.jwt_expires + + async def get_signing_key_public(self, timeout=IAM_TIMEOUT): + """Return the active JWT signing public key in PEM. The + gateway calls this at startup and caches the result.""" + resp = await self._request( + operation="get-signing-key-public", + timeout=timeout, + ) + return resp.signing_key_public + + async def change_password(self, user_id, current_password, + new_password, timeout=IAM_TIMEOUT): + await self._request( + operation="change-password", + user_id=user_id, + password=current_password, + new_password=new_password, + timeout=timeout, + ) + + async def reset_password(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + """Admin-driven password reset. Returns the plaintext + temporary password (returned once).""" + resp = await self._request( + operation="reset-password", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + return resp.temporary_password + + async def get_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="get-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + return resp.user + + async def update_user(self, workspace, user_id, user, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="update-user", + workspace=workspace, + actor=actor, + user_id=user_id, + user=user, + timeout=timeout, + ) + return resp.user + + async def disable_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="disable-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + + async def enable_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="enable-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + + async def delete_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="delete-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + + async def create_workspace(self, workspace_record, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="create-workspace", + actor=actor, + workspace_record=workspace_record, + timeout=timeout, + ) + return resp.workspace + + async def list_workspaces(self, actor="", timeout=IAM_TIMEOUT): + resp = await self._request( + operation="list-workspaces", + actor=actor, + timeout=timeout, + ) + return list(resp.workspaces) + + async def get_workspace(self, workspace_id, actor="", + timeout=IAM_TIMEOUT): + from ..schema import WorkspaceInput + resp = await self._request( + operation="get-workspace", + actor=actor, + workspace_record=WorkspaceInput(id=workspace_id), + timeout=timeout, + ) + return resp.workspace + + async def update_workspace(self, workspace_record, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="update-workspace", + actor=actor, + workspace_record=workspace_record, + timeout=timeout, + ) + return resp.workspace + + async def disable_workspace(self, workspace_id, actor="", + timeout=IAM_TIMEOUT): + from ..schema import WorkspaceInput + await self._request( + operation="disable-workspace", + actor=actor, + workspace_record=WorkspaceInput(id=workspace_id), + timeout=timeout, + ) + + async def rotate_signing_key(self, actor="", timeout=IAM_TIMEOUT): + await self._request( + operation="rotate-signing-key", + actor=actor, + timeout=timeout, + ) + + +class IamClientSpec(RequestResponseSpec): + def __init__(self, request_name, response_name): + super().__init__( + request_name=request_name, + request_schema=IamRequest, + response_name=response_name, + response_schema=IamResponse, + impl=IamClient, + ) diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 30f5061c..9fcfa6f7 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -15,6 +15,7 @@ from .translators.library import LibraryRequestTranslator, LibraryResponseTransl from .translators.document_loading import DocumentTranslator, TextDocumentTranslator from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator from .translators.flow import FlowRequestTranslator, FlowResponseTranslator +from .translators.iam import IamRequestTranslator, IamResponseTranslator from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator from .translators.tool import ToolRequestTranslator, ToolResponseTranslator from .translators.embeddings_query import ( @@ -85,11 +86,17 @@ TranslatorRegistry.register_service( ) TranslatorRegistry.register_service( - "flow", - FlowRequestTranslator(), + "flow", + FlowRequestTranslator(), FlowResponseTranslator() ) +TranslatorRegistry.register_service( + "iam", + IamRequestTranslator(), + IamResponseTranslator() +) + TranslatorRegistry.register_service( "prompt", PromptRequestTranslator(), diff --git a/trustgraph-base/trustgraph/messaging/translators/iam.py b/trustgraph-base/trustgraph/messaging/translators/iam.py new file mode 100644 index 00000000..4a717bba --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/iam.py @@ -0,0 +1,194 @@ +from typing import Dict, Any, Tuple + +from ...schema import IamRequest, IamResponse +from ...schema import ( + UserInput, UserRecord, + WorkspaceInput, WorkspaceRecord, + ApiKeyInput, ApiKeyRecord, +) +from .base import MessageTranslator + + +def _user_input_from_dict(d): + if d is None: + return None + return UserInput( + username=d.get("username", ""), + name=d.get("name", ""), + email=d.get("email", ""), + password=d.get("password", ""), + roles=list(d.get("roles", [])), + enabled=d.get("enabled", True), + must_change_password=d.get("must_change_password", False), + ) + + +def _workspace_input_from_dict(d): + if d is None: + return None + return WorkspaceInput( + id=d.get("id", ""), + name=d.get("name", ""), + enabled=d.get("enabled", True), + ) + + +def _api_key_input_from_dict(d): + if d is None: + return None + return ApiKeyInput( + user_id=d.get("user_id", ""), + name=d.get("name", ""), + expires=d.get("expires", ""), + ) + + +def _user_record_to_dict(r): + if r is None: + return None + return { + "id": r.id, + "workspace": r.workspace, + "username": r.username, + "name": r.name, + "email": r.email, + "roles": list(r.roles), + "enabled": r.enabled, + "must_change_password": r.must_change_password, + "created": r.created, + } + + +def _workspace_record_to_dict(r): + if r is None: + return None + return { + "id": r.id, + "name": r.name, + "enabled": r.enabled, + "created": r.created, + } + + +def _api_key_record_to_dict(r): + if r is None: + return None + return { + "id": r.id, + "user_id": r.user_id, + "name": r.name, + "prefix": r.prefix, + "expires": r.expires, + "created": r.created, + "last_used": r.last_used, + } + + +class IamRequestTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> IamRequest: + return IamRequest( + operation=data.get("operation", ""), + workspace=data.get("workspace", ""), + actor=data.get("actor", ""), + user_id=data.get("user_id", ""), + username=data.get("username", ""), + key_id=data.get("key_id", ""), + api_key=data.get("api_key", ""), + password=data.get("password", ""), + new_password=data.get("new_password", ""), + user=_user_input_from_dict(data.get("user")), + workspace_record=_workspace_input_from_dict( + data.get("workspace_record") + ), + key=_api_key_input_from_dict(data.get("key")), + ) + + def encode(self, obj: IamRequest) -> Dict[str, Any]: + result = {"operation": obj.operation} + for fname in ( + "workspace", "actor", "user_id", "username", "key_id", + "api_key", "password", "new_password", + ): + v = getattr(obj, fname, "") + if v: + result[fname] = v + if obj.user is not None: + result["user"] = { + "username": obj.user.username, + "name": obj.user.name, + "email": obj.user.email, + "password": obj.user.password, + "roles": list(obj.user.roles), + "enabled": obj.user.enabled, + "must_change_password": obj.user.must_change_password, + } + if obj.workspace_record is not None: + result["workspace_record"] = { + "id": obj.workspace_record.id, + "name": obj.workspace_record.name, + "enabled": obj.workspace_record.enabled, + } + if obj.key is not None: + result["key"] = { + "user_id": obj.key.user_id, + "name": obj.key.name, + "expires": obj.key.expires, + } + return result + + +class IamResponseTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> IamResponse: + raise NotImplementedError( + "IamResponse is a server-produced message; no HTTP→schema " + "path is needed" + ) + + def encode(self, obj: IamResponse) -> Dict[str, Any]: + result: Dict[str, Any] = {} + + if obj.user is not None: + result["user"] = _user_record_to_dict(obj.user) + if obj.users: + result["users"] = [_user_record_to_dict(u) for u in obj.users] + if obj.workspace is not None: + result["workspace"] = _workspace_record_to_dict(obj.workspace) + if obj.workspaces: + result["workspaces"] = [ + _workspace_record_to_dict(w) for w in obj.workspaces + ] + if obj.api_key_plaintext: + result["api_key_plaintext"] = obj.api_key_plaintext + if obj.api_key is not None: + result["api_key"] = _api_key_record_to_dict(obj.api_key) + if obj.api_keys: + result["api_keys"] = [ + _api_key_record_to_dict(k) for k in obj.api_keys + ] + if obj.jwt: + result["jwt"] = obj.jwt + if obj.jwt_expires: + result["jwt_expires"] = obj.jwt_expires + if obj.signing_key_public: + result["signing_key_public"] = obj.signing_key_public + if obj.resolved_user_id: + result["resolved_user_id"] = obj.resolved_user_id + if obj.resolved_workspace: + result["resolved_workspace"] = obj.resolved_workspace + if obj.resolved_roles: + result["resolved_roles"] = list(obj.resolved_roles) + if obj.temporary_password: + result["temporary_password"] = obj.temporary_password + if obj.bootstrap_admin_user_id: + result["bootstrap_admin_user_id"] = obj.bootstrap_admin_user_id + if obj.bootstrap_admin_api_key: + result["bootstrap_admin_api_key"] = obj.bootstrap_admin_api_key + + return result + + def encode_with_completion( + self, obj: IamResponse, + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index 550b7d12..2a214201 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -5,6 +5,7 @@ from .agent import * from .flow import * from .prompt import * from .config import * +from .iam import * from .library import * from .lookup import * from .nlp_query import * diff --git a/trustgraph-base/trustgraph/schema/services/iam.py b/trustgraph-base/trustgraph/schema/services/iam.py new file mode 100644 index 00000000..1e3ab1ab --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/iam.py @@ -0,0 +1,142 @@ + +from dataclasses import dataclass, field + +from ..core.topic import queue +from ..core.primitives import Error + +############################################################################ + +# IAM service — see docs/tech-specs/iam-protocol.md for the full protocol. +# +# Transport: request/response pub/sub, correlated by the `id` message +# property. Caller is the API gateway only; the IAM service trusts +# the bus per the enforcement-boundary policy (no per-request auth +# against the caller). + + +@dataclass +class UserInput: + username: str = "" + name: str = "" + email: str = "" + # Only populated on create-user; never on update-user. + password: str = "" + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + + +@dataclass +class UserRecord: + id: str = "" + workspace: str = "" + username: str = "" + name: str = "" + email: str = "" + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + created: str = "" + + +@dataclass +class WorkspaceInput: + id: str = "" + name: str = "" + enabled: bool = True + + +@dataclass +class WorkspaceRecord: + id: str = "" + name: str = "" + enabled: bool = True + created: str = "" + + +@dataclass +class ApiKeyInput: + user_id: str = "" + name: str = "" + expires: str = "" + + +@dataclass +class ApiKeyRecord: + id: str = "" + user_id: str = "" + name: str = "" + # First 4 chars of the plaintext token, for operator identification + # in list-api-keys. Never enough to reconstruct the key. + prefix: str = "" + expires: str = "" + created: str = "" + last_used: str = "" + + +@dataclass +class IamRequest: + operation: str = "" + + # Workspace scope. Required on workspace-scoped operations; + # omitted for system-level ops (workspace CRUD, signing-key + # ops, bootstrap, resolve-api-key, login). + workspace: str = "" + + # Acting user id for audit. Empty for internal-origin and for + # operations that resolve an identity (login, resolve-api-key). + actor: str = "" + + user_id: str = "" + username: str = "" + key_id: str = "" + api_key: str = "" + + password: str = "" + new_password: str = "" + + user: UserInput | None = None + workspace_record: WorkspaceInput | None = None + key: ApiKeyInput | None = None + + +@dataclass +class IamResponse: + user: UserRecord | None = None + users: list[UserRecord] = field(default_factory=list) + + workspace: WorkspaceRecord | None = None + workspaces: list[WorkspaceRecord] = field(default_factory=list) + + # create-api-key returns the plaintext once; never populated + # on any other operation. + api_key_plaintext: str = "" + api_key: ApiKeyRecord | None = None + api_keys: list[ApiKeyRecord] = field(default_factory=list) + + # login, rotate-signing-key + jwt: str = "" + jwt_expires: str = "" + + # get-signing-key-public + signing_key_public: str = "" + + # resolve-api-key + resolved_user_id: str = "" + resolved_workspace: str = "" + resolved_roles: list[str] = field(default_factory=list) + + # reset-password + temporary_password: str = "" + + # bootstrap + bootstrap_admin_user_id: str = "" + bootstrap_admin_api_key: str = "" + + error: Error | None = None + + +iam_request_queue = queue('iam', cls='request') +iam_response_queue = queue('iam', cls='response') + +############################################################################ diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index d316ae4f..728079c8 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -40,6 +40,20 @@ tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" +tg-bootstrap-iam = "trustgraph.cli.bootstrap_iam:main" +tg-login = "trustgraph.cli.login:main" +tg-create-user = "trustgraph.cli.create_user:main" +tg-list-users = "trustgraph.cli.list_users:main" +tg-disable-user = "trustgraph.cli.disable_user:main" +tg-enable-user = "trustgraph.cli.enable_user:main" +tg-delete-user = "trustgraph.cli.delete_user:main" +tg-change-password = "trustgraph.cli.change_password:main" +tg-reset-password = "trustgraph.cli.reset_password:main" +tg-create-api-key = "trustgraph.cli.create_api_key:main" +tg-list-api-keys = "trustgraph.cli.list_api_keys:main" +tg-revoke-api-key = "trustgraph.cli.revoke_api_key:main" +tg-list-workspaces = "trustgraph.cli.list_workspaces:main" +tg-create-workspace = "trustgraph.cli.create_workspace:main" tg-invoke-agent = "trustgraph.cli.invoke_agent:main" tg-invoke-document-rag = "trustgraph.cli.invoke_document_rag:main" tg-invoke-graph-rag = "trustgraph.cli.invoke_graph_rag:main" diff --git a/trustgraph-cli/trustgraph/cli/_iam.py b/trustgraph-cli/trustgraph/cli/_iam.py new file mode 100644 index 00000000..f5278c0c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/_iam.py @@ -0,0 +1,75 @@ +""" +Shared helpers for IAM CLI tools. + +All IAM operations go through the gateway's ``/api/v1/iam`` forwarder, +with the three public auth operations (``login``, ``bootstrap``, +``change-password``) served via ``/api/v1/auth/...`` instead. These +helpers encapsulate the HTTP plumbing so each CLI can stay focused +on its own argument parsing and output formatting. +""" + +import json +import os +import sys + +import requests + + +DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None) + + +def _fmt_error(resp_json): + err = resp_json.get("error", {}) + if isinstance(err, dict): + t = err.get("type", "") + m = err.get("message", "") + return f"{t}: {m}" if t else m or "error" + return str(err) + + +def _post(url, path, token, body): + endpoint = url.rstrip("/") + path + headers = {"Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.post( + endpoint, headers=headers, data=json.dumps(body), + ) + + if resp.status_code != 200: + try: + payload = resp.json() + detail = _fmt_error(payload) + except Exception: + detail = resp.text + raise RuntimeError(f"HTTP {resp.status_code}: {detail}") + + body = resp.json() + if "error" in body: + raise RuntimeError(_fmt_error(body)) + return body + + +def call_iam(url, token, request): + """Forward an IAM request through ``/api/v1/iam``. ``request`` is + the ``IamRequest`` dict shape.""" + return _post(url, "/api/v1/iam", token, request) + + +def call_auth(url, path, token, body): + """Hit one of the public auth endpoints + (``/api/v1/auth/login``, ``/api/v1/auth/change-password``, etc.). + ``token`` is optional — login and bootstrap don't need one.""" + return _post(url, path, token, body) + + +def run_main(fn, parser): + """Standard error-handling wrapper for CLI main() bodies.""" + args = parser.parse_args() + try: + fn(args) + except Exception as e: + print("Exception:", e, file=sys.stderr, flush=True) + sys.exit(1) diff --git a/trustgraph-cli/trustgraph/cli/bootstrap_iam.py b/trustgraph-cli/trustgraph/cli/bootstrap_iam.py new file mode 100644 index 00000000..99a789e2 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/bootstrap_iam.py @@ -0,0 +1,94 @@ +""" +Bootstraps the IAM service. Only works when iam-svc is running in +bootstrap mode with empty tables. Prints the initial admin API key +to stdout. + +This is a one-time, trust-sensitive operation. The resulting token +is shown once and never again — capture it on use. Rotate and +revoke it as soon as a real admin API key has been issued. +""" + +import argparse +import json +import os +import sys + +import requests + +default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") + + +def bootstrap(url): + + # Unauthenticated public endpoint — IAM refuses the bootstrap + # operation unless the service is running in bootstrap mode with + # empty tables, so the safety gate lives on the server side. + endpoint = url.rstrip("/") + "/api/v1/auth/bootstrap" + + headers = {"Content-Type": "application/json"} + + resp = requests.post( + endpoint, + headers=headers, + data=json.dumps({}), + ) + + if resp.status_code != 200: + raise RuntimeError( + f"HTTP {resp.status_code}: {resp.text}" + ) + + body = resp.json() + + if "error" in body: + raise RuntimeError( + f"IAM {body['error'].get('type', 'error')}: " + f"{body['error'].get('message', '')}" + ) + + api_key = body.get("bootstrap_admin_api_key") + user_id = body.get("bootstrap_admin_user_id") + + if not api_key: + raise RuntimeError( + "IAM response did not contain a bootstrap token — the " + "service may already be bootstrapped, or may be running " + "in token mode." + ) + + return user_id, api_key + + +def main(): + + parser = argparse.ArgumentParser( + prog="tg-bootstrap-iam", + description=__doc__, + ) + + parser.add_argument( + "-u", "--api-url", + default=default_url, + help=f"API URL (default: {default_url})", + ) + + args = parser.parse_args() + + try: + user_id, api_key = bootstrap(args.api_url) + except Exception as e: + print("Exception:", e, file=sys.stderr, flush=True) + sys.exit(1) + + # Stdout gets machine-readable output (the key). Any operator + # context goes to stderr. + print(f"Admin user id: {user_id}", file=sys.stderr) + print( + "Admin API key (shown once, capture now):", + file=sys.stderr, + ) + print(api_key) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/change_password.py b/trustgraph-cli/trustgraph/cli/change_password.py new file mode 100644 index 00000000..c914b30f --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/change_password.py @@ -0,0 +1,46 @@ +""" +Change your own password. Requires the current password. +""" + +import argparse +import getpass + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_auth, run_main + + +def do_change_password(args): + current = args.current or getpass.getpass("Current password: ") + new = args.new or getpass.getpass("New password: ") + + call_auth( + args.api_url, "/api/v1/auth/change-password", args.token, + {"current_password": current, "new_password": new}, + ) + print("Password changed.") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-change-password", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--current", default=None, + help="Current password (prompted if omitted)", + ) + parser.add_argument( + "--new", default=None, + help="New password (prompted if omitted)", + ) + run_main(do_change_password, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/create_api_key.py b/trustgraph-cli/trustgraph/cli/create_api_key.py new file mode 100644 index 00000000..2b269041 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/create_api_key.py @@ -0,0 +1,71 @@ +""" +Create an API key for a user. Prints the plaintext key to stdout — +shown once only. +""" + +import argparse +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_create_api_key(args): + key = { + "user_id": args.user_id, + "name": args.name, + } + if args.expires: + key["expires"] = args.expires + + req = {"operation": "create-api-key", "key": key} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + plaintext = resp.get("api_key_plaintext", "") + rec = resp.get("api_key", {}) + print(f"Key id: {rec.get('id', '')}", file=sys.stderr) + print(f"Name: {rec.get('name', '')}", file=sys.stderr) + print(f"Prefix: {rec.get('prefix', '')}", file=sys.stderr) + print( + "API key (shown once, capture now):", file=sys.stderr, + ) + print(plaintext) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-create-api-key", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, + help="Owner user id", + ) + parser.add_argument( + "--name", required=True, + help="Operator-facing label (e.g. 'laptop', 'ci')", + ) + parser.add_argument( + "--expires", default=None, + help="ISO-8601 expiry (optional; empty = no expiry)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_create_api_key, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/create_user.py b/trustgraph-cli/trustgraph/cli/create_user.py new file mode 100644 index 00000000..c9253aca --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/create_user.py @@ -0,0 +1,87 @@ +""" +Create a user in the caller's workspace. Prints the new user id. +""" + +import argparse +import getpass +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_create_user(args): + password = args.password + if not password: + password = getpass.getpass( + f"Password for new user {args.username}: " + ) + + user = { + "username": args.username, + "password": password, + "roles": args.roles, + } + if args.name: + user["name"] = args.name + if args.email: + user["email"] = args.email + if args.must_change_password: + user["must_change_password"] = True + + req = {"operation": "create-user", "user": user} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + rec = resp.get("user", {}) + print(f"User id: {rec.get('id', '')}", file=sys.stderr) + print(f"Username: {rec.get('username', '')}", file=sys.stderr) + print(f"Roles: {', '.join(rec.get('roles', []))}", file=sys.stderr) + print(rec.get("id", "")) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-create-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--username", required=True, help="Username (unique in workspace)", + ) + parser.add_argument( + "--password", default=None, + help="Password (prompted if omitted)", + ) + parser.add_argument( + "--name", default=None, help="Display name", + ) + parser.add_argument( + "--email", default=None, help="Email", + ) + parser.add_argument( + "--roles", nargs="+", default=["reader"], + help="One or more role names (default: reader)", + ) + parser.add_argument( + "--must-change-password", action="store_true", + help="Force password change on next login", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_create_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/create_workspace.py b/trustgraph-cli/trustgraph/cli/create_workspace.py new file mode 100644 index 00000000..f8367720 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/create_workspace.py @@ -0,0 +1,46 @@ +""" +Create a workspace (system-level; requires admin). +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_create_workspace(args): + ws = {"id": args.workspace_id, "enabled": True} + if args.name: + ws["name"] = args.name + + resp = call_iam(args.api_url, args.token, { + "operation": "create-workspace", + "workspace_record": ws, + }) + rec = resp.get("workspace", {}) + print(f"Workspace created: {rec.get('id', '')}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-create-workspace", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--workspace-id", required=True, + help="New workspace id (must not start with '_')", + ) + parser.add_argument( + "--name", default=None, help="Display name", + ) + run_main(do_create_workspace, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/delete_user.py b/trustgraph-cli/trustgraph/cli/delete_user.py new file mode 100644 index 00000000..dbdf7877 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/delete_user.py @@ -0,0 +1,62 @@ +""" +Delete a user. Removes the user record, their username lookup, +and all their API keys. The freed username becomes available for +re-use. + +Irreversible. Use tg-disable-user if you want to preserve the +record (audit trail, username squatting protection). +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_delete_user(args): + if not args.yes: + confirm = input( + f"Delete user {args.user_id}? This is irreversible. " + f"[type 'yes' to confirm]: " + ) + if confirm.strip() != "yes": + print("Aborted.") + return + + req = {"operation": "delete-user", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Deleted user {args.user_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-delete-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="User id to delete", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + parser.add_argument( + "--yes", action="store_true", + help="Skip the interactive confirmation prompt", + ) + run_main(do_delete_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/disable_user.py b/trustgraph-cli/trustgraph/cli/disable_user.py new file mode 100644 index 00000000..e142644b --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/disable_user.py @@ -0,0 +1,45 @@ +""" +Disable a user. Soft-deletes (enabled=false) and revokes all their +API keys. +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_disable_user(args): + req = {"operation": "disable-user", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Disabled user {args.user_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-disable-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="User id to disable", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_disable_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/enable_user.py b/trustgraph-cli/trustgraph/cli/enable_user.py new file mode 100644 index 00000000..c762366a --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/enable_user.py @@ -0,0 +1,45 @@ +""" +Re-enable a previously disabled user. Does not restore their API +keys — those must be re-issued by an admin. +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_enable_user(args): + req = {"operation": "enable-user", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Enabled user {args.user_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-enable-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="User id to enable", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_enable_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/list_api_keys.py b/trustgraph-cli/trustgraph/cli/list_api_keys.py new file mode 100644 index 00000000..f969890e --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_api_keys.py @@ -0,0 +1,69 @@ +""" +List the API keys for a user. +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_list_api_keys(args): + req = {"operation": "list-api-keys", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + keys = resp.get("api_keys", []) + if not keys: + print("No keys.") + return + + rows = [ + [ + k.get("id", ""), + k.get("name", ""), + k.get("prefix", ""), + k.get("created", ""), + k.get("last_used", "") or "—", + k.get("expires", "") or "never", + ] + for k in keys + ] + print(tabulate.tabulate( + rows, + headers=["id", "name", "prefix", "created", "last used", "expires"], + tablefmt="pretty", + stralign="left", + )) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-list-api-keys", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, + help="Owner user id", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_list_api_keys, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/list_users.py b/trustgraph-cli/trustgraph/cli/list_users.py new file mode 100644 index 00000000..25bc1901 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_users.py @@ -0,0 +1,65 @@ +""" +List users in the caller's workspace. +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_list_users(args): + req = {"operation": "list-users"} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + users = resp.get("users", []) + if not users: + print("No users.") + return + + rows = [ + [ + u.get("id", ""), + u.get("username", ""), + u.get("name", ""), + ", ".join(u.get("roles", [])), + "yes" if u.get("enabled") else "no", + "yes" if u.get("must_change_password") else "no", + ] + for u in users + ] + print(tabulate.tabulate( + rows, + headers=["id", "username", "name", "roles", "enabled", "change-pw"], + tablefmt="pretty", + stralign="left", + )) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-list-users", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_list_users, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/list_workspaces.py b/trustgraph-cli/trustgraph/cli/list_workspaces.py new file mode 100644 index 00000000..170d330c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_workspaces.py @@ -0,0 +1,53 @@ +""" +List workspaces (system-level; requires admin). +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_list_workspaces(args): + resp = call_iam( + args.api_url, args.token, {"operation": "list-workspaces"}, + ) + workspaces = resp.get("workspaces", []) + if not workspaces: + print("No workspaces.") + return + rows = [ + [ + w.get("id", ""), + w.get("name", ""), + "yes" if w.get("enabled") else "no", + w.get("created", ""), + ] + for w in workspaces + ] + print(tabulate.tabulate( + rows, + headers=["id", "name", "enabled", "created"], + tablefmt="pretty", + stralign="left", + )) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-list-workspaces", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + run_main(do_list_workspaces, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/login.py b/trustgraph-cli/trustgraph/cli/login.py new file mode 100644 index 00000000..0e87c3b0 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/login.py @@ -0,0 +1,62 @@ +""" +Log in with username / password. Prints the resulting JWT to +stdout so it can be captured for subsequent CLI use. +""" + +import argparse +import getpass +import sys + +from ._iam import DEFAULT_URL, call_auth, run_main + + +def do_login(args): + password = args.password + if not password: + password = getpass.getpass(f"Password for {args.username}: ") + + body = { + "username": args.username, + "password": password, + } + if args.workspace: + body["workspace"] = args.workspace + + resp = call_auth(args.api_url, "/api/v1/auth/login", None, body) + + jwt = resp.get("jwt", "") + expires = resp.get("jwt_expires", "") + + if expires: + print(f"JWT expires: {expires}", file=sys.stderr) + # Machine-readable on stdout. + print(jwt) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-login", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "--username", required=True, help="Username", + ) + parser.add_argument( + "--password", default=None, + help="Password (prompted if omitted)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Optional workspace to log in against. Defaults to " + "the user's assigned workspace." + ), + ) + run_main(do_login, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/reset_password.py b/trustgraph-cli/trustgraph/cli/reset_password.py new file mode 100644 index 00000000..600f00e1 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/reset_password.py @@ -0,0 +1,54 @@ +""" +Admin: reset another user's password. Prints a one-time temporary +password to stdout. The user is forced to change it on next login. +""" + +import argparse +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_reset_password(args): + req = {"operation": "reset-password", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + tmp = resp.get("temporary_password", "") + if not tmp: + raise RuntimeError( + "IAM returned no temporary password — unexpected" + ) + print("Temporary password (shown once, capture now):", file=sys.stderr) + print(tmp) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-reset-password", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, + help="Target user id", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_reset_password, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/revoke_api_key.py b/trustgraph-cli/trustgraph/cli/revoke_api_key.py new file mode 100644 index 00000000..3976b56f --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/revoke_api_key.py @@ -0,0 +1,44 @@ +""" +Revoke an API key by id. +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_revoke_api_key(args): + req = {"operation": "revoke-api-key", "key_id": args.key_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Revoked key {args.key_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-revoke-api-key", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--key-id", required=True, help="Key id to revoke", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_revoke_api_key, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index cc7dac63..d8c690b5 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -63,6 +63,7 @@ chunker-token = "trustgraph.chunking.token:run" bootstrap = "trustgraph.bootstrap.bootstrapper:run" config-svc = "trustgraph.config.service:run" flow-svc = "trustgraph.flow.service:run" +iam-svc = "trustgraph.iam.service:run" doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index a693ca32..95743261 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -1,22 +1,264 @@ +""" +IAM-backed authentication for the API gateway. -class Authenticator: +Replaces the legacy GATEWAY_SECRET shared-token Authenticator. The +gateway is now stateless with respect to credentials: it either +verifies a JWT locally using the active IAM signing public key, or +resolves an API key by hash with a short local cache backed by the +IAM service. - def __init__(self, token=None, allow_all=False): +Identity returned by authenticate() is the (user_id, workspace, +roles) triple the rest of the gateway — capability checks, workspace +resolver, audit logging — needs. +""" - if not allow_all and token is None: - raise RuntimeError("Need a token") +import asyncio +import base64 +import hashlib +import json +import logging +import time +import uuid +from dataclasses import dataclass - if not allow_all and token == "": - raise RuntimeError("Need a token") +from aiohttp import web - self.token = token - self.allow_all = allow_all +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 - def permitted(self, token, roles): +from ..base.iam_client import IamClient +from ..base.metrics import ProducerMetrics, SubscriberMetrics +from ..schema import ( + IamRequest, IamResponse, + iam_request_queue, iam_response_queue, +) - if self.allow_all: return True +logger = logging.getLogger("auth") - if self.token != token: return False +API_KEY_CACHE_TTL = 60 # seconds - return True +@dataclass +class Identity: + user_id: str + workspace: str + roles: list + source: str # "api-key" | "jwt" + + +def _auth_failure(): + return web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + ) + + +def _access_denied(): + return web.HTTPForbidden( + text='{"error":"access denied"}', + content_type="application/json", + ) + + +def _b64url_decode(s): + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s + pad) + + +def _verify_jwt_eddsa(token, public_pem): + """Verify an Ed25519 JWT and return its claims. Raises on any + validation failure. Refuses non-EdDSA algorithms.""" + parts = token.split(".") + if len(parts) != 3: + raise ValueError("malformed JWT") + h_b64, p_b64, s_b64 = parts + signing_input = f"{h_b64}.{p_b64}".encode("ascii") + header = json.loads(_b64url_decode(h_b64)) + if header.get("alg") != "EdDSA": + raise ValueError(f"unsupported alg: {header.get('alg')!r}") + + key = serialization.load_pem_public_key(public_pem.encode("ascii")) + if not isinstance(key, ed25519.Ed25519PublicKey): + raise ValueError("public key is not Ed25519") + + signature = _b64url_decode(s_b64) + key.verify(signature, signing_input) # raises InvalidSignature + + claims = json.loads(_b64url_decode(p_b64)) + exp = claims.get("exp") + if exp is None or exp < time.time(): + raise ValueError("expired") + return claims + + +class IamAuth: + """Resolves bearer credentials via the IAM service. + + Used by every gateway endpoint that needs authentication. Fetches + the IAM signing public key at startup (cached in memory). API + keys are resolved via the IAM service with a local hash→identity + cache (short TTL so revoked keys stop working within the TTL + window without any push mechanism).""" + + def __init__(self, backend, id="api-gateway"): + self.backend = backend + self.id = id + + # Populated at start() via IAM. + self._signing_public_pem = None + + # API-key cache: plaintext_sha256_hex -> (Identity, expires_ts) + self._key_cache = {} + self._key_cache_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Short-lived client helper. Mirrors the pattern used by the + # bootstrap framework and AsyncProcessor: a fresh uuid suffix per + # invocation so Pulsar exclusive subscriptions don't collide with + # ghosts from prior calls. + # ------------------------------------------------------------------ + + def _make_client(self): + rr_id = str(uuid.uuid4()) + return IamClient( + backend=self.backend, + subscription=f"{self.id}--iam--{rr_id}", + consumer_name=self.id, + request_topic=iam_request_queue, + request_schema=IamRequest, + request_metrics=ProducerMetrics( + processor=self.id, flow=None, name="iam-request", + ), + response_topic=iam_response_queue, + response_schema=IamResponse, + response_metrics=SubscriberMetrics( + processor=self.id, flow=None, name="iam-response", + ), + ) + + async def _with_client(self, op): + """Open a short-lived IamClient, run ``op(client)``, close.""" + client = self._make_client() + await client.start() + try: + return await op(client) + finally: + try: + await client.stop() + except Exception: + pass + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def start(self, max_retries=30, retry_delay=2.0): + """Fetch the signing public key from IAM. Retries on + failure — the gateway may be starting before IAM is ready.""" + + async def _fetch(client): + return await client.get_signing_key_public() + + for attempt in range(max_retries): + try: + pem = await self._with_client(_fetch) + if pem: + self._signing_public_pem = pem + logger.info( + "IamAuth: fetched IAM signing public key " + f"({len(pem)} bytes)" + ) + return + except Exception as e: + logger.info( + f"IamAuth: waiting for IAM signing key " + f"({type(e).__name__}: {e}); " + f"retry {attempt + 1}/{max_retries}" + ) + await asyncio.sleep(retry_delay) + + # Don't prevent startup forever. A later authenticate() call + # will try again via the JWT path. + logger.warning( + "IamAuth: could not fetch IAM signing key at startup; " + "JWT validation will fail until it's available" + ) + + # ------------------------------------------------------------------ + # Authentication + # ------------------------------------------------------------------ + + async def authenticate(self, request): + """Extract and validate the Bearer credential from an HTTP + request. Returns an ``Identity``. Raises HTTPUnauthorized + (401 / "auth failure") on any failure mode — the caller + cannot distinguish missing / malformed / invalid / expired / + revoked credentials.""" + + header = request.headers.get("Authorization", "") + if not header.startswith("Bearer "): + raise _auth_failure() + token = header[len("Bearer "):].strip() + if not token: + raise _auth_failure() + + # API keys always start with "tg_". JWTs have two dots and + # no "tg_" prefix. Discriminate cheaply. + if token.startswith("tg_"): + return await self._resolve_api_key(token) + if token.count(".") == 2: + return self._verify_jwt(token) + raise _auth_failure() + + def _verify_jwt(self, token): + if not self._signing_public_pem: + raise _auth_failure() + try: + claims = _verify_jwt_eddsa(token, self._signing_public_pem) + except Exception as e: + logger.debug(f"JWT validation failed: {type(e).__name__}: {e}") + raise _auth_failure() + + sub = claims.get("sub", "") + ws = claims.get("workspace", "") + roles = list(claims.get("roles", [])) + if not sub or not ws: + raise _auth_failure() + + return Identity( + user_id=sub, workspace=ws, roles=roles, source="jwt", + ) + + async def _resolve_api_key(self, plaintext): + h = hashlib.sha256(plaintext.encode("utf-8")).hexdigest() + + cached = self._key_cache.get(h) + now = time.time() + if cached and cached[1] > now: + return cached[0] + + async with self._key_cache_lock: + cached = self._key_cache.get(h) + if cached and cached[1] > now: + return cached[0] + + try: + async def _call(client): + return await client.resolve_api_key(plaintext) + user_id, workspace, roles = await self._with_client(_call) + except Exception as e: + logger.debug( + f"API key resolution failed: " + f"{type(e).__name__}: {e}" + ) + raise _auth_failure() + + if not user_id or not workspace: + raise _auth_failure() + + identity = Identity( + user_id=user_id, workspace=workspace, + roles=list(roles), source="api-key", + ) + self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL) + return identity diff --git a/trustgraph-flow/trustgraph/gateway/capabilities.py b/trustgraph-flow/trustgraph/gateway/capabilities.py new file mode 100644 index 00000000..15e25684 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/capabilities.py @@ -0,0 +1,238 @@ +""" +Capability vocabulary, role definitions, and authorisation helpers. + +See docs/tech-specs/capabilities.md for the authoritative description. +The data here is the OSS bundle table in that spec. Enterprise +editions may replace this module with their own role table; the +vocabulary (capability strings) is shared. + +Role model +---------- +A role has two dimensions: + + 1. **capability set** — which operations the role grants. + 2. **workspace scope** — which workspaces the role is active in. + +The authorisation question is: *given the caller's roles, a required +capability, and a target workspace, does any role grant the +capability AND apply to the target workspace?* + +Workspace scope values recognised here: + + - ``"assigned"`` — the role applies only to the caller's own + assigned workspace (stored on their user record). + - ``"*"`` — the role applies to every workspace. + +Enterprise editions can add richer scopes (explicit permitted-set, +patterns, etc.) without changing the wire protocol. + +Sentinels +--------- +- ``PUBLIC`` — endpoint requires no authentication. +- ``AUTHENTICATED`` — endpoint requires a valid identity, no + specific capability. +""" + +from aiohttp import web + + +PUBLIC = "__public__" +AUTHENTICATED = "__authenticated__" + + +# Capability vocabulary. Mirrors the "Capability list" tables in +# capabilities.md. Kept as a set so the gateway can fail-closed on +# an endpoint that declares an unknown capability. +KNOWN_CAPABILITIES = { + # Data plane + "agent", + "graph:read", "graph:write", + "documents:read", "documents:write", + "rows:read", "rows:write", + "llm", + "embeddings", + "mcp", + # Control plane + "config:read", "config:write", + "flows:read", "flows:write", + "users:read", "users:write", "users:admin", + "keys:self", "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", + "collections:read", "collections:write", + "knowledge:read", "knowledge:write", +} + + +# Capability sets used below. +_READER_CAPS = { + "agent", + "graph:read", + "documents:read", + "rows:read", + "llm", + "embeddings", + "mcp", + "config:read", + "flows:read", + "collections:read", + "knowledge:read", + "keys:self", +} + +_WRITER_CAPS = _READER_CAPS | { + "graph:write", + "documents:write", + "rows:write", + "collections:write", + "knowledge:write", +} + +_ADMIN_CAPS = _WRITER_CAPS | { + "config:write", + "flows:write", + "users:read", "users:write", "users:admin", + "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", +} + + +# Role definitions. Each role has a capability set and a workspace +# scope. Enterprise overrides this mapping. +ROLE_DEFINITIONS = { + "reader": { + "capabilities": _READER_CAPS, + "workspace_scope": "assigned", + }, + "writer": { + "capabilities": _WRITER_CAPS, + "workspace_scope": "assigned", + }, + "admin": { + "capabilities": _ADMIN_CAPS, + "workspace_scope": "*", + }, +} + + +def _scope_permits(role_name, target_workspace, assigned_workspace): + """Does the given role apply to ``target_workspace``?""" + role = ROLE_DEFINITIONS.get(role_name) + if role is None: + return False + scope = role["workspace_scope"] + if scope == "*": + return True + if scope == "assigned": + return target_workspace == assigned_workspace + # Future scope types (lists, patterns) extend here. + return False + + +def check(identity, capability, target_workspace=None): + """Is ``identity`` permitted to invoke ``capability`` on + ``target_workspace``? + + Passes iff some role held by the caller both (a) grants + ``capability`` and (b) is active in ``target_workspace``. + + ``target_workspace`` defaults to the caller's assigned workspace, + which makes this function usable for system-level operations and + for authenticated endpoints that don't take a workspace argument + (the call collapses to "do any of my roles grant this cap?").""" + if capability not in KNOWN_CAPABILITIES: + return False + + target = target_workspace or identity.workspace + + for role_name in identity.roles: + role = ROLE_DEFINITIONS.get(role_name) + if role is None: + continue + if capability not in role["capabilities"]: + continue + if _scope_permits(role_name, target, identity.workspace): + return True + return False + + +def access_denied(): + return web.HTTPForbidden( + text='{"error":"access denied"}', + content_type="application/json", + ) + + +def auth_failure(): + return web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + ) + + +async def enforce(request, auth, capability): + """Authenticate + capability-check for endpoints that carry no + workspace dimension on the request (metrics, i18n, etc.). + + For endpoints that carry a workspace field on the body, call + :func:`enforce_workspace` *after* parsing the body to validate + the workspace and re-check the capability in that scope. Most + endpoints do both. + + - ``PUBLIC``: no authentication, returns ``None``. + - ``AUTHENTICATED``: any valid identity. + - capability string: identity must have it, checked against the + caller's assigned workspace (adequate for endpoints whose + capability is system-level, e.g. ``metrics:read``, or where + the real workspace-aware check happens in + :func:`enforce_workspace` after body parsing).""" + if capability == PUBLIC: + return None + + identity = await auth.authenticate(request) + + if capability == AUTHENTICATED: + return identity + + if not check(identity, capability): + raise access_denied() + + return identity + + +def enforce_workspace(data, identity, capability=None): + """Resolve + validate the workspace on a request body. + + - Target workspace = ``data["workspace"]`` if supplied, else the + caller's assigned workspace. + - At least one of the caller's roles must (a) be active in the + target workspace and, if ``capability`` is given, (b) grant + ``capability``. Otherwise 403. + - On success, ``data["workspace"]`` is overwritten with the + resolved value — callers can rely on the outgoing message + having the gateway's chosen workspace rather than any + caller-supplied value. + + For ``capability=None`` the workspace scope alone is checked — + useful when the body has a workspace but the endpoint already + passed its capability check (e.g. via :func:`enforce`).""" + if not isinstance(data, dict): + return data + + requested = data.get("workspace", "") + target = requested or identity.workspace + + for role_name in identity.roles: + role = ROLE_DEFINITIONS.get(role_name) + if role is None: + continue + if capability is not None and capability not in role["capabilities"]: + continue + if _scope_permits(role_name, target, identity.workspace): + data["workspace"] = target + return data + + raise access_denied() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/iam.py b/trustgraph-flow/trustgraph/gateway/dispatch/iam.py new file mode 100644 index 00000000..386233f5 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/iam.py @@ -0,0 +1,40 @@ + +from ... schema import IamRequest, IamResponse +from ... schema import iam_request_queue, iam_response_queue +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + + +class IamRequestor(ServiceRequestor): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = iam_request_queue + if response_queue is None: + response_queue = iam_response_queue + + super().__init__( + backend=backend, + consumer_name=consumer, + subscription=subscriber, + request_queue=request_queue, + response_queue=response_queue, + request_schema=IamRequest, + response_schema=IamResponse, + timeout=timeout, + ) + + self.request_translator = ( + TranslatorRegistry.get_request_translator("iam") + ) + self.response_translator = ( + TranslatorRegistry.get_response_translator("iam") + ) + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index b238bb5b..ea8770d7 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) from . config import ConfigRequestor from . flow import FlowRequestor +from . iam import IamRequestor from . librarian import LibrarianRequestor from . knowledge import KnowledgeRequestor from . collection_management import CollectionManagementRequestor @@ -72,6 +73,7 @@ request_response_dispatchers = { global_dispatchers = { "config": ConfigRequestor, "flow": FlowRequestor, + "iam": IamRequestor, "librarian": LibrarianRequestor, "knowledge": KnowledgeRequestor, "collection-management": CollectionManagementRequestor, @@ -105,13 +107,31 @@ class DispatcherWrapper: class DispatcherManager: - def __init__(self, backend, config_receiver, prefix="api-gateway", - queue_overrides=None): + def __init__(self, backend, config_receiver, auth, + prefix="api-gateway", queue_overrides=None): + """ + ``auth`` is required. It flows into the Mux for first-frame + WebSocket authentication and into downstream dispatcher + construction. There is no permissive default — constructing + a DispatcherManager without an authenticator would be a + silent downgrade to no-auth on the socket path. + """ + if auth is None: + raise ValueError( + "DispatcherManager requires an 'auth' argument — there " + "is no no-auth mode" + ) + self.backend = backend self.config_receiver = config_receiver self.config_receiver.add_handler(self) self.prefix = prefix + # Gateway IamAuth — used by the socket Mux for first-frame + # auth and by any dispatcher that needs to resolve caller + # identity out-of-band. + self.auth = auth + # Store queue overrides for global services # Format: {"config": {"request": "...", "response": "..."}, ...} self.queue_overrides = queue_overrides or {} @@ -163,6 +183,15 @@ class DispatcherManager: def dispatch_global_service(self): return DispatcherWrapper(self.process_global_service) + def dispatch_auth_iam(self): + """Pre-configured IAM dispatcher for the gateway's auth + endpoints (login, bootstrap, change-password). Pins the + kind to ``iam`` so these handlers don't have to supply URL + params the global dispatcher would expect.""" + async def _process(data, responder): + return await self.invoke_global_service(data, responder, "iam") + return DispatcherWrapper(_process) + def dispatch_core_export(self): return DispatcherWrapper(self.process_core_export) @@ -314,7 +343,10 @@ class DispatcherManager: async def process_socket(self, ws, running, params): - dispatcher = Mux(self, ws, running) + # The mux self-authenticates via the first-frame protocol; + # pass the gateway's IamAuth so it can validate tokens + # without reaching back into the endpoint layer. + dispatcher = Mux(self, ws, running, auth=self.auth) return dispatcher diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 3d610dca..013cd1ea 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -16,11 +16,28 @@ MAX_QUEUE_SIZE = 10 class Mux: - def __init__(self, dispatcher_manager, ws, running): + def __init__(self, dispatcher_manager, ws, running, auth): + """ + ``auth`` is required — the Mux implements the first-frame + auth protocol described in ``iam.md`` and will refuse any + non-auth frame until an ``auth-ok`` has been issued. There + is no no-auth mode. + """ + if auth is None: + raise ValueError( + "Mux requires an 'auth' argument — there is no " + "no-auth mode" + ) self.dispatcher_manager = dispatcher_manager self.ws = ws self.running = running + self.auth = auth + + # Authenticated identity, populated by the first-frame auth + # protocol. ``None`` means the socket is not yet + # authenticated; any non-auth frame is refused. + self.identity = None self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) @@ -31,6 +48,41 @@ class Mux: if self.ws: await self.ws.close() + async def _handle_auth_frame(self, data): + """Process a ``{"type": "auth", "token": "..."}`` frame. + On success, updates ``self.identity`` and returns an + ``auth-ok`` response frame. On failure, returns the masked + auth-failure frame. Never raises — auth failures keep the + socket open so the client can retry without reconnecting + (important for browsers, which treat a handshake-time 401 + as terminal).""" + token = data.get("token", "") + if not token: + await self.ws.send_json({ + "type": "auth-failed", + "error": "auth failure", + }) + return + + class _Shim: + def __init__(self, tok): + self.headers = {"Authorization": f"Bearer {tok}"} + + try: + identity = await self.auth.authenticate(_Shim(token)) + except Exception: + await self.ws.send_json({ + "type": "auth-failed", + "error": "auth failure", + }) + return + + self.identity = identity + await self.ws.send_json({ + "type": "auth-ok", + "workspace": identity.workspace, + }) + async def receive(self, msg): request_id = None @@ -38,6 +90,16 @@ class Mux: try: data = msg.json() + + # In-band auth protocol: the client sends + # ``{"type": "auth", "token": "..."}`` as its first frame + # (and any time it wants to re-auth: JWT refresh, token + # rotation, etc). Auth is always required on a Mux — + # there is no no-auth mode. + if isinstance(data, dict) and data.get("type") == "auth": + await self._handle_auth_frame(data) + return + request_id = data.get("id") if "request" not in data: @@ -46,9 +108,49 @@ class Mux: if "id" not in data: raise RuntimeError("Bad message") + # Reject all non-auth frames until an ``auth-ok`` has + # been issued. + if self.identity is None: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "auth failure", + "type": "auth-required", + }, + "complete": True, + }) + return + + # Workspace resolution. Role workspace scope determines + # which target workspaces are permitted. The resolved + # value is written to both the envelope and the inner + # request payload so clients don't have to repeat it + # per-message (same convenience HTTP callers get via + # enforce_workspace). + from ..capabilities import enforce_workspace + from aiohttp import web as _web + + try: + enforce_workspace(data, self.identity) + inner = data.get("request") + if isinstance(inner, dict): + enforce_workspace(inner, self.identity) + except _web.HTTPForbidden: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "access denied", + "type": "access-denied", + }, + "complete": True, + }) + return + + workspace = data["workspace"] + await self.q.put(( data["id"], - data.get("workspace", "default"), + workspace, data.get("flow"), data["service"], data["request"] diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py new file mode 100644 index 00000000..6037fc4b --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py @@ -0,0 +1,115 @@ +""" +Gateway auth endpoints. + +Three dedicated paths: + POST /api/v1/auth/login — unauthenticated; username/password → JWT + POST /api/v1/auth/bootstrap — unauthenticated; IAM bootstrap op + POST /api/v1/auth/change-password — authenticated; any role + +These are the only IAM-surface operations that can be reached from +outside. Everything else routes through ``/api/v1/iam`` gated by +``users:admin``. +""" + +import logging + +from aiohttp import web + +from .. capabilities import enforce, PUBLIC, AUTHENTICATED + +logger = logging.getLogger("auth-endpoints") +logger.setLevel(logging.INFO) + + +class AuthEndpoints: + """Groups the three auth-surface handlers. Each forwards to the + IAM service via the existing ``IamRequestor`` dispatcher.""" + + def __init__(self, iam_dispatcher, auth): + self.iam = iam_dispatcher + self.auth = auth + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([ + web.post("/api/v1/auth/login", self.login), + web.post("/api/v1/auth/bootstrap", self.bootstrap), + web.post( + "/api/v1/auth/change-password", + self.change_password, + ), + ]) + + async def _forward(self, body): + async def responder(x, fin): + pass + return await self.iam.process(body, responder) + + async def login(self, request): + """Public. Accepts {username, password, workspace?}. Returns + {jwt, jwt_expires} on success; IAM's masked auth failure on + anything else.""" + await enforce(request, self.auth, PUBLIC) + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + req = { + "operation": "login", + "username": body.get("username", ""), + "password": body.get("password", ""), + "workspace": body.get("workspace", ""), + } + resp = await self._forward(req) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + + async def bootstrap(self, request): + """Public. Valid only when IAM is running in bootstrap mode + with empty tables. In every other case the IAM service + returns a masked auth-failure.""" + await enforce(request, self.auth, PUBLIC) + resp = await self._forward({"operation": "bootstrap"}) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + + async def change_password(self, request): + """Authenticated (any role). Accepts {current_password, + new_password}; user_id is taken from the authenticated + identity — the caller cannot change someone else's password + this way (reset-password is the admin path).""" + identity = await enforce(request, self.auth, AUTHENTICATED) + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + req = { + "operation": "change-password", + "user_id": identity.user_id, + "password": body.get("current_password", ""), + "new_password": body.get("new_password", ""), + } + resp = await self._forward(req) + if "error" in resp: + err_type = resp.get("error", {}).get("type", "") + if err_type == "auth-failed": + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response( + {"error": resp.get("error", {}).get("message", "error")}, + status=400, + ) + return web.json_response(resp) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py index 58ba1738..ee9c0447 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py @@ -1,28 +1,27 @@ -import asyncio -from aiohttp import web -import uuid import logging +from aiohttp import web + +from .. capabilities import enforce, enforce_workspace + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class ConstantEndpoint: - def __init__(self, endpoint_path, auth, dispatcher): + def __init__(self, endpoint_path, auth, dispatcher, capability): self.path = endpoint_path - self.auth = auth - self.operation = "service" - + self.capability = capability self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - app.add_routes([ web.post(self.path, self.handle), ]) @@ -31,22 +30,14 @@ class ConstantEndpoint: logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + identity = await enforce(request, self.auth, self.capability) try: - data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + async def responder(x, fin): pass @@ -54,10 +45,8 @@ class ConstantEndpoint: return web.json_response(resp) + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py b/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py index b949a499..f28f293d 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py @@ -4,16 +4,18 @@ from aiohttp import web from trustgraph.i18n import get_language_pack +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) class I18nPackEndpoint: - def __init__(self, endpoint_path: str, auth): + def __init__(self, endpoint_path: str, auth, capability): self.path = endpoint_path self.auth = auth - self.operation = "service" + self.capability = capability async def start(self): pass @@ -26,26 +28,13 @@ class I18nPackEndpoint: async def handle(self, request): logger.debug(f"Processing i18n pack request: {request.path}") - token = "" - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except Exception: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) lang = request.match_info.get("lang") or "en" - # This is a path traversal defense, and is a critical sec defense. - # Do not remove! + # Path-traversal defense — critical, do not remove. if "/" in lang or ".." in lang: return web.HTTPBadRequest(reason="Invalid language code") pack = get_language_pack(lang) - return web.json_response(pack) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index fb8b0b76..69b11e07 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -8,72 +8,278 @@ from . variable_endpoint import VariableEndpoint from . socket import SocketEndpoint from . metrics import MetricsEndpoint from . i18n import I18nPackEndpoint +from . auth_endpoints import AuthEndpoints + +from .. capabilities import PUBLIC, AUTHENTICATED from .. dispatch.manager import DispatcherManager + +# Capability required for each kind on the /api/v1/{kind} generic +# endpoint (global services). Coarse gating — the IAM bundle split +# of "read vs write" per admin subsystem is not applied here because +# this endpoint forwards an opaque operation in the body. Writes +# are the upper bound on what the endpoint can do, so we gate on +# the write/admin capability. +GLOBAL_KIND_CAPABILITY = { + "config": "config:write", + "flow": "flows:write", + "librarian": "documents:write", + "knowledge": "knowledge:write", + "collection-management": "collections:write", + # IAM endpoints land on /api/v1/iam and require the admin bundle. + # Login / bootstrap / change-password are served by + # AuthEndpoints, which handle their own gating (PUBLIC / + # AUTHENTICATED). + "iam": "users:admin", +} + + +# Capability required for each kind on the +# /api/v1/flow/{flow}/service/{kind} endpoint (per-flow data-plane). +FLOW_KIND_CAPABILITY = { + "agent": "agent", + "text-completion": "llm", + "prompt": "llm", + "mcp-tool": "mcp", + "graph-rag": "graph:read", + "document-rag": "documents:read", + "embeddings": "embeddings", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "triples": "graph:read", + "rows": "rows:read", + "nlp-query": "rows:read", + "structured-query": "rows:read", + "structured-diag": "rows:read", + "row-embeddings": "rows:read", + "sparql": "graph:read", +} + + +# Capability for the streaming flow import/export endpoints, +# keyed by the "kind" URL segment. +FLOW_IMPORT_CAPABILITY = { + "triples": "graph:write", + "graph-embeddings": "graph:write", + "document-embeddings": "documents:write", + "entity-contexts": "documents:write", + "rows": "rows:write", +} + +FLOW_EXPORT_CAPABILITY = { + "triples": "graph:read", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "entity-contexts": "documents:read", +} + + +from .. capabilities import enforce, enforce_workspace +import logging as _mgr_logging +_mgr_logger = _mgr_logging.getLogger("endpoint") + + +class _RoutedVariableEndpoint: + """HTTP endpoint whose required capability is looked up per + request from the URL's ``kind`` parameter. Used for the two + generic dispatch paths (``/api/v1/{kind}`` and + ``/api/v1/flow/{flow}/service/{kind}``). Self-contained rather + than subclassing ``VariableEndpoint`` to avoid mutating shared + state across concurrent requests.""" + + def __init__(self, endpoint_path, auth, dispatcher, capability_map): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + self._capability_map = capability_map + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.post(self.path, self.handle)]) + + async def handle(self, request): + kind = request.match_info.get("kind", "") + cap = self._capability_map.get(kind) + if cap is None: + return web.json_response( + {"error": "unknown kind"}, status=404, + ) + + identity = await enforce(request, self.auth, cap) + + try: + data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + + async def responder(x, fin): + pass + + resp = await self.dispatcher.process( + data, responder, request.match_info, + ) + return web.json_response(resp) + + except web.HTTPException: + raise + except Exception as e: + _mgr_logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) + + +class _RoutedSocketEndpoint: + """WebSocket endpoint whose required capability is looked up per + request from the URL's ``kind`` parameter. Used for the flow + import/export streaming endpoints.""" + + def __init__(self, endpoint_path, auth, dispatcher, capability_map): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + self._capability_map = capability_map + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.get(self.path, self.handle)]) + + async def handle(self, request): + from .. capabilities import check, auth_failure, access_denied + + kind = request.match_info.get("kind", "") + cap = self._capability_map.get(kind) + if cap is None: + return web.json_response( + {"error": "unknown kind"}, status=404, + ) + + token = request.query.get("token", "") + if not token: + return auth_failure() + + from . socket import _QueryTokenRequest + try: + identity = await self.auth.authenticate( + _QueryTokenRequest(token) + ) + except web.HTTPException as e: + return e + if not check(identity, cap): + return access_denied() + + # Delegate the websocket handling to a standalone SocketEndpoint + # with the resolved capability, bypassing the per-request mutation + # concern by instantiating fresh state. + ws_ep = SocketEndpoint( + endpoint_path=self.path, + auth=self.auth, + dispatcher=self.dispatcher, + capability=cap, + ) + return await ws_ep.handle(request) + + class EndpointManager: def __init__( - self, dispatcher_manager, auth, prometheus_url, timeout=600 + self, dispatcher_manager, auth, prometheus_url, timeout=600, ): self.dispatcher_manager = dispatcher_manager self.timeout = timeout - self.services = { - } - self.endpoints = [ + + # Auth surface — public / authenticated-any. Must come + # before the generic /api/v1/{kind} routes to win the + # match for /api/v1/auth/* paths. aiohttp routes in + # registration order, so we prepend here. + AuthEndpoints( + iam_dispatcher=dispatcher_manager.dispatch_auth_iam(), + auth=auth, + ), + I18nPackEndpoint( - endpoint_path = "/api/v1/i18n/packs/{lang}", - auth = auth, + endpoint_path="/api/v1/i18n/packs/{lang}", + auth=auth, + capability=PUBLIC, ), MetricsEndpoint( - endpoint_path = "/api/metrics", - prometheus_url = prometheus_url, - auth = auth, + endpoint_path="/api/metrics", + prometheus_url=prometheus_url, + auth=auth, + capability="metrics:read", ), - VariableEndpoint( - endpoint_path = "/api/v1/{kind}", auth = auth, - dispatcher = dispatcher_manager.dispatch_global_service(), + + # Global services: capability chosen per-kind. + _RoutedVariableEndpoint( + endpoint_path="/api/v1/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_global_service(), + capability_map=GLOBAL_KIND_CAPABILITY, ), + + # /api/v1/socket: WebSocket handshake accepts + # unconditionally; the Mux dispatcher runs the + # first-frame auth protocol. Handshake-time 401s break + # browser reconnection, so authentication is always + # in-band for this endpoint. SocketEndpoint( - endpoint_path = "/api/v1/socket", - auth = auth, - dispatcher = dispatcher_manager.dispatch_socket() + endpoint_path="/api/v1/socket", + auth=auth, + dispatcher=dispatcher_manager.dispatch_socket(), + capability=AUTHENTICATED, # informational only; bypassed + in_band_auth=True, ), - VariableEndpoint( - endpoint_path = "/api/v1/flow/{flow}/service/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_service(), + + # Per-flow request/response services — capability per kind. + _RoutedVariableEndpoint( + endpoint_path="/api/v1/flow/{flow}/service/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_service(), + capability_map=FLOW_KIND_CAPABILITY, ), - SocketEndpoint( - endpoint_path = "/api/v1/flow/{flow}/import/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_import() + + # Per-flow streaming import/export — capability per kind. + _RoutedSocketEndpoint( + endpoint_path="/api/v1/flow/{flow}/import/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_import(), + capability_map=FLOW_IMPORT_CAPABILITY, ), - SocketEndpoint( - endpoint_path = "/api/v1/flow/{flow}/export/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_export() + _RoutedSocketEndpoint( + endpoint_path="/api/v1/flow/{flow}/export/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_export(), + capability_map=FLOW_EXPORT_CAPABILITY, + ), + + StreamEndpoint( + endpoint_path="/api/v1/import-core", + auth=auth, + method="POST", + dispatcher=dispatcher_manager.dispatch_core_import(), + # Cross-subject import — require the admin bundle via a + # single representative capability. + capability="users:admin", ), StreamEndpoint( - endpoint_path = "/api/v1/import-core", - auth = auth, - method = "POST", - dispatcher = dispatcher_manager.dispatch_core_import(), + endpoint_path="/api/v1/export-core", + auth=auth, + method="GET", + dispatcher=dispatcher_manager.dispatch_core_export(), + capability="users:admin", ), StreamEndpoint( - endpoint_path = "/api/v1/export-core", - auth = auth, - method = "GET", - dispatcher = dispatcher_manager.dispatch_core_export(), - ), - StreamEndpoint( - endpoint_path = "/api/v1/document-stream", - auth = auth, - method = "GET", - dispatcher = dispatcher_manager.dispatch_document_stream(), + endpoint_path="/api/v1/document-stream", + auth=auth, + method="GET", + dispatcher=dispatcher_manager.dispatch_document_stream(), + capability="documents:read", ), ] @@ -84,4 +290,3 @@ class EndpointManager: async def start(self): for ep in self.endpoints: await ep.start() - diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py index 903a199c..6832d1e3 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py @@ -10,17 +10,19 @@ import asyncio import uuid import logging +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) class MetricsEndpoint: - def __init__(self, prometheus_url, endpoint_path, auth): + def __init__(self, prometheus_url, endpoint_path, auth, capability): self.prometheus_url = prometheus_url self.path = endpoint_path self.auth = auth - self.operation = "service" + self.capability = capability async def start(self): pass @@ -35,17 +37,7 @@ class MetricsEndpoint: logger.debug(f"Processing metrics request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) path = request.match_info["path"] url = ( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index 9065761c..08629ea2 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -4,6 +4,9 @@ from aiohttp import web, WSMsgType import logging from .. running import Running +from .. capabilities import ( + PUBLIC, AUTHENTICATED, check, auth_failure, access_denied, +) logger = logging.getLogger("socket") logger.setLevel(logging.INFO) @@ -11,12 +14,25 @@ logger.setLevel(logging.INFO) class SocketEndpoint: def __init__( - self, endpoint_path, auth, dispatcher, + self, endpoint_path, auth, dispatcher, capability, + in_band_auth=False, ): + """ + ``in_band_auth=True`` skips the handshake-time auth check. + The WebSocket handshake always succeeds; the dispatcher is + expected to gate itself via the first-frame auth protocol + (see ``Mux``). + + This avoids the browser problem where a 401 on the handshake + is treated as permanent and prevents reconnection, and lets + long-lived sockets refresh their credential mid-session by + sending a new auth frame. + """ self.path = endpoint_path self.auth = auth - self.operation = "socket" + self.capability = capability + self.in_band_auth = in_band_auth self.dispatcher = dispatcher @@ -61,15 +77,29 @@ class SocketEndpoint: raise async def handle(self, request): - """Enhanced handler with better cleanup""" - try: - token = request.query['token'] - except: - token = "" + """Enhanced handler with better cleanup. + + Auth: WebSocket clients pass the bearer token on the + ``?token=...`` query string; we wrap it into a synthetic + Authorization header before delegating to the standard auth + path so the IAM-backed flow (JWT / API key) applies uniformly. + The first-frame auth protocol described in the IAM spec is + a future upgrade.""" + + if not self.in_band_auth and self.capability != PUBLIC: + token = request.query.get("token", "") + if not token: + return auth_failure() + try: + identity = await self.auth.authenticate( + _QueryTokenRequest(token) + ) + except web.HTTPException as e: + return e + if self.capability != AUTHENTICATED: + if not check(identity, self.capability): + return access_denied() - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() - # 50MB max message size ws = web.WebSocketResponse(max_msg_size=52428800) @@ -150,3 +180,11 @@ class SocketEndpoint: web.get(self.path, self.handle), ]) + +class _QueryTokenRequest: + """Minimal shim that exposes headers["Authorization"] to + IamAuth.authenticate(), derived from a query-string token.""" + + def __init__(self, token): + self.headers = {"Authorization": f"Bearer {token}"} + diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py index 38d8846f..7b0c4692 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py @@ -1,82 +1,64 @@ -import asyncio -from aiohttp import web import logging +from aiohttp import web + +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class StreamEndpoint: - def __init__(self, endpoint_path, auth, dispatcher, method="POST"): - + def __init__( + self, endpoint_path, auth, dispatcher, capability, method="POST", + ): self.path = endpoint_path - self.auth = auth - self.operation = "service" + self.capability = capability self.method = method - self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - if self.method == "POST": - app.add_routes([ - web.post(self.path, self.handle), - ]) + app.add_routes([web.post(self.path, self.handle)]) elif self.method == "GET": - app.add_routes([ - web.get(self.path, self.handle), - ]) + app.add_routes([web.get(self.path, self.handle)]) else: - raise RuntimeError("Bad method" + self.method) + raise RuntimeError("Bad method " + self.method) async def handle(self, request): logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) try: - data = request.content async def error(err): - return web.HTTPInternalServerError(text = err) + return web.HTTPInternalServerError(text=err) async def ok( - status=200, reason="OK", type="application/octet-stream" + status=200, reason="OK", + type="application/octet-stream", ): response = web.StreamResponse( - status = status, reason = reason, - headers = {"Content-Type": type} + status=status, reason=reason, + headers={"Content-Type": type}, ) await response.prepare(request) return response - resp = await self.dispatcher.process( - data, error, ok, request - ) - + resp = await self.dispatcher.process(data, error, ok, request) return resp + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py index 608de71b..5e0d9d21 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py @@ -1,27 +1,27 @@ -import asyncio -from aiohttp import web import logging +from aiohttp import web + +from .. capabilities import enforce, enforce_workspace + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class VariableEndpoint: - def __init__(self, endpoint_path, auth, dispatcher): + def __init__(self, endpoint_path, auth, dispatcher, capability): self.path = endpoint_path - self.auth = auth - self.operation = "service" - + self.capability = capability self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - app.add_routes([ web.post(self.path, self.handle), ]) @@ -30,35 +30,25 @@ class VariableEndpoint: logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + identity = await enforce(request, self.auth, self.capability) try: - data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + async def responder(x, fin): pass resp = await self.dispatcher.process( - data, responder, request.match_info + data, responder, request.match_info, ) return web.json_response(resp) + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 4e465bf7..f75f3b25 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -12,7 +12,7 @@ import os from trustgraph.base.logging import setup_logging, add_logging_args from trustgraph.base.pubsub import get_pubsub, add_pubsub_args -from . auth import Authenticator +from . auth import IamAuth from . config.receiver import ConfigReceiver from . dispatch.manager import DispatcherManager @@ -35,7 +35,6 @@ default_prometheus_url = os.getenv("PROMETHEUS_URL", "http://prometheus:9090") default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) default_timeout = 600 default_port = 8088 -default_api_token = os.getenv("GATEWAY_SECRET", "") class Api: @@ -60,13 +59,14 @@ class Api: if not self.prometheus_url.endswith("/"): self.prometheus_url += "/" - api_token = config.get("api_token", default_api_token) - - # Token not set, or token equal empty string means no auth - if api_token: - self.auth = Authenticator(token=api_token) - else: - self.auth = Authenticator(allow_all=True) + # IAM-backed authentication. The legacy GATEWAY_SECRET + # shared-token path has been removed — there is no + # "open for everyone" fallback. The gateway cannot + # authenticate any request until IAM is reachable. + self.auth = IamAuth( + backend=self.pubsub_backend, + id=config.get("id", "api-gateway"), + ) self.config_receiver = ConfigReceiver(self.pubsub_backend) @@ -118,6 +118,7 @@ class Api: config_receiver = self.config_receiver, prefix = "gateway", queue_overrides = queue_overrides, + auth = self.auth, ) self.endpoint_manager = EndpointManager( @@ -132,12 +133,18 @@ class Api: ] async def app_factory(self): - + self.app = web.Application( middlewares=[], client_max_size=256 * 1024 * 1024 ) + # Fetch IAM signing public key before accepting traffic. + # Blocks for a bounded retry window; the gateway starts even + # if IAM is still unreachable (JWT validation will 401 until + # the key is available). + await self.auth.start() + await self.config_receiver.start() for ep in self.endpoints: @@ -189,12 +196,6 @@ def run(): help=f'API request timeout in seconds (default: {default_timeout})', ) - parser.add_argument( - '--api-token', - default=default_api_token, - help=f'Secret API token (default: no auth)', - ) - add_logging_args(parser) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/iam/__init__.py b/trustgraph-flow/trustgraph/iam/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trustgraph-flow/trustgraph/iam/service/__init__.py b/trustgraph-flow/trustgraph/iam/service/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/iam/service/__main__.py b/trustgraph-flow/trustgraph/iam/service/__main__.py new file mode 100644 index 00000000..a731dd63 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/__main__.py @@ -0,0 +1,4 @@ + +from . service import run + +run() diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py new file mode 100644 index 00000000..6e7c7aa5 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -0,0 +1,1132 @@ +""" +IAM business logic. Handles ``IamRequest`` messages and builds +``IamResponse`` messages. Does not concern itself with transport. + +See docs/tech-specs/iam-protocol.md for the wire-level contract and +docs/tech-specs/iam.md for the surrounding architecture. +""" + +import asyncio +import base64 +import datetime +import hashlib +import json +import logging +import os +import secrets +import uuid + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from trustgraph.schema import ( + IamResponse, Error, + UserRecord, WorkspaceRecord, ApiKeyRecord, +) + +from ... tables.iam import IamTableStore + +logger = logging.getLogger(__name__) + + +DEFAULT_WORKSPACE = "default" +BOOTSTRAP_ADMIN_USERNAME = "admin" +BOOTSTRAP_ADMIN_NAME = "Administrator" + +PBKDF2_ITERATIONS = 600_000 +API_KEY_PREFIX = "tg_" +API_KEY_RANDOM_BYTES = 24 + +JWT_ISSUER = "trustgraph-iam" +JWT_TTL_SECONDS = 3600 + + +def _now_iso(): + return datetime.datetime.now(datetime.timezone.utc).isoformat() + + +def _now_dt(): + return datetime.datetime.now(datetime.timezone.utc) + + +def _iso(dt): + if dt is None: + return "" + if isinstance(dt, str): + return dt + if dt.tzinfo is None: + dt = dt.replace(tzinfo=datetime.timezone.utc) + return dt.isoformat() + + +def _hash_password(password): + """Return an encoded PBKDF2-SHA-256 hash of ``password``. + + Format: ``pbkdf2-sha256$$$``. Stored + verbatim in the password_hash column so the algorithm and cost + can be evolved later (new rows get a new prefix; old rows are + verified with their own parameters). + """ + salt = os.urandom(16) + dk = hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), salt, PBKDF2_ITERATIONS, + ) + return ( + f"pbkdf2-sha256${PBKDF2_ITERATIONS}" + f"${base64.b64encode(salt).decode('ascii')}" + f"${base64.b64encode(dk).decode('ascii')}" + ) + + +def _verify_password(password, encoded): + """Constant-time verify ``password`` against an encoded hash.""" + try: + algo, iters, b64_salt, b64_hash = encoded.split("$") + except ValueError: + return False + if algo != "pbkdf2-sha256": + return False + try: + iters = int(iters) + salt = base64.b64decode(b64_salt) + target = base64.b64decode(b64_hash) + except Exception: + return False + dk = hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), salt, iters, + ) + return secrets.compare_digest(dk, target) + + +def _generate_api_key(): + """Return a fresh API-key plaintext of the form ``tg_``.""" + return API_KEY_PREFIX + secrets.token_urlsafe(API_KEY_RANDOM_BYTES) + + +def _hash_api_key(plaintext): + """SHA-256 hex digest of an API key plaintext. Used as the + primary key in ``iam_api_keys`` so ``resolve-api-key`` is O(1).""" + return hashlib.sha256(plaintext.encode("utf-8")).hexdigest() + + +def _err(type, message): + return IamResponse(error=Error(type=type, message=message)) + + +def _parse_expires(s): + if not s: + return None + try: + return datetime.datetime.fromisoformat(s) + except Exception: + return None + + +def _b64url(data): + """URL-safe base64 encode without padding, as required by JWT.""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _generate_signing_keypair(): + """Return (kid, private_pem, public_pem) for a fresh Ed25519 + keypair. Ed25519 / EdDSA: small (32-byte public key), fast, + deterministic, side-channel-resistant by construction, free of + NIST-curve baggage.""" + key = ed25519.Ed25519PrivateKey.generate() + private_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("ascii") + public_pem = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("ascii") + kid = uuid.uuid4().hex[:16] + return kid, private_pem, public_pem + + +def _sign_jwt(kid, private_pem, claims): + """Produce a compact-serialisation EdDSA (Ed25519) JWT for + ``claims``.""" + key = serialization.load_pem_private_key( + private_pem.encode("ascii"), password=None, + ) + if not isinstance(key, ed25519.Ed25519PrivateKey): + raise RuntimeError( + f"signing key is not Ed25519: {type(key).__name__}" + ) + + header = {"alg": "EdDSA", "typ": "JWT", "kid": kid} + header_b = _b64url(json.dumps( + header, separators=(",", ":"), sort_keys=True, + ).encode("utf-8")) + payload_b = _b64url(json.dumps( + claims, separators=(",", ":"), sort_keys=True, + ).encode("utf-8")) + signing_input = f"{header_b}.{payload_b}".encode("ascii") + signature = key.sign(signing_input) + + return f"{header_b}.{payload_b}.{_b64url(signature)}" + + +class IamService: + + def __init__(self, host, username, password, keyspace, + bootstrap_mode, bootstrap_token=None): + self.table_store = IamTableStore( + host, username, password, keyspace, + ) + # bootstrap_mode: "token" or "bootstrap". In "token" mode the + # service auto-seeds on first start using the provided + # bootstrap_token and the ``bootstrap`` operation is refused + # thereafter (indistinguishable from an already-bootstrapped + # deployment per the error policy). In "bootstrap" mode the + # ``bootstrap`` operation is live until tables are populated. + if bootstrap_mode not in ("token", "bootstrap"): + raise ValueError( + f"bootstrap_mode must be 'token' or 'bootstrap', " + f"got {bootstrap_mode!r}" + ) + if bootstrap_mode == "token" and not bootstrap_token: + raise ValueError( + "bootstrap_mode='token' requires bootstrap_token" + ) + self.bootstrap_mode = bootstrap_mode + self.bootstrap_token = bootstrap_token + + self._signing_key = None + self._signing_key_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Dispatch + # ------------------------------------------------------------------ + + async def handle(self, v): + op = v.operation + + try: + if op == "bootstrap": + return await self.handle_bootstrap(v) + if op == "resolve-api-key": + return await self.handle_resolve_api_key(v) + if op == "create-user": + return await self.handle_create_user(v) + if op == "list-users": + return await self.handle_list_users(v) + if op == "create-api-key": + return await self.handle_create_api_key(v) + if op == "list-api-keys": + return await self.handle_list_api_keys(v) + if op == "revoke-api-key": + return await self.handle_revoke_api_key(v) + if op == "login": + return await self.handle_login(v) + if op == "get-signing-key-public": + return await self.handle_get_signing_key_public(v) + if op == "change-password": + return await self.handle_change_password(v) + if op == "reset-password": + return await self.handle_reset_password(v) + if op == "get-user": + return await self.handle_get_user(v) + if op == "update-user": + return await self.handle_update_user(v) + if op == "disable-user": + return await self.handle_disable_user(v) + if op == "enable-user": + return await self.handle_enable_user(v) + if op == "delete-user": + return await self.handle_delete_user(v) + if op == "create-workspace": + return await self.handle_create_workspace(v) + if op == "list-workspaces": + return await self.handle_list_workspaces(v) + if op == "get-workspace": + return await self.handle_get_workspace(v) + if op == "update-workspace": + return await self.handle_update_workspace(v) + if op == "disable-workspace": + return await self.handle_disable_workspace(v) + if op == "rotate-signing-key": + return await self.handle_rotate_signing_key(v) + + return _err( + "invalid-argument", + f"unknown or not-yet-implemented operation: {op!r}", + ) + + except Exception as e: + logger.error( + f"IAM {op} failed: {type(e).__name__}: {e}", + exc_info=True, + ) + return _err("internal-error", str(e)) + + # ------------------------------------------------------------------ + # Record conversion + # ------------------------------------------------------------------ + + def _row_to_user_record(self, row): + ( + id, workspace, username, name, email, _password_hash, + roles, enabled, must_change_password, created, + ) = row + return UserRecord( + id=id or "", + workspace=workspace or "", + username=username or "", + name=name or "", + email=email or "", + roles=sorted(roles) if roles else [], + enabled=bool(enabled), + must_change_password=bool(must_change_password), + created=_iso(created), + ) + + def _row_to_api_key_record(self, row): + ( + _key_hash, id, user_id, name, prefix, expires, + created, last_used, + ) = row + return ApiKeyRecord( + id=id or "", + user_id=user_id or "", + name=name or "", + prefix=prefix or "", + expires=_iso(expires), + created=_iso(created), + last_used=_iso(last_used), + ) + + # ------------------------------------------------------------------ + # bootstrap + # ------------------------------------------------------------------ + + async def auto_bootstrap_if_token_mode(self): + """Called from the service processor at startup. In + ``token`` mode, if tables are empty, seeds the default + workspace / admin / signing key using the operator-provided + bootstrap token. The admin's API key plaintext is *the* + ``bootstrap_token`` — the operator already knows it, nothing + needs to be returned or logged. + + In ``bootstrap`` mode this is a no-op; seeding happens on + explicit ``bootstrap`` operation invocation.""" + if self.bootstrap_mode != "token": + return + + if await self.table_store.any_workspace_exists(): + logger.info( + "IAM: token mode, tables already populated; skipping " + "auto-bootstrap" + ) + return + + logger.info("IAM: token mode, empty tables; auto-bootstrapping") + await self._seed_tables(self.bootstrap_token) + logger.info( + "IAM: auto-bootstrap complete using operator-provided token" + ) + + async def _seed_tables(self, api_key_plaintext): + """Shared seeding logic used by token-mode auto-bootstrap and + bootstrap-mode handle_bootstrap. Creates the default + workspace, admin user, admin API key (using the given + plaintext), and an initial signing key. Returns the admin + user id.""" + now = _now_dt() + + await self.table_store.put_workspace( + id=DEFAULT_WORKSPACE, + name="Default", + enabled=True, + created=now, + ) + + admin_user_id = str(uuid.uuid4()) + admin_password = secrets.token_urlsafe(32) + await self.table_store.put_user( + id=admin_user_id, + workspace=DEFAULT_WORKSPACE, + username=BOOTSTRAP_ADMIN_USERNAME, + name=BOOTSTRAP_ADMIN_NAME, + email="", + password_hash=_hash_password(admin_password), + roles=["admin"], + enabled=True, + must_change_password=True, + created=now, + ) + + key_id = str(uuid.uuid4()) + await self.table_store.put_api_key( + key_hash=_hash_api_key(api_key_plaintext), + id=key_id, + user_id=admin_user_id, + name="bootstrap", + prefix=api_key_plaintext[:len(API_KEY_PREFIX) + 4], + expires=None, + created=now, + last_used=None, + ) + + kid, private_pem, public_pem = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=kid, + private_pem=private_pem, + public_pem=public_pem, + created=now, + retired=None, + ) + self._signing_key = (kid, private_pem, public_pem) + + logger.info( + f"IAM seeded: workspace={DEFAULT_WORKSPACE!r}, " + f"admin user_id={admin_user_id}, signing key kid={kid}" + ) + return admin_user_id + + async def handle_bootstrap(self, v): + """Explicit bootstrap op. Only available in ``bootstrap`` + mode and only when tables are empty. Every other case is + masked to a generic auth failure — the caller cannot + distinguish 'not in bootstrap mode' from 'already + bootstrapped' from 'operation forbidden'.""" + if self.bootstrap_mode != "bootstrap": + return _err("auth-failed", "auth failure") + + if await self.table_store.any_workspace_exists(): + return _err("auth-failed", "auth failure") + + plaintext = _generate_api_key() + admin_user_id = await self._seed_tables(plaintext) + + return IamResponse( + bootstrap_admin_user_id=admin_user_id, + bootstrap_admin_api_key=plaintext, + ) + + # ------------------------------------------------------------------ + # Signing key helpers + # ------------------------------------------------------------------ + + async def _get_active_signing_key(self): + """Return ``(kid, private_pem, public_pem)`` for the active + signing key. Loads from Cassandra on first call. Generates + and persists a new key if none exists — covers the case where + ``login`` is called before ``bootstrap`` (shouldn't happen in + practice but keeps the service internally consistent).""" + if self._signing_key is not None: + return self._signing_key + + async with self._signing_key_lock: + if self._signing_key is not None: + return self._signing_key + + rows = await self.table_store.list_signing_keys() + active = [r for r in rows if r[4] is None] + + if active: + row = active[0] + self._signing_key = (row[0], row[1], row[2]) + logger.info( + f"IAM: loaded active signing key kid={row[0]}" + ) + return self._signing_key + + kid, private_pem, public_pem = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=kid, + private_pem=private_pem, + public_pem=public_pem, + created=_now_dt(), + retired=None, + ) + self._signing_key = (kid, private_pem, public_pem) + logger.info( + f"IAM: generated active signing key kid={kid} " + f"(no existing key found)" + ) + return self._signing_key + + # ------------------------------------------------------------------ + # login + # ------------------------------------------------------------------ + + async def handle_login(self, v): + if not v.username: + return _err("auth-failed", "username required") + if not v.password: + return _err("auth-failed", "password required") + + # Login accepts an optional workspace parameter. If omitted + # we use the default workspace (OSS single-workspace + # assumption). Multi-workspace enterprise editions swap in a + # resolver that looks across the caller's permitted set. + workspace = v.workspace or DEFAULT_WORKSPACE + + user_id = await self.table_store.get_user_id_by_username( + workspace, v.username, + ) + if not user_id: + return _err("auth-failed", "no such user") + + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return _err("auth-failed", "user disappeared") + + ( + id, ws, _username, _name, _email, password_hash, + roles, enabled, _mcp, _created, + ) = user_row + + if not enabled: + return _err("auth-failed", "user disabled") + if not password_hash or not _verify_password( + v.password, password_hash, + ): + return _err("auth-failed", "bad credentials") + + ws_row = await self.table_store.get_workspace(ws) + if ws_row is None or not ws_row[2]: + return _err("auth-failed", "workspace disabled") + + kid, private_pem, _ = await self._get_active_signing_key() + + now_ts = int(_now_dt().timestamp()) + exp_ts = now_ts + JWT_TTL_SECONDS + claims = { + "iss": JWT_ISSUER, + "sub": id, + "workspace": ws, + "roles": sorted(roles) if roles else [], + "iat": now_ts, + "exp": exp_ts, + } + token = _sign_jwt(kid, private_pem, claims) + + expires_iso = datetime.datetime.fromtimestamp( + exp_ts, tz=datetime.timezone.utc, + ).isoformat() + + return IamResponse(jwt=token, jwt_expires=expires_iso) + + # ------------------------------------------------------------------ + # get-signing-key-public + # ------------------------------------------------------------------ + + async def handle_get_signing_key_public(self, v): + _, _, public_pem = await self._get_active_signing_key() + return IamResponse(signing_key_public=public_pem) + + # ------------------------------------------------------------------ + # Record-conversion helper for workspaces + # ------------------------------------------------------------------ + + def _row_to_workspace_record(self, row): + id, name, enabled, created = row + return WorkspaceRecord( + id=id or "", + name=name or "", + enabled=bool(enabled), + created=_iso(created), + ) + + async def _user_in_workspace(self, user_id, workspace): + """Return (user_row, error_response_or_None). Loads the user + record, verifies it exists, is enabled, and belongs to + ``workspace``. The workspace scope check rejects cross- + workspace admin attempts.""" + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return None, _err("not-found", "user not found") + if user_row[1] != workspace: + return None, _err( + "operation-not-permitted", + "user is in a different workspace", + ) + return user_row, None + + # ------------------------------------------------------------------ + # change-password + # ------------------------------------------------------------------ + + async def handle_change_password(self, v): + if not v.user_id: + return _err("invalid-argument", "user_id required") + if not v.password: + return _err("invalid-argument", "password (current) required") + if not v.new_password: + return _err("invalid-argument", "new_password required") + + user_row = await self.table_store.get_user(v.user_id) + if user_row is None: + return _err("auth-failed", "no such user") + + _id, _ws, _un, _name, _email, password_hash, _r, enabled, _mcp, _c = ( + user_row + ) + if not enabled: + return _err("auth-failed", "user disabled") + if not password_hash or not _verify_password( + v.password, password_hash, + ): + return _err("auth-failed", "bad credentials") + + await self.table_store.update_user_password( + id=v.user_id, + password_hash=_hash_password(v.new_password), + must_change_password=False, + ) + return IamResponse() + + # ------------------------------------------------------------------ + # reset-password + # ------------------------------------------------------------------ + + async def handle_reset_password(self, v): + if not v.workspace: + return _err( + "invalid-argument", + "workspace required for reset-password", + ) + if not v.user_id: + return _err("invalid-argument", "user_id required") + + _, err = await self._user_in_workspace(v.user_id, v.workspace) + if err is not None: + return err + + temporary = secrets.token_urlsafe(12) + await self.table_store.update_user_password( + id=v.user_id, + password_hash=_hash_password(temporary), + must_change_password=True, + ) + return IamResponse(temporary_password=temporary) + + # ------------------------------------------------------------------ + # get-user / update-user / disable-user + # ------------------------------------------------------------------ + + async def handle_get_user(self, v): + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + user_row, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + return IamResponse(user=self._row_to_user_record(user_row)) + + async def handle_update_user(self, v): + """Update user profile fields: name, email, roles, enabled, + must_change_password. Username is immutable — change it by + creating a new user and disabling the old one. Password + changes go through change-password / reset-password.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + if v.user is None: + return _err("invalid-argument", "user field required") + if v.user.password: + return _err( + "invalid-argument", + "password cannot be changed via update-user; " + "use change-password or reset-password", + ) + if v.user.username and v.user.username != "": + # Compare to existing. Username-change not allowed. + existing, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + if v.user.username != existing[2]: + return _err( + "invalid-argument", + "username is immutable; create a new user " + "instead", + ) + else: + existing, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + + # Carry forward fields the caller didn't provide. + ( + _id, _ws, _username, cur_name, cur_email, _pw, + cur_roles, cur_enabled, cur_mcp, _created, + ) = existing + + new_name = v.user.name if v.user.name else cur_name + new_email = v.user.email if v.user.email else cur_email + new_roles = list(v.user.roles) if v.user.roles else list( + cur_roles or [], + ) + new_enabled = v.user.enabled if v.user.enabled is not None else ( + cur_enabled + ) + new_mcp = ( + v.user.must_change_password + if v.user.must_change_password is not None + else cur_mcp + ) + + await self.table_store.update_user_profile( + id=v.user_id, + name=new_name, + email=new_email, + roles=new_roles, + enabled=new_enabled, + must_change_password=new_mcp, + ) + + updated = await self.table_store.get_user(v.user_id) + return IamResponse(user=self._row_to_user_record(updated)) + + async def handle_disable_user(self, v): + """Soft-delete: set enabled=false and revoke every API key + belonging to the user.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + _, err = await self._user_in_workspace(v.user_id, v.workspace) + if err is not None: + return err + + await self.table_store.update_user_enabled( + id=v.user_id, enabled=False, + ) + + # Revoke all their API keys. + key_rows = await self.table_store.list_api_keys_by_user(v.user_id) + for kr in key_rows: + await self.table_store.delete_api_key(kr[0]) + + return IamResponse() + + async def handle_enable_user(self, v): + """Re-enable a previously disabled user. Does not restore + API keys — those have to be re-issued by the admin.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + _, err = await self._user_in_workspace(v.user_id, v.workspace) + if err is not None: + return err + + await self.table_store.update_user_enabled( + id=v.user_id, enabled=True, + ) + return IamResponse() + + async def handle_delete_user(self, v): + """Hard-delete a user. Removes the ``iam_users`` row, the + ``iam_users_by_username`` lookup row, and every API key + belonging to the user. + + Unlike disable, this frees the username for re-use and + removes the user's personal data from storage (intended to + cover GDPR erasure-style requirements). When audit logging + lands, the decision to delete vs. anonymise referenced audit + rows will need to be revisited.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + user_row, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + + # user_row indices match get_user columns. Username is [2]. + username = user_row[2] + + # Revoke all API keys. + key_rows = await self.table_store.list_api_keys_by_user(v.user_id) + for kr in key_rows: + await self.table_store.delete_api_key(kr[0]) + + # Remove username lookup. + if username: + await self.table_store.delete_username_lookup( + v.workspace, username, + ) + + # Remove user record. + await self.table_store.delete_user(v.user_id) + + return IamResponse() + + # ------------------------------------------------------------------ + # Workspace CRUD + # ------------------------------------------------------------------ + + async def handle_create_workspace(self, v): + if v.workspace_record is None or not v.workspace_record.id: + return _err( + "invalid-argument", + "workspace_record.id required for create-workspace", + ) + if v.workspace_record.id.startswith("_"): + return _err( + "invalid-argument", + "workspace ids beginning with '_' are reserved", + ) + + existing = await self.table_store.get_workspace( + v.workspace_record.id, + ) + if existing is not None: + return _err("duplicate", "workspace already exists") + + now = _now_dt() + await self.table_store.put_workspace( + id=v.workspace_record.id, + name=v.workspace_record.name or v.workspace_record.id, + enabled=v.workspace_record.enabled, + created=now, + ) + row = await self.table_store.get_workspace(v.workspace_record.id) + return IamResponse(workspace=self._row_to_workspace_record(row)) + + async def handle_list_workspaces(self, v): + rows = await self.table_store.list_workspaces() + return IamResponse( + workspaces=[ + self._row_to_workspace_record(r) for r in rows + ], + ) + + async def handle_get_workspace(self, v): + if v.workspace_record is None or not v.workspace_record.id: + return _err("invalid-argument", "workspace_record.id required") + row = await self.table_store.get_workspace(v.workspace_record.id) + if row is None: + return _err("not-found", "workspace not found") + return IamResponse(workspace=self._row_to_workspace_record(row)) + + async def handle_update_workspace(self, v): + """Update workspace name / enabled. The id is immutable.""" + if v.workspace_record is None or not v.workspace_record.id: + return _err("invalid-argument", "workspace_record.id required") + row = await self.table_store.get_workspace(v.workspace_record.id) + if row is None: + return _err("not-found", "workspace not found") + + _, cur_name, cur_enabled, _created = row + new_name = ( + v.workspace_record.name + if v.workspace_record.name else cur_name + ) + new_enabled = ( + v.workspace_record.enabled + if v.workspace_record.enabled is not None + else cur_enabled + ) + + await self.table_store.update_workspace( + id=v.workspace_record.id, + name=new_name, + enabled=new_enabled, + ) + updated = await self.table_store.get_workspace( + v.workspace_record.id, + ) + return IamResponse( + workspace=self._row_to_workspace_record(updated), + ) + + async def handle_disable_workspace(self, v): + """Set enabled=false, disable every user in the workspace, + revoke every API key belonging to those users.""" + if v.workspace_record is None or not v.workspace_record.id: + return _err("invalid-argument", "workspace_record.id required") + + row = await self.table_store.get_workspace(v.workspace_record.id) + if row is None: + return _err("not-found", "workspace not found") + + await self.table_store.update_workspace( + id=v.workspace_record.id, + name=row[1] or v.workspace_record.id, + enabled=False, + ) + + user_rows = await self.table_store.list_users_by_workspace( + v.workspace_record.id, + ) + for ur in user_rows: + user_id = ur[0] + await self.table_store.update_user_enabled( + id=user_id, enabled=False, + ) + key_rows = await self.table_store.list_api_keys_by_user(user_id) + for kr in key_rows: + await self.table_store.delete_api_key(kr[0]) + + return IamResponse() + + # ------------------------------------------------------------------ + # rotate-signing-key + # ------------------------------------------------------------------ + + async def handle_rotate_signing_key(self, v): + """Create a new Ed25519 signing key, retire the current + active key, switch the in-memory cache over. + + The retired key row is kept in ``iam_signing_keys`` so the + gateway's JWT validator can continue to validate previously- + issued tokens during the grace period. Actual grace-period + enforcement (time-window acceptance at the validator) lands + with the gateway auth middleware work.""" + + # Retire the currently-active key, if any. + current = await self._get_active_signing_key() + now = _now_dt() + if current is not None: + cur_kid, _cur_priv, _cur_pub = current + await self.table_store.retire_signing_key( + kid=cur_kid, retired=now, + ) + + new_kid, new_priv, new_pub = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=new_kid, + private_pem=new_priv, + public_pem=new_pub, + created=now, + retired=None, + ) + self._signing_key = (new_kid, new_priv, new_pub) + logger.info( + f"IAM: rotated signing key. " + f"New kid={new_kid}, retired kid={(current or (None,))[0]}" + ) + return IamResponse() + + # ------------------------------------------------------------------ + # resolve-api-key + # ------------------------------------------------------------------ + + async def handle_resolve_api_key(self, v): + if not v.api_key: + return _err("auth-failed", "no api key") + + row = await self.table_store.get_api_key_by_hash( + _hash_api_key(v.api_key), + ) + if row is None: + return _err("auth-failed", "unknown api key") + + ( + _key_hash, _id, user_id, _name, _prefix, expires, + _created, _last_used, + ) = row + + if expires is not None: + exp_dt = expires + if isinstance(exp_dt, str): + exp_dt = datetime.datetime.fromisoformat(exp_dt) + if exp_dt.tzinfo is None: + exp_dt = exp_dt.replace(tzinfo=datetime.timezone.utc) + if exp_dt < _now_dt(): + return _err("auth-failed", "api key expired") + + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return _err("auth-failed", "owning user missing") + user = self._row_to_user_record(user_row) + if not user.enabled: + return _err("auth-failed", "owning user disabled") + + # Workspace-disabled check. + ws_row = await self.table_store.get_workspace(user.workspace) + if ws_row is None or not ws_row[2]: + return _err("auth-failed", "owning workspace disabled") + + return IamResponse( + resolved_user_id=user.id, + resolved_workspace=user.workspace, + resolved_roles=list(user.roles), + ) + + # ------------------------------------------------------------------ + # create-user + # ------------------------------------------------------------------ + + async def handle_create_user(self, v): + if not v.workspace: + return _err( + "invalid-argument", "workspace required for create-user", + ) + if v.user is None: + return _err( + "invalid-argument", "user field required for create-user", + ) + if not v.user.username: + return _err("invalid-argument", "user.username required") + if not v.user.password: + return _err("invalid-argument", "user.password required") + + # Workspace must exist and be enabled. + ws = await self.table_store.get_workspace(v.workspace) + if ws is None or not ws[2]: + return _err("not-found", "workspace not found or disabled") + + # Uniqueness on username within workspace. + existing = await self.table_store.get_user_id_by_username( + v.workspace, v.user.username, + ) + if existing: + return _err("duplicate", "username already exists") + + user_id = str(uuid.uuid4()) + now = _now_dt() + + await self.table_store.put_user( + id=user_id, + workspace=v.workspace, + username=v.user.username, + name=v.user.name or v.user.username, + email=v.user.email or "", + password_hash=_hash_password(v.user.password), + roles=list(v.user.roles or []), + enabled=v.user.enabled, + must_change_password=v.user.must_change_password, + created=now, + ) + + row = await self.table_store.get_user(user_id) + return IamResponse(user=self._row_to_user_record(row)) + + # ------------------------------------------------------------------ + # list-users + # ------------------------------------------------------------------ + + async def handle_list_users(self, v): + if not v.workspace: + return _err( + "invalid-argument", "workspace required for list-users", + ) + + rows = await self.table_store.list_users_by_workspace(v.workspace) + return IamResponse( + users=[self._row_to_user_record(r) for r in rows], + ) + + # ------------------------------------------------------------------ + # create-api-key + # ------------------------------------------------------------------ + + async def handle_create_api_key(self, v): + if not v.workspace: + return _err( + "invalid-argument", "workspace required for create-api-key", + ) + if v.key is None or not v.key.user_id: + return _err("invalid-argument", "key.user_id required") + if not v.key.name: + return _err("invalid-argument", "key.name required") + + # Target user must exist and belong to the caller's workspace. + user_row = await self.table_store.get_user(v.key.user_id) + if user_row is None: + return _err("not-found", "user not found") + if user_row[1] != v.workspace: + return _err( + "operation-not-permitted", + "target user is in a different workspace", + ) + + plaintext = _generate_api_key() + key_id = str(uuid.uuid4()) + now = _now_dt() + expires_dt = _parse_expires(v.key.expires) + + await self.table_store.put_api_key( + key_hash=_hash_api_key(plaintext), + id=key_id, + user_id=v.key.user_id, + name=v.key.name, + prefix=plaintext[:len(API_KEY_PREFIX) + 4], + expires=expires_dt, + created=now, + last_used=None, + ) + + row = await self.table_store.get_api_key_by_hash( + _hash_api_key(plaintext), + ) + return IamResponse( + api_key_plaintext=plaintext, + api_key=self._row_to_api_key_record(row), + ) + + # ------------------------------------------------------------------ + # list-api-keys + # ------------------------------------------------------------------ + + async def handle_list_api_keys(self, v): + if not v.workspace: + return _err( + "invalid-argument", + "workspace required for list-api-keys", + ) + if not v.user_id: + return _err( + "invalid-argument", "user_id required for list-api-keys", + ) + + # Workspace-scope check: user must live in this workspace. + user_row = await self.table_store.get_user(v.user_id) + if user_row is None or user_row[1] != v.workspace: + return _err("not-found", "user not found in workspace") + + rows = await self.table_store.list_api_keys_by_user(v.user_id) + return IamResponse( + api_keys=[self._row_to_api_key_record(r) for r in rows], + ) + + # ------------------------------------------------------------------ + # revoke-api-key + # ------------------------------------------------------------------ + + async def handle_revoke_api_key(self, v): + if not v.workspace: + return _err( + "invalid-argument", + "workspace required for revoke-api-key", + ) + if not v.key_id: + return _err("invalid-argument", "key_id required") + + row = await self.table_store.get_api_key_by_id(v.key_id) + if row is None: + return _err("not-found", "api key not found") + + key_hash, _id, user_id, _name, _prefix, _expires, _c, _lu = row + # Workspace-scope check via the owning user. + user_row = await self.table_store.get_user(user_id) + if user_row is None or user_row[1] != v.workspace: + return _err( + "operation-not-permitted", + "key belongs to a different workspace", + ) + + await self.table_store.delete_api_key(key_hash) + return IamResponse() diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py new file mode 100644 index 00000000..8ea31cf0 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/service.py @@ -0,0 +1,210 @@ +""" +IAM service processor. Terminates the IAM request queue and forwards +each request to the IamService business logic, then returns the +response on the IAM response queue. + +Shape mirrors trustgraph.config.service. +""" + +import logging + +from trustgraph.schema import Error +from trustgraph.schema import IamRequest, IamResponse +from trustgraph.schema import iam_request_queue, iam_response_queue + +from trustgraph.base import AsyncProcessor, Consumer, Producer +from trustgraph.base import ConsumerMetrics, ProducerMetrics +from trustgraph.base.cassandra_config import ( + add_cassandra_args, resolve_cassandra_config, +) + +from . iam import IamService + +logger = logging.getLogger(__name__) + +default_ident = "iam-svc" + +default_iam_request_queue = iam_request_queue +default_iam_response_queue = iam_response_queue + + +class Processor(AsyncProcessor): + + def __init__(self, **params): + + iam_req_q = params.get( + "iam_request_queue", default_iam_request_queue, + ) + iam_resp_q = params.get( + "iam_response_queue", default_iam_response_queue, + ) + + bootstrap_mode = params.get("bootstrap_mode") + bootstrap_token = params.get("bootstrap_token") + + if bootstrap_mode not in ("token", "bootstrap"): + raise RuntimeError( + "iam-svc: --bootstrap-mode is required. Set to 'token' " + "(with --bootstrap-token) for production, or 'bootstrap' " + "to enable the explicit bootstrap operation over the " + "pub/sub bus (dev / quick-start only, not safe under " + "public exposure). Refusing to start." + ) + if bootstrap_mode == "token" and not bootstrap_token: + raise RuntimeError( + "iam-svc: --bootstrap-mode=token requires " + "--bootstrap-token. Refusing to start." + ) + if bootstrap_mode == "bootstrap" and bootstrap_token: + raise RuntimeError( + "iam-svc: --bootstrap-token is not accepted when " + "--bootstrap-mode=bootstrap. Ambiguous intent. " + "Refusing to start." + ) + + self.bootstrap_mode = bootstrap_mode + self.bootstrap_token = bootstrap_token + + cassandra_host = params.get("cassandra_host") + cassandra_username = params.get("cassandra_username") + cassandra_password = params.get("cassandra_password") + + hosts, username, password, keyspace = resolve_cassandra_config( + host=cassandra_host, + username=cassandra_username, + password=cassandra_password, + default_keyspace="iam", + ) + + self.cassandra_host = hosts + self.cassandra_username = username + self.cassandra_password = password + + super().__init__( + **params | { + "iam_request_schema": IamRequest.__name__, + "iam_response_schema": IamResponse.__name__, + "cassandra_host": self.cassandra_host, + "cassandra_username": self.cassandra_username, + "cassandra_password": self.cassandra_password, + } + ) + + iam_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="iam-request", + ) + iam_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="iam-response", + ) + + self.iam_request_topic = iam_req_q + + self.iam_request_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=iam_req_q, + subscriber=self.id, + schema=IamRequest, + handler=self.on_iam_request, + metrics=iam_request_metrics, + ) + + self.iam_response_producer = Producer( + backend=self.pubsub, + topic=iam_resp_q, + schema=IamResponse, + metrics=iam_response_metrics, + ) + + self.iam = IamService( + host=self.cassandra_host, + username=self.cassandra_username, + password=self.cassandra_password, + keyspace=keyspace, + bootstrap_mode=self.bootstrap_mode, + bootstrap_token=self.bootstrap_token, + ) + + logger.info( + f"IAM service initialised (bootstrap-mode={self.bootstrap_mode})" + ) + + async def start(self): + await self.pubsub.ensure_topic(self.iam_request_topic) + # Token-mode auto-bootstrap runs before we accept requests so + # the first inbound call always sees a populated table. + await self.iam.auto_bootstrap_if_token_mode() + await self.iam_request_consumer.start() + + async def on_iam_request(self, msg, consumer, flow): + + id = None + try: + v = msg.value() + id = msg.properties()["id"] + logger.debug( + f"Handling IAM request {id} op={v.operation!r}" + ) + resp = await self.iam.handle(v) + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + except Exception as e: + logger.error( + f"IAM request failed: {type(e).__name__}: {e}", + exc_info=True, + ) + resp = IamResponse( + error=Error(type="internal-error", message=str(e)), + ) + if id is not None: + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + + @staticmethod + def add_args(parser): + AsyncProcessor.add_args(parser) + + parser.add_argument( + "--iam-request-queue", + default=default_iam_request_queue, + help=f"IAM request queue (default: {default_iam_request_queue})", + ) + parser.add_argument( + "--iam-response-queue", + default=default_iam_response_queue, + help=f"IAM response queue (default: {default_iam_response_queue})", + ) + parser.add_argument( + "--bootstrap-mode", + default=None, + choices=["token", "bootstrap"], + help=( + "IAM bootstrap mode (required). " + "'token' = operator supplies the initial admin API " + "key via --bootstrap-token; auto-seeds on first start, " + "bootstrap operation refused. " + "'bootstrap' = bootstrap operation is live over the " + "bus until tables are populated; a token is generated " + "and returned by tg-bootstrap-iam. Unsafe to run " + "'bootstrap' mode with public exposure." + ), + ) + parser.add_argument( + "--bootstrap-token", + default=None, + help=( + "Initial admin API key plaintext, required when " + "--bootstrap-mode=token. Treat as a one-time " + "credential: the operator should rotate to a new key " + "and revoke this one after first use." + ), + ) + + add_cassandra_args(parser) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py new file mode 100644 index 00000000..3d41ebbd --- /dev/null +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -0,0 +1,422 @@ +""" +IAM Cassandra table store. + +Tables: + - iam_workspaces (id primary key) + - iam_users (id primary key) + iam_users_by_username lookup table + (workspace, username) -> id + - iam_api_keys (key_hash primary key) with secondary index on user_id + - iam_signing_keys (kid primary key) — RSA keypairs for JWT signing + +See docs/tech-specs/iam-protocol.md for the wire-level context. +""" + +import logging + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider +from ssl import SSLContext, PROTOCOL_TLSv1_2 + +from . cassandra_async import async_execute + +logger = logging.getLogger(__name__) + + +class IamTableStore: + + def __init__( + self, + cassandra_host, cassandra_username, cassandra_password, + keyspace, + ): + self.keyspace = keyspace + + logger.info("IAM: connecting to Cassandra...") + + if isinstance(cassandra_host, str): + cassandra_host = [h.strip() for h in cassandra_host.split(",")] + + if cassandra_username and cassandra_password: + ssl_context = SSLContext(PROTOCOL_TLSv1_2) + auth_provider = PlainTextAuthProvider( + username=cassandra_username, password=cassandra_password, + ) + self.cluster = Cluster( + cassandra_host, + auth_provider=auth_provider, + ssl_context=ssl_context, + ) + else: + self.cluster = Cluster(cassandra_host) + + self.cassandra = self.cluster.connect() + + logger.info("IAM: connected.") + + self._ensure_schema() + self._prepare_statements() + + def _ensure_schema(self): + # FIXME: Replication factor should be configurable. + self.cassandra.execute(f""" + create keyspace if not exists {self.keyspace} + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }}; + """) + self.cassandra.set_keyspace(self.keyspace) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_workspaces ( + id text PRIMARY KEY, + name text, + enabled boolean, + created timestamp + ); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_users ( + id text PRIMARY KEY, + workspace text, + username text, + name text, + email text, + password_hash text, + roles set, + enabled boolean, + must_change_password boolean, + created timestamp + ); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_users_by_username ( + workspace text, + username text, + user_id text, + PRIMARY KEY ((workspace), username) + ); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_api_keys ( + key_hash text PRIMARY KEY, + id text, + user_id text, + name text, + prefix text, + expires timestamp, + created timestamp, + last_used timestamp + ); + """) + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS iam_api_keys_user_id_idx + ON iam_api_keys (user_id); + """) + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS iam_api_keys_id_idx + ON iam_api_keys (id); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_signing_keys ( + kid text PRIMARY KEY, + private_pem text, + public_pem text, + created timestamp, + retired timestamp + ); + """) + + logger.info("IAM: Cassandra schema OK.") + + def _prepare_statements(self): + c = self.cassandra + + self.put_workspace_stmt = c.prepare(""" + INSERT INTO iam_workspaces (id, name, enabled, created) + VALUES (?, ?, ?, ?) + """) + self.get_workspace_stmt = c.prepare(""" + SELECT id, name, enabled, created FROM iam_workspaces + WHERE id = ? + """) + self.list_workspaces_stmt = c.prepare(""" + SELECT id, name, enabled, created FROM iam_workspaces + """) + + self.put_user_stmt = c.prepare(""" + INSERT INTO iam_users ( + id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """) + self.get_user_stmt = c.prepare(""" + SELECT id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + FROM iam_users WHERE id = ? + """) + self.list_users_by_workspace_stmt = c.prepare(""" + SELECT id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + FROM iam_users WHERE workspace = ? ALLOW FILTERING + """) + + self.put_username_lookup_stmt = c.prepare(""" + INSERT INTO iam_users_by_username (workspace, username, user_id) + VALUES (?, ?, ?) + """) + self.get_user_id_by_username_stmt = c.prepare(""" + SELECT user_id FROM iam_users_by_username + WHERE workspace = ? AND username = ? + """) + self.delete_username_lookup_stmt = c.prepare(""" + DELETE FROM iam_users_by_username + WHERE workspace = ? AND username = ? + """) + self.delete_user_stmt = c.prepare(""" + DELETE FROM iam_users WHERE id = ? + """) + + self.put_api_key_stmt = c.prepare(""" + INSERT INTO iam_api_keys ( + key_hash, id, user_id, name, prefix, expires, + created, last_used + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """) + self.get_api_key_by_hash_stmt = c.prepare(""" + SELECT key_hash, id, user_id, name, prefix, expires, + created, last_used + FROM iam_api_keys WHERE key_hash = ? + """) + self.get_api_key_by_id_stmt = c.prepare(""" + SELECT key_hash, id, user_id, name, prefix, expires, + created, last_used + FROM iam_api_keys WHERE id = ? + """) + self.list_api_keys_by_user_stmt = c.prepare(""" + SELECT key_hash, id, user_id, name, prefix, expires, + created, last_used + FROM iam_api_keys WHERE user_id = ? + """) + self.delete_api_key_stmt = c.prepare(""" + DELETE FROM iam_api_keys WHERE key_hash = ? + """) + + self.put_signing_key_stmt = c.prepare(""" + INSERT INTO iam_signing_keys ( + kid, private_pem, public_pem, created, retired + ) + VALUES (?, ?, ?, ?, ?) + """) + self.list_signing_keys_stmt = c.prepare(""" + SELECT kid, private_pem, public_pem, created, retired + FROM iam_signing_keys + """) + self.retire_signing_key_stmt = c.prepare(""" + UPDATE iam_signing_keys SET retired = ? WHERE kid = ? + """) + + self.update_user_profile_stmt = c.prepare(""" + UPDATE iam_users + SET name = ?, email = ?, roles = ?, enabled = ?, + must_change_password = ? + WHERE id = ? + """) + self.update_user_password_stmt = c.prepare(""" + UPDATE iam_users + SET password_hash = ?, must_change_password = ? + WHERE id = ? + """) + self.update_user_enabled_stmt = c.prepare(""" + UPDATE iam_users SET enabled = ? WHERE id = ? + """) + + self.update_workspace_stmt = c.prepare(""" + UPDATE iam_workspaces SET name = ?, enabled = ? + WHERE id = ? + """) + + # ------------------------------------------------------------------ + # Workspaces + # ------------------------------------------------------------------ + + async def put_workspace(self, id, name, enabled, created): + await async_execute( + self.cassandra, self.put_workspace_stmt, + (id, name, enabled, created), + ) + + async def get_workspace(self, id): + rows = await async_execute( + self.cassandra, self.get_workspace_stmt, (id,), + ) + return rows[0] if rows else None + + async def list_workspaces(self): + return await async_execute( + self.cassandra, self.list_workspaces_stmt, + ) + + # ------------------------------------------------------------------ + # Users + # ------------------------------------------------------------------ + + async def put_user( + self, id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created, + ): + await async_execute( + self.cassandra, self.put_user_stmt, + ( + id, workspace, username, name, email, password_hash, + set(roles) if roles else set(), + enabled, must_change_password, created, + ), + ) + await async_execute( + self.cassandra, self.put_username_lookup_stmt, + (workspace, username, id), + ) + + async def get_user(self, id): + rows = await async_execute( + self.cassandra, self.get_user_stmt, (id,), + ) + return rows[0] if rows else None + + async def get_user_id_by_username(self, workspace, username): + rows = await async_execute( + self.cassandra, self.get_user_id_by_username_stmt, + (workspace, username), + ) + return rows[0][0] if rows else None + + async def list_users_by_workspace(self, workspace): + return await async_execute( + self.cassandra, self.list_users_by_workspace_stmt, (workspace,), + ) + + async def delete_user(self, id): + await async_execute( + self.cassandra, self.delete_user_stmt, (id,), + ) + + async def delete_username_lookup(self, workspace, username): + await async_execute( + self.cassandra, self.delete_username_lookup_stmt, + (workspace, username), + ) + + # ------------------------------------------------------------------ + # API keys + # ------------------------------------------------------------------ + + async def put_api_key( + self, key_hash, id, user_id, name, prefix, expires, + created, last_used, + ): + await async_execute( + self.cassandra, self.put_api_key_stmt, + (key_hash, id, user_id, name, prefix, expires, + created, last_used), + ) + + async def get_api_key_by_hash(self, key_hash): + rows = await async_execute( + self.cassandra, self.get_api_key_by_hash_stmt, (key_hash,), + ) + return rows[0] if rows else None + + async def get_api_key_by_id(self, id): + rows = await async_execute( + self.cassandra, self.get_api_key_by_id_stmt, (id,), + ) + return rows[0] if rows else None + + async def list_api_keys_by_user(self, user_id): + return await async_execute( + self.cassandra, self.list_api_keys_by_user_stmt, (user_id,), + ) + + async def delete_api_key(self, key_hash): + await async_execute( + self.cassandra, self.delete_api_key_stmt, (key_hash,), + ) + + # ------------------------------------------------------------------ + # Signing keys + # ------------------------------------------------------------------ + + async def put_signing_key(self, kid, private_pem, public_pem, + created, retired): + await async_execute( + self.cassandra, self.put_signing_key_stmt, + (kid, private_pem, public_pem, created, retired), + ) + + async def list_signing_keys(self): + return await async_execute( + self.cassandra, self.list_signing_keys_stmt, + ) + + async def retire_signing_key(self, kid, retired): + await async_execute( + self.cassandra, self.retire_signing_key_stmt, + (retired, kid), + ) + + # ------------------------------------------------------------------ + # User partial updates + # ------------------------------------------------------------------ + + async def update_user_profile( + self, id, name, email, roles, enabled, must_change_password, + ): + await async_execute( + self.cassandra, self.update_user_profile_stmt, + ( + name, email, + set(roles) if roles else set(), + enabled, must_change_password, id, + ), + ) + + async def update_user_password( + self, id, password_hash, must_change_password, + ): + await async_execute( + self.cassandra, self.update_user_password_stmt, + (password_hash, must_change_password, id), + ) + + async def update_user_enabled(self, id, enabled): + await async_execute( + self.cassandra, self.update_user_enabled_stmt, + (enabled, id), + ) + + # ------------------------------------------------------------------ + # Workspace updates + # ------------------------------------------------------------------ + + async def update_workspace(self, id, name, enabled): + await async_execute( + self.cassandra, self.update_workspace_stmt, + (name, enabled, id), + ) + + # ------------------------------------------------------------------ + # Bootstrap helpers + # ------------------------------------------------------------------ + + async def any_workspace_exists(self): + rows = await self.list_workspaces() + return bool(rows) From 666af1c4b39aa25474442d3884bed724795b9199 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 28 Apr 2026 15:00:33 +0100 Subject: [PATCH 15/21] feat(iam): allow bootstrap mode and token to be sourced from env vars (#851) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an environment-variable fallback for the iam-svc bootstrap configuration so the token can be injected from a Kubernetes Secret (or any equivalent secret store) without ever appearing in the processor-group YAML — which is typically version-controlled. Resolution order is fixed and per-setting: bootstrap_mode = params["bootstrap_mode"] or $IAM_BOOTSTRAP_MODE bootstrap_token = params["bootstrap_token"] or $IAM_BOOTSTRAP_TOKEN If neither source supplies a value, the service refuses to start with a clear message naming both options. The two settings are resolved independently, which lets operators commit the mode in YAML (it is not a secret) while pulling the token from a Secret-backed ``IAM_BOOTSTRAP_TOKEN`` env var. Validation invariants are unchanged: * mode must be 'token' or 'bootstrap' * mode='token' requires a token (from any source) * mode='bootstrap' must NOT have a token (ambiguous intent) There is no permissive fallback — the service fails closed in every branch where configuration is incomplete. docs/tech-specs/iam-protocol.md gains a 'Configuration sources' subsection under 'Bootstrap modes' that documents the precedence table and the K8s injection pattern. The 'Bootstrap-token lifecycle' step about removing the token after rotation now applies to whichever source was used (Secret, env var, or YAML field). --- docs/tech-specs/iam-protocol.md | 22 +++++++++- .../trustgraph/iam/service/service.py | 41 +++++++++++++++---- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/docs/tech-specs/iam-protocol.md b/docs/tech-specs/iam-protocol.md index 8638e7e9..8049ebfe 100644 --- a/docs/tech-specs/iam-protocol.md +++ b/docs/tech-specs/iam-protocol.md @@ -273,6 +273,25 @@ cannot distinguish: This matches the general IAM error-policy stance (see `iam.md`) and prevents externally enumerating IAM's state. +### Configuration sources + +The mode and token can be supplied two ways. Resolution order is +fixed; there is no permissive fallback. + +| Source | Field | +|---|---| +| Processor-group YAML / CLI argument | `bootstrap_mode`, `bootstrap_token` | +| Environment variable | `IAM_BOOTSTRAP_MODE`, `IAM_BOOTSTRAP_TOKEN` | + +For each setting the service uses the explicit param value if +present; otherwise the environment variable; otherwise the service +refuses to start. The env-var path is intended for the K8s +deployment pattern where the token is injected from a `Secret` via +`secretKeyRef`, so the plaintext never has to live in YAML or git. +A typical production manifest holds `bootstrap_mode: "token"` in +the YAML and pulls `IAM_BOOTSTRAP_TOKEN` from the Secret; the YAML +is then safe to version-control. + ### Bootstrap-token lifecycle The bootstrap token — whether operator-supplied (`token` mode) or @@ -283,7 +302,8 @@ operator's first admin action after bootstrap should be: 1. Create a durable admin user and API key (or issue a durable API key to the bootstrap admin). 2. Revoke the bootstrap key via `revoke-api-key`. -3. Remove the bootstrap token from any deployment configuration. +3. Remove the bootstrap token from any deployment configuration + (Secret, env var, or YAML field — wherever it was sourced). The `name="bootstrap"` marker makes bootstrap keys easy to detect in tooling (e.g. a `tg-list-api-keys` filter). diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py index 8ea31cf0..147bd56a 100644 --- a/trustgraph-flow/trustgraph/iam/service/service.py +++ b/trustgraph-flow/trustgraph/iam/service/service.py @@ -7,6 +7,7 @@ Shape mirrors trustgraph.config.service. """ import logging +import os from trustgraph.schema import Error from trustgraph.schema import IamRequest, IamResponse @@ -27,6 +28,13 @@ default_ident = "iam-svc" default_iam_request_queue = iam_request_queue default_iam_response_queue = iam_response_queue +# Environment variables consulted as a fallback when the +# corresponding params field is not set in the processor-group YAML +# or via CLI. Intended for K8s Secret / env-var injection so the +# bootstrap token never has to live in the YAML (and thus in git). +ENV_BOOTSTRAP_MODE = "IAM_BOOTSTRAP_MODE" +ENV_BOOTSTRAP_TOKEN = "IAM_BOOTSTRAP_TOKEN" + class Processor(AsyncProcessor): @@ -39,26 +47,41 @@ class Processor(AsyncProcessor): "iam_response_queue", default_iam_response_queue, ) - bootstrap_mode = params.get("bootstrap_mode") - bootstrap_token = params.get("bootstrap_token") + # Resolve bootstrap mode + token. Precedence: explicit + # params (CLI / processor-group YAML) → environment variable + # → unset (fail-closed). The env-var path is the K8s-native + # injection point: an `IAM_BOOTSTRAP_TOKEN` from a Secret + # never has to land in the YAML, and therefore never enters + # git history. + bootstrap_mode = ( + params.get("bootstrap_mode") + or os.environ.get(ENV_BOOTSTRAP_MODE) + ) + bootstrap_token = ( + params.get("bootstrap_token") + or os.environ.get(ENV_BOOTSTRAP_TOKEN) + ) if bootstrap_mode not in ("token", "bootstrap"): raise RuntimeError( - "iam-svc: --bootstrap-mode is required. Set to 'token' " - "(with --bootstrap-token) for production, or 'bootstrap' " + "iam-svc: bootstrap-mode is required. Set to 'token' " + "(with bootstrap-token) for production, or 'bootstrap' " "to enable the explicit bootstrap operation over the " "pub/sub bus (dev / quick-start only, not safe under " - "public exposure). Refusing to start." + "public exposure). Configurable via processor-group " + f"params or the {ENV_BOOTSTRAP_MODE} environment " + "variable. Refusing to start." ) if bootstrap_mode == "token" and not bootstrap_token: raise RuntimeError( - "iam-svc: --bootstrap-mode=token requires " - "--bootstrap-token. Refusing to start." + "iam-svc: bootstrap-mode=token requires bootstrap-token " + f"(or the {ENV_BOOTSTRAP_TOKEN} environment " + "variable). Refusing to start." ) if bootstrap_mode == "bootstrap" and bootstrap_token: raise RuntimeError( - "iam-svc: --bootstrap-token is not accepted when " - "--bootstrap-mode=bootstrap. Ambiguous intent. " + "iam-svc: bootstrap-token is not accepted when " + "bootstrap-mode=bootstrap. Ambiguous intent. " "Refusing to start." ) From b15f1a167c1d411eeb59a62e1fbf90d924b51498 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 28 Apr 2026 15:05:35 +0100 Subject: [PATCH 16/21] Smoke-test websocket tool (#852) --- dev-tools/tests/smoke/smoke_ws_queries.py | 475 ++++++++++++++++++++++ 1 file changed, 475 insertions(+) create mode 100755 dev-tools/tests/smoke/smoke_ws_queries.py diff --git a/dev-tools/tests/smoke/smoke_ws_queries.py b/dev-tools/tests/smoke/smoke_ws_queries.py new file mode 100755 index 00000000..c6a4dfb6 --- /dev/null +++ b/dev-tools/tests/smoke/smoke_ws_queries.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +""" +WebSocket smoke / load test that hammers a TrustGraph gateway with a +mix of `embeddings`, `graph-embeddings`, and `triples` queries while +keeping a target number of in-flight requests at all times. + +Useful for reproducing the "worker hangs after a while, all subsequent +requests time out" failure mode — leaves enough load on the system to +saturate worker concurrency and reports per-service success/timeout +rates and latency distributions over time. + +Usage: + smoke_ws_queries.py --flow onto-rag --duration 120 --concurrency 20 + +Connects via /api/v1/socket using the first-frame auth protocol. +""" + +import argparse +import asyncio +import json +import os +import random +import statistics +import sys +import time +import uuid +from collections import defaultdict +from typing import Any + +import websockets + + +DEFAULT_TEXT = ( + "What caused the space shuttle to explode and what were the " + "main factors leading to the disaster?" +) + + +class Stats: + """Per-service rolling counters and latency samples.""" + + def __init__(self) -> None: + self.sent = 0 + self.ok = 0 + self.err = 0 + self.timeout = 0 + self.latencies_ms: list[float] = [] + + def record_ok(self, latency_ms: float) -> None: + self.ok += 1 + self.latencies_ms.append(latency_ms) + + def record_err(self) -> None: + self.err += 1 + + def record_timeout(self) -> None: + self.timeout += 1 + + def percentile(self, p: float) -> float: + if not self.latencies_ms: + return 0.0 + s = sorted(self.latencies_ms) + idx = min(len(s) - 1, int(len(s) * p)) + return s[idx] + + def summary(self) -> str: + if self.latencies_ms: + mn = min(self.latencies_ms) + mx = max(self.latencies_ms) + mean = statistics.mean(self.latencies_ms) + p50 = self.percentile(0.50) + p95 = self.percentile(0.95) + p99 = self.percentile(0.99) + lat = ( + f"min={mn:.0f} mean={mean:.0f} p50={p50:.0f} " + f"p95={p95:.0f} p99={p99:.0f} max={mx:.0f} ms" + ) + else: + lat = "no successful samples" + return ( + f"sent={self.sent} ok={self.ok} err={self.err} " + f"timeout={self.timeout} | {lat}" + ) + + +class WSClient: + """Thin async websocket client with first-frame auth and a shared + reader task that demuxes responses to per-request asyncio queues.""" + + def __init__( + self, url: str, token: str | None, workspace: str, + ping_timeout: int, + ) -> None: + self.url = url + self.token = token + self.workspace = workspace + self.ping_timeout = ping_timeout + self._ws: Any = None + self._pending: dict[str, asyncio.Queue] = {} + self._reader_task: asyncio.Task | None = None + self._closed = asyncio.Event() + + async def connect(self) -> None: + ws_url = self.url.rstrip("/") + "/api/v1/socket" + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[len("http://"):] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[len("https://"):] + elif not ( + ws_url.startswith("ws://") or ws_url.startswith("wss://") + ): + ws_url = "ws://" + ws_url + + self._ws = await websockets.connect( + ws_url, + ping_interval=20, + ping_timeout=self.ping_timeout, + max_size=64 * 1024 * 1024, + ) + + if self.token: + # First-frame auth handshake. + await self._ws.send(json.dumps({ + "type": "auth", "token": self.token, + })) + raw = await asyncio.wait_for(self._ws.recv(), timeout=10) + resp = json.loads(raw) + if resp.get("type") != "auth-ok": + await self._ws.close() + raise RuntimeError(f"auth failed: {resp}") + if "workspace" in resp: + # Server-resolved workspace overrides the user-supplied + # one, mirroring AsyncSocketClient behaviour. + self.workspace = resp["workspace"] + else: + print( + "WARNING: no token provided — skipping auth handshake. " + "Requests will be rejected unless the gateway is " + "running without IAM enforcement.", + file=sys.stderr, + ) + + self._reader_task = asyncio.create_task(self._reader()) + + async def _reader(self) -> None: + try: + async for raw in self._ws: + msg = json.loads(raw) + rid = msg.get("id") + if rid and rid in self._pending: + await self._pending[rid].put(msg) + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + for q in list(self._pending.values()): + try: + q.put_nowait({"error": {"message": str(e)}}) + except Exception: + pass + finally: + self._closed.set() + + async def request( + self, service: str, flow: str | None, body: dict, timeout: float, + ) -> tuple[dict | None, str | None, float]: + """Send one request, await final response. + + Returns ``(response, error, latency_ms)``. ``response`` is None + on error/timeout. ``error`` describes the failure category. + """ + rid = str(uuid.uuid4()) + q: asyncio.Queue = asyncio.Queue() + self._pending[rid] = q + env = { + "id": rid, + "workspace": self.workspace, + "service": service, + "request": body, + } + if flow: + env["flow"] = flow + + t0 = time.monotonic() + try: + await self._ws.send(json.dumps(env)) + while True: + try: + msg = await asyncio.wait_for(q.get(), timeout=timeout) + except asyncio.TimeoutError: + return None, "timeout", (time.monotonic() - t0) * 1000 + if "error" in msg and msg["error"]: + err = msg["error"] + err_msg = ( + err.get("message") if isinstance(err, dict) else str(err) + ) + return None, f"error: {err_msg}", (time.monotonic() - t0) * 1000 + if msg.get("complete"): + return msg.get("response"), None, (time.monotonic() - t0) * 1000 + # Otherwise an intermediate streaming chunk — keep waiting. + finally: + self._pending.pop(rid, None) + + async def close(self) -> None: + if self._ws is not None: + await self._ws.close() + if self._reader_task is not None: + try: + await asyncio.wait_for(self._reader_task, timeout=2) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--url", + default=os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/"), + help="Gateway URL (http or ws). Default: %(default)s", + ) + p.add_argument( + "--token", + default=os.getenv("TRUSTGRAPH_TOKEN"), + help="Auth token (or set TRUSTGRAPH_TOKEN). Optional — if " + "omitted, the auth handshake is skipped (only works " + "when the gateway is running without IAM enforcement).", + ) + p.add_argument( + "--workspace", default="default", + help="Workspace. Default: %(default)s", + ) + p.add_argument( + "--flow", required=True, + help="Flow id. Comma-separated for round-robin across flows " + "(e.g. onto-rag,doc-rag).", + ) + p.add_argument( + "--duration", type=int, default=60, + help="Test duration in seconds. Default: %(default)s", + ) + p.add_argument( + "--concurrency", type=int, default=15, + help="Target in-flight request count. Default: %(default)s", + ) + p.add_argument( + "--services", + default="embeddings,graph-embeddings,triples", + help="Comma-separated services to exercise. " + "Default: %(default)s", + ) + p.add_argument( + "--limit", type=int, default=3, + help="limit for triples / graph-embeddings queries. " + "Default: %(default)s", + ) + p.add_argument( + "--collection", default="default", + help="Collection. Default: %(default)s", + ) + p.add_argument( + "--text", default=DEFAULT_TEXT, + help="Text to embed for embeddings/seed.", + ) + p.add_argument( + "--vector-dim", type=int, default=384, + help="Dimension of synthetic vector when --no-seed is used. " + "Default: %(default)s", + ) + p.add_argument( + "--no-seed", action="store_true", + help="Skip the embeddings warm-up call. Use a random vector " + "for graph-embeddings queries instead.", + ) + p.add_argument( + "--request-timeout", type=float, default=30.0, + help="Per-request timeout (seconds). Default: %(default)s", + ) + p.add_argument( + "--report-interval", type=float, default=5.0, + help="How often to print stats (seconds). Default: %(default)s", + ) + p.add_argument( + "--ping-timeout", type=int, default=120, + help="Websocket ping timeout. Default: %(default)s", + ) + p.add_argument( + "--seed", type=int, default=None, + help="Random seed (for reproducibility).", + ) + return p.parse_args() + + +async def seed_vector( + client: WSClient, flow: str, text: str, timeout: float, +) -> list[float]: + """Issue one embeddings request to obtain a real vector that + later graph-embeddings calls can reuse.""" + resp, err, _ = await client.request( + "embeddings", flow, {"texts": [text]}, timeout, + ) + if err or not resp: + raise RuntimeError(f"seed embeddings failed: {err or resp}") + vectors = resp.get("vectors") + if not vectors: + raise RuntimeError(f"seed embeddings: no vectors in response: {resp}") + return vectors[0] + + +def make_request_body( + service: str, args: argparse.Namespace, vector: list[float], +) -> dict: + if service == "embeddings": + return {"texts": [args.text]} + if service == "graph-embeddings": + return { + "vector": vector, + "limit": args.limit, + "collection": args.collection, + } + if service == "triples": + return { + "limit": args.limit, + "collection": args.collection, + } + raise ValueError(f"Unknown service: {service}") + + +async def worker( + name: int, + client: WSClient, + flows: list[str], + services: list[str], + args: argparse.Namespace, + vector: list[float], + stats: dict[str, Stats], + in_flight: dict[str, int], + stop_at: float, +) -> None: + rng = random.Random((args.seed or 0) + name) + while time.monotonic() < stop_at: + svc = rng.choice(services) + flow = rng.choice(flows) + body = make_request_body(svc, args, vector) + + stats[svc].sent += 1 + in_flight[svc] += 1 + try: + resp, err, lat = await client.request( + svc, flow, body, args.request_timeout, + ) + if err == "timeout": + stats[svc].record_timeout() + elif err: + stats[svc].record_err() + else: + stats[svc].record_ok(lat) + except Exception as e: + stats[svc].record_err() + print(f"worker {name}: unexpected {svc} exception: {e}", + file=sys.stderr) + finally: + in_flight[svc] -= 1 + + +async def reporter( + services: list[str], + stats: dict[str, Stats], + in_flight: dict[str, int], + stop_at: float, + interval: float, +) -> None: + started = time.monotonic() + last_sent = {s: 0 for s in services} + while time.monotonic() < stop_at: + await asyncio.sleep(interval) + now = time.monotonic() + elapsed = now - started + total_inflight = sum(in_flight.values()) + print( + f"\n[{elapsed:6.1f}s] in-flight={total_inflight} " + f"per-svc={dict(in_flight)}" + ) + for svc in services: + s = stats[svc] + delta = s.sent - last_sent[svc] + rate = delta / interval + last_sent[svc] = s.sent + print(f" {svc:20s} {rate:6.1f}/s | {s.summary()}") + + +async def run(args: argparse.Namespace) -> int: + if args.seed is not None: + random.seed(args.seed) + + services = [s.strip() for s in args.services.split(",") if s.strip()] + flows = [f.strip() for f in args.flow.split(",") if f.strip()] + valid = {"embeddings", "graph-embeddings", "triples"} + bad = [s for s in services if s not in valid] + if bad: + print(f"ERROR: unknown service(s): {bad}. " + f"Supported: {sorted(valid)}", file=sys.stderr) + return 2 + + client = WSClient( + args.url, args.token, args.workspace, args.ping_timeout, + ) + print(f"Connecting to {args.url} ...") + await client.connect() + print(f"Connected. workspace={client.workspace} flows={flows} " + f"services={services} concurrency={args.concurrency} " + f"duration={args.duration}s") + + if "graph-embeddings" in services and not args.no_seed: + print("Seeding embedding vector ...") + vector = await seed_vector( + client, flows[0], args.text, args.request_timeout, + ) + print(f"Got vector of length {len(vector)}") + else: + vector = [random.uniform(-1.0, 1.0) for _ in range(args.vector_dim)] + + stats: dict[str, Stats] = defaultdict(Stats) + in_flight: dict[str, int] = defaultdict(int) + for svc in services: + stats[svc] # initialise + in_flight[svc] = 0 + + stop_at = time.monotonic() + args.duration + print(f"Starting load: {args.concurrency} workers for " + f"{args.duration}s ...") + + workers = [ + asyncio.create_task( + worker( + i, client, flows, services, args, vector, + stats, in_flight, stop_at, + ) + ) + for i in range(args.concurrency) + ] + rep = asyncio.create_task( + reporter(services, stats, in_flight, stop_at, args.report_interval) + ) + + try: + await asyncio.gather(*workers) + finally: + rep.cancel() + try: + await rep + except asyncio.CancelledError: + pass + + print("\n=== Final results ===") + any_failures = False + for svc in services: + s = stats[svc] + print(f" {svc:20s} {s.summary()}") + if s.timeout > 0 or s.err > 0: + any_failures = True + + await client.close() + + return 1 if any_failures else 0 + + +def main() -> int: + args = parse_args() + try: + return asyncio.run(run(args)) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main()) From 9f2d9adcb140aae8c3413b6e16c1239881a79d74 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 28 Apr 2026 15:43:04 +0100 Subject: [PATCH 17/21] Fix Ollama async issue (#854) * Fix Ollama sync issues - replaced with async * Fix tests --- .../test_ollama_dynamic_model.py | 24 +++++----- .../test_ollama_processor.py | 48 +++++++++---------- .../trustgraph/embeddings/ollama/processor.py | 14 +++--- .../model/text_completion/ollama/llm.py | 22 ++++----- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py index d52a58c6..cfbc4d6e 100644 --- a/tests/unit/test_embeddings/test_ollama_dynamic_model.py +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -14,13 +14,13 @@ from trustgraph.embeddings.ollama.processor import Processor class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): """Test Ollama dynamic model selection""" - @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.embeddings.ollama.processor.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class): """Test that Ollama client is initialized with correct host""" # Arrange - mock_ollama_client = Mock() + mock_ollama_client = AsyncMock() mock_response = Mock() mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_ollama_client.embed.return_value = mock_response @@ -36,13 +36,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): mock_client_class.assert_called_once_with(host="http://localhost:11434") assert processor.default_model == "test-model" - @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.embeddings.ollama.processor.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class): """Test that on_embeddings uses default model when no model specified""" # Arrange - mock_ollama_client = Mock() + mock_ollama_client = AsyncMock() mock_response = Mock() mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_ollama_client.embed.return_value = mock_response @@ -62,13 +62,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): ) assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] - @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.embeddings.ollama.processor.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class): """Test that on_embeddings uses specified model when provided""" # Arrange - mock_ollama_client = Mock() + mock_ollama_client = AsyncMock() mock_response = Mock() mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_ollama_client.embed.return_value = mock_response @@ -88,13 +88,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): ) assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] - @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.embeddings.ollama.processor.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class): """Test switching between multiple models""" # Arrange - mock_ollama_client = Mock() + mock_ollama_client = AsyncMock() mock_response = Mock() mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_ollama_client.embed.return_value = mock_response @@ -118,13 +118,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): assert calls[2][1]['model'] == "model-a" assert calls[3][1]['model'] == "test-model" # Default - @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.embeddings.ollama.processor.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): """Test that None model parameter falls back to default""" # Arrange - mock_ollama_client = Mock() + mock_ollama_client = AsyncMock() mock_response = Mock() mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] mock_ollama_client.embed.return_value = mock_response @@ -143,13 +143,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): input=["test text"] ) - @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.embeddings.ollama.processor.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): """Test initialization without model parameter uses module default""" # Arrange - mock_ollama_client = Mock() + mock_ollama_client = AsyncMock() mock_client_class.return_value = mock_ollama_client mock_async_init.return_value = None mock_embeddings_init.return_value = None diff --git a/tests/unit/test_text_completion/test_ollama_processor.py b/tests/unit/test_text_completion/test_ollama_processor.py index 69baf85f..35bf182a 100644 --- a/tests/unit/test_text_completion/test_ollama_processor.py +++ b/tests/unit/test_text_completion/test_ollama_processor.py @@ -15,13 +15,13 @@ from trustgraph.base import LlmResult class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): """Test Ollama processor functionality""" - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class): """Test basic processor initialization""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_client_class.return_value = mock_client # Mock the parent class initialization @@ -44,13 +44,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): assert hasattr(processor, 'llm') mock_client_class.assert_called_once_with(host='http://localhost:11434') - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class): """Test successful content generation""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Generated response from Ollama', 'prompt_eval_count': 15, @@ -83,13 +83,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): assert result.model == 'llama2' mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0}) - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class): """Test handling of generic exceptions""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.generate.side_effect = Exception("Connection error") mock_client_class.return_value = mock_client @@ -110,13 +110,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(Exception, match="Connection error"): await processor.generate_content("System prompt", "User prompt") - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class): """Test processor initialization with custom parameters""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_client_class.return_value = mock_client mock_async_init.return_value = None @@ -137,13 +137,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): assert processor.default_model == 'mistral' mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434') - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class): """Test processor initialization with default values""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_client_class.return_value = mock_client mock_async_init.return_value = None @@ -164,13 +164,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) mock_client_class.assert_called_once() - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class): """Test content generation with empty prompts""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Default response', 'prompt_eval_count': 2, @@ -205,13 +205,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): # The prompt should be "" + "\n\n" + "" = "\n\n" mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0}) - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class): """Test token counting from Ollama response""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Test response', 'prompt_eval_count': 50, @@ -243,13 +243,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): assert result.out_token == 25 assert result.model == 'llama2' - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class): """Test that Ollama client is initialized correctly""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_client_class.return_value = mock_client mock_async_init.return_value = None @@ -273,13 +273,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): # Verify processor has the client assert processor.llm == mock_client - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class): """Test prompt construction with system and user prompts""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Response with system instructions', 'prompt_eval_count': 25, @@ -312,13 +312,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): # Verify the combined prompt mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0}) - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class): """Test temperature parameter override functionality""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Response with custom temperature', 'prompt_eval_count': 20, @@ -360,13 +360,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): options={'temperature': 0.8} # Should use runtime override ) - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class): """Test model parameter override functionality""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Response with custom model', 'prompt_eval_count': 18, @@ -408,13 +408,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): options={'temperature': 0.1} # Should use processor default ) - @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.model.text_completion.ollama.llm.AsyncClient') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class): """Test overriding both model and temperature parameters simultaneously""" # Arrange - mock_client = MagicMock() + mock_client = AsyncMock() mock_response = { 'response': 'Response with both overrides', 'prompt_eval_count': 22, diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index c63db33c..5fa74054 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -5,7 +5,7 @@ Input is text, output is embeddings vector. """ from ... base import EmbeddingsService -from ollama import Client +from ollama import AsyncClient import os import logging @@ -30,24 +30,24 @@ class Processor(EmbeddingsService): } ) - self.client = Client(host=ollama) + self.client = AsyncClient(host=ollama) self.default_model = model self._checked_models = set() - def _ensure_model(self, model_name): + async def _ensure_model(self, model_name): """Check if model exists locally, pull it if not.""" if model_name in self._checked_models: return try: - self.client.show(model_name) + await self.client.show(model_name) self._checked_models.add(model_name) except Exception as e: status_code = getattr(e, 'status_code', None) if status_code == 404 or "not found" in str(e).lower(): logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") try: - self.client.pull(model_name) + await self.client.pull(model_name) self._checked_models.add(model_name) logger.info(f"Successfully pulled Ollama model '{model_name}'.") except Exception as pull_e: @@ -63,10 +63,10 @@ class Processor(EmbeddingsService): use_model = model or self.default_model # Ensure the model exists/is pulled - self._ensure_model(use_model) + await self._ensure_model(use_model) # Ollama handles batch input efficiently - embeds = self.client.embed( + embeds = await self.client.embed( model = use_model, input = texts ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index f6c5dcb8..2e537fde 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using an Ollama service. Input is prompt, output is response. """ -from ollama import Client +from ollama import AsyncClient import os import logging @@ -38,23 +38,23 @@ class Processor(LlmService): self.default_model = model self.temperature = temperature - self.llm = Client(host=ollama) + self.llm = AsyncClient(host=ollama) self._checked_models = set() - def _ensure_model(self, model_name): + async def _ensure_model(self, model_name): """Check if model exists locally, pull it if not.""" if model_name in self._checked_models: return try: - self.llm.show(model_name) + await self.llm.show(model_name) self._checked_models.add(model_name) except Exception as e: status_code = getattr(e, 'status_code', None) if status_code == 404 or "not found" in str(e).lower(): logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") try: - self.llm.pull(model_name) + await self.llm.pull(model_name) self._checked_models.add(model_name) logger.info(f"Successfully pulled Ollama model '{model_name}'.") except Exception as pull_e: @@ -66,9 +66,9 @@ class Processor(LlmService): # Use provided model or fall back to default model_name = model or self.default_model - + # Ensure the model exists/is pulled - self._ensure_model(model_name) + await self._ensure_model(model_name) # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature @@ -79,7 +79,7 @@ class Processor(LlmService): try: - response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature}) + response = await self.llm.generate(model_name, prompt, options={'temperature': effective_temperature}) response_text = response['response'] logger.debug("Sending response...") @@ -113,7 +113,7 @@ class Processor(LlmService): model_name = model or self.default_model # Ensure the model exists/is pulled - self._ensure_model(model_name) + await self._ensure_model(model_name) effective_temperature = temperature if temperature is not None else self.temperature @@ -123,7 +123,7 @@ class Processor(LlmService): prompt = system + "\n\n" + prompt try: - stream = self.llm.generate( + stream = await self.llm.generate( model_name, prompt, options={'temperature': effective_temperature}, @@ -133,7 +133,7 @@ class Processor(LlmService): total_input_tokens = 0 total_output_tokens = 0 - for chunk in stream: + async for chunk in stream: if 'response' in chunk and chunk['response']: yield LlmChunk( text=chunk['response'], From 5e28d3cce0231afa24930aa9e2b339e8ef5de153 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 28 Apr 2026 16:19:41 +0100 Subject: [PATCH 18/21] refactor(iam): pluggable IAM regime via authenticate/authorise contract (#853) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gateway no longer holds any policy state — capability sets, role definitions, workspace scope rules. Per the IAM contract it asks the regime "may this identity perform this capability on this resource?" per request. That moves the OSS role-based regime entirely into iam-svc, which can be replaced (SSO, ABAC, ReBAC) without changing the gateway, the wire protocol, or backend services. Contract: - authenticate(credential) -> Identity (handle, workspace, principal_id, source). No roles, claims, or policy state surface to the gateway. - authorise(identity, capability, resource, parameters) -> (allow, ttl). Cached per-decision (regime TTL clamped above; fail-closed on regime errors). - authorise_many available as a fan-out variant. Operation registry drives every authorisation decision: - /api/v1/iam -> IamEndpoint, looks up bare op name (create-user, list-workspaces, ...). - /api/v1/{kind} -> RegistryRoutedVariableEndpoint, : (config:get, flow:list-blueprints, librarian:add-document, ...). - /api/v1/flow/{flow}/service/{kind} -> flow-service:. - /api/v1/flow/{flow}/{import,export}/{kind} -> flow-{import,export}:. - WS Mux per-frame -> flow-service:; closes a gap where authenticated users could hit any service kind. 85 operations registered across the surface. JWT carries identity only — sub + workspace. The roles claim is gone; the gateway never reads policy state from a credential. The three coarse *_KIND_CAPABILITY maps are removed. The registry is the only source of truth for the capability + resource shape of an operation. Tests migrated to the new Identity shape and to authorise()-mocked auth doubles. Specs updated: docs/tech-specs/iam-contract.md (Identity surface, caching, registry-naming conventions), iam.md (JWT shape, gateway flow, role section reframed as OSS-regime detail), iam-protocol.md (positioned as one implementation of the contract). --- docs/tech-specs/capabilities.md | 95 +++- docs/tech-specs/data-ownership-model.md | 32 +- ...nition.md => flow-blueprint-definition.md} | 0 docs/tech-specs/iam-contract.md | 366 +++++++++++++ docs/tech-specs/iam-protocol.md | 48 +- docs/tech-specs/iam.md | 323 ++++++++--- tests/unit/test_gateway/test_auth.py | 165 +++++- tests/unit/test_gateway/test_capabilities.py | 236 ++++---- .../test_gateway/test_endpoint_manager.py | 16 +- .../test_socket_graceful_shutdown.py | 17 +- trustgraph-base/trustgraph/base/iam_client.py | 44 +- .../trustgraph/schema/services/iam.py | 27 + trustgraph-flow/trustgraph/gateway/auth.py | 141 ++++- .../trustgraph/gateway/capabilities.py | 234 ++------ .../trustgraph/gateway/dispatch/mux.py | 51 +- .../gateway/endpoint/auth_endpoints.py | 2 +- .../gateway/endpoint/constant_endpoint.py | 2 +- .../gateway/endpoint/iam_endpoint.py | 106 ++++ .../trustgraph/gateway/endpoint/manager.py | 177 +++--- .../gateway/endpoint/registry_endpoint.py | 123 +++++ .../trustgraph/gateway/endpoint/socket.py | 10 +- .../gateway/endpoint/variable_endpoint.py | 2 +- .../trustgraph/gateway/registry.py | 515 ++++++++++++++++++ trustgraph-flow/trustgraph/iam/service/iam.py | 214 +++++++- 24 files changed, 2359 insertions(+), 587 deletions(-) rename docs/tech-specs/{flow-class-definition.md => flow-blueprint-definition.md} (100%) create mode 100644 docs/tech-specs/iam-contract.md create mode 100644 trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py create mode 100644 trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py create mode 100644 trustgraph-flow/trustgraph/gateway/registry.py diff --git a/docs/tech-specs/capabilities.md b/docs/tech-specs/capabilities.md index 60f5acbf..7717cbc9 100644 --- a/docs/tech-specs/capabilities.md +++ b/docs/tech-specs/capabilities.md @@ -8,22 +8,41 @@ parent: "Tech Specs" ## Overview -Authorisation in TrustGraph is **capability-based**. Every gateway -endpoint maps to exactly one *capability*; a user's roles each grant -a set of capabilities; an authenticated request is permitted when -the required capability is a member of the union of the caller's -role capability sets. +Every gateway endpoint maps to exactly one *capability* — a string +from a closed vocabulary defined in this document. When the +gateway authorises a request, it hands the IAM regime four things: +the authenticated identity, the required capability, the +operation's resource (the structured identifier of what's being +operated on), and the operation's parameters. The IAM regime +decides allow or deny; see the [IAM contract](iam-contract.md) for +the full abstraction. -This document defines the capability vocabulary — the closed list -of capability strings that the gateway recognises — and the -open-source edition's role bundles. +A capability is a **permission**, not a structural classification. +`graph:read` says "the caller may read graphs"; it does not say +where graphs live or how they are addressed. The shape of a +request — whether workspace appears in the URL, the envelope, or +the body, and whether it is a resource address component or an +operation parameter — is determined by what the operation operates +on, not by what permission it requires. Permission and structure +are orthogonal; the contract takes both. -The capability mechanism is shared between open-source and potential -3rd party enterprise capability. The open-source edition ships a -fixed three-role bundle (`reader`, `writer`, `admin`). Enterprise -capability may define additional roles by composing their own -capability bundles from the same vocabulary; no protocol, gateway, -or backend-service change is required. +This document defines: + +- The **capability vocabulary** — the closed list of capability + strings the gateway uses as input to `authorise`. All IAM + regimes share this vocabulary; that's the only schema the + gateway and the IAM regime have to agree on. +- The **open-source role bundles** — the role-and-scope table the + OSS IAM regime uses to answer `authorise` calls. Other regimes + answer the same call differently; the bundles below are an + OSS-specific implementation detail, not a contract assertion. + +A regime may evaluate `authorise` using role bundles (OSS), IdP +group memberships, attribute-based policies, relationship tuples, +or any other mechanism. The gateway is unaware of which. The +capability strings — and the resource component vocabulary the +gateway populates alongside them — are the only thing both sides +have to agree on. ## Motivation @@ -113,19 +132,50 @@ granting `llm` expresses exactly that. An administrator granting `agent` should treat it as a grant of everything the agent composes at deployment time. -### Authorisation evaluation +### Authorisation evaluation (OSS regime) + +This section describes how the OSS IAM regime answers +`authorise(identity, capability, resource, parameters)`. Other +regimes answer the same contract differently; only the inputs (the +capability vocabulary, the resource components, the parameter +shape) are shared. For a request bearing a resolved set of roles -`R = {r1, r2, ...}` against an endpoint that requires capability -`c`: +`R = {r1, r2, ...}`, a required capability `c`, a resource, and +parameters: ``` -allow if c IN union(bundle(r) for r in R) +let target_workspace = + resource.workspace (workspace-/flow-level resources) + or parameters.workspace (system-level resources whose + parameters reference a workspace) + or unset (system-level operations with no + workspace context) + +allow if some role r in R has c in its capability bundle + and (target_workspace is unset + or r's workspace_scope permits target_workspace) ``` -No hierarchy, no precedence, no role-order sensitivity. A user +The OSS regime considers workspace from whichever role it plays in +the operation: + +- For workspace-level and flow-level resources, the workspace lives + in `resource.workspace` and that is what the role's scope is + checked against. +- For system-level resources whose operation parameters reference a + workspace (e.g. `create-user with workspace association W`), + workspace lives in `parameters.workspace` and that is what the + role's scope is checked against. The resource is system-level + (`resource = {}`) but the workspace constraint still bites. +- For system-level operations with no workspace context (e.g. + `bootstrap`, `rotate-signing-key`), the workspace-scope check + collapses — only capability-bundle membership matters. + +No hierarchy, no precedence, no role-order sensitivity. A user with a single role is the common case; a user with multiple roles -gets the union of their bundles. +is allowed if any role independently grants both the capability +and the relevant workspace scope. ### Enforcement boundary @@ -214,5 +264,10 @@ ships that feature. ## References +- [IAM Contract Specification](iam-contract.md) — the abstract + gateway↔IAM regime contract; capability strings are inputs to + `authorise`. - [Identity and Access Management Specification](iam.md) +- [IAM Service Protocol Specification](iam-protocol.md) — the OSS + regime's wire-level protocol. - [Architecture Principles](architecture-principles.md) diff --git a/docs/tech-specs/data-ownership-model.md b/docs/tech-specs/data-ownership-model.md index ea94ec46..b112d195 100644 --- a/docs/tech-specs/data-ownership-model.md +++ b/docs/tech-specs/data-ownership-model.md @@ -22,8 +22,16 @@ are the boundaries around data, and who owns what? A workspace is the primary isolation boundary. It represents an organisation, team, or independent operating unit. All data belongs to -exactly one workspace. Cross-workspace access is never permitted through -the API. +exactly one workspace. + +Cross-workspace access through the API is gated by the IAM regime +(see [`iam-contract.md`](iam-contract.md)). In the OSS distribution, +the role table defined in [`capabilities.md`](capabilities.md) +permits cross-workspace operation only to the `admin` role; the +`reader` and `writer` roles are constrained to a single assigned +workspace per credential. Other regimes can model the relationship +between identity and workspace differently — the gateway makes no +assumption. A workspace owns: - Source documents @@ -279,9 +287,18 @@ A typical workflow: The current codebase uses a `user` field in message metadata and storage partition keys to identify the workspace. The `collection` field -identifies the collection within that workspace. The IAM spec describes -how the gateway maps authenticated credentials to a workspace identity -and sets these fields. +identifies the collection within that workspace. + +The gateway is the single point at which workspace gets stamped onto +outbound pub/sub messages. An incoming credential authenticates to a +workspace (the credential's binding, not a user-to-workspace lookup — +see [`iam-contract.md`](iam-contract.md) and the *Identity surface* +section of [`iam.md`](iam.md)); any caller-supplied workspace on the +request is reconciled against the authenticated identity by the IAM +regime; the resolved value is what the gateway writes into outgoing +messages and the storage layers' partition keys. Backend services +trust the workspace they receive — defense-in-depth happens at the +gateway, not at the bus. For details on how each storage backend implements this scoping, see: @@ -302,7 +319,10 @@ For details on how each storage backend implements this scoping, see: ## References -- [Identity and Access Management](iam.md) +- [IAM Contract](iam-contract.md) — gateway↔IAM regime abstraction. +- [Identity and Access Management](iam.md) — gateway-side framing. +- [Capability Vocabulary](capabilities.md) — capability strings and + the OSS role bundles that decide cross-workspace eligibility. - [Collection Management](collection-management.md) - [Entity-Centric Graph](entity-centric-graph.md) - [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) diff --git a/docs/tech-specs/flow-class-definition.md b/docs/tech-specs/flow-blueprint-definition.md similarity index 100% rename from docs/tech-specs/flow-class-definition.md rename to docs/tech-specs/flow-blueprint-definition.md diff --git a/docs/tech-specs/iam-contract.md b/docs/tech-specs/iam-contract.md new file mode 100644 index 00000000..3289add1 --- /dev/null +++ b/docs/tech-specs/iam-contract.md @@ -0,0 +1,366 @@ +--- +layout: default +title: "IAM Contract Technical Specification" +parent: "Tech Specs" +--- + +# IAM Contract Technical Specification + +## Overview + +The IAM contract is the abstraction between the API gateway and any +identity / access management regime that fronts it. The gateway +treats IAM as a black box behind two operations — *authenticate* and +*authorise* — plus a small surface of management operations. No +regime-specific concept (roles, scopes, groups, claims, policy +languages) is visible to the gateway, and no gateway-specific +concept (capability vocabulary, request anatomy) is visible to +backend services. + +The TrustGraph open-source distribution ships one IAM regime — a +role-based implementation defined in +[`iam-protocol.md`](iam-protocol.md) — that is one implementation of +this contract. Enterprise editions can replace it with a different +regime (OIDC / SSO, ABAC, ReBAC, external policy engine) without +changing the gateway, the wire protocol, or the backends. + +## Motivation + +Authorisation models vary by deployment. A small team might be +happy with three predefined roles; an enterprise might need group- +mapping from an upstream IdP, attribute-based policies, or +relationship-based access control. Hard-wiring any one of those +into the gateway forces every other regime to either compromise its +model or be re-implemented. + +A narrow contract — "authenticate this credential" and "may this +identity perform this operation on this resource" — captures what +the gateway actually needs to know without committing to a policy +shape. The IAM regime owns the policy decision; the gateway is a +generic enforcement point. + +## Operations + +### `authenticate` + +``` +authenticate(credential: bytes) → Identity | AuthFailure +``` + +Validates a credential the client presented. The gateway treats +the credential as opaque bytes — for the OSS regime today that's +either an API key plaintext or a JWT, but the gateway does not +parse them; the IAM regime decides. + +On success, returns an `Identity`. On any failure the IAM regime +returns the same opaque `AuthFailure` — never a description of which +condition failed. This is the spec's masked-error rule: an +attacker probing the endpoint cannot distinguish "no such key", +"expired", "wrong signature", "revoked", "user disabled", etc. + +### `authorise` + +``` +authorise(identity: Identity, + capability: str, + resource: Resource, + parameters: dict) + → Decision +``` + +Asks whether the identity is permitted to perform the named +capability on the named resource, given the operation's +parameters. Returns `allow` or `deny`. `identity` is whatever +`authenticate` returned for this caller; the gateway never +decomposes it. + +The four arguments separate concerns: + +- **`identity`** — who is asking. +- **`capability`** — what permission they are exercising (e.g. + `users:write`, `graph:read`). Permission, not structure. +- **`resource`** — what is being operated on, as a structured + identifier. See *The Resource model* below. +- **`parameters`** — operation-specific data that the regime may + need to consider beyond the resource identifier. Used when a + decision depends on attributes the request supplies — e.g. an + admin scoped to one workspace creating a user *with workspace + association W*: the resource is the system-level user registry, + and W is a parameter the regime checks against the admin's + scope. + +Different regimes use the four arguments differently — the OSS +regime checks role bundles against the capability and the role's +workspace scope against parameters; an SSO regime might consult an +upstream IdP's group memberships; an ABAC regime evaluates a +policy with all four as inputs. The contract is unchanged. + +### `authorise_many` + +``` +authorise_many(identity: Identity, + checks: list[(str, Resource, dict)]) + → list[Decision] +``` + +Bulk variant of `authorise`. Same semantics, one round-trip for +many decisions. Used when an operation fans out to multiple +resources (e.g. an agent that touches several workspaces) and a +single permission check isn't sufficient. + +`authorise_many` is not just a performance optimisation; it pins +the contract for fan-out operations early, before clients (or +internal callers) build patterns that assume one-permission-check- +per-request. Regimes implement it as a loop over `authorise` +unless they have a more efficient path. + +### Management operations + +Beyond the request-time `authenticate` / `authorise`, the contract +also covers identity-lifecycle and credential-lifecycle operations +that are invoked by administrative requests rather than by the +authentication path. These are regime-specific in detail (an SSO +regime that delegates user management to the IdP may not implement +most of them) but the operation set the gateway can forward is: + +- User management: `create-user`, `list-users`, `get-user`, + `update-user`, `disable-user`, `enable-user`, `delete-user` +- Credential management: `create-api-key`, `list-api-keys`, + `revoke-api-key`, `change-password`, `reset-password` +- Workspace management: `create-workspace`, `list-workspaces`, + `get-workspace`, `update-workspace`, `disable-workspace` +- Session management: `login` +- Key management: `get-signing-key-public`, `rotate-signing-key` +- Bootstrap: `bootstrap` + +A regime that does not support one of these (e.g. an SSO regime +where users are managed in the IdP) returns a defined "not +supported" error; the gateway surfaces it as a 501. + +## The `Identity` surface + +`Identity` is *mostly* opaque. The gateway holds the value as a +token to quote back when calling `authorise`, never decomposing it. +But there are a few gateway-side concerns that need a small +surface: + +| Field | Purpose | +|---|---| +| `handle` | Opaque reference passed back to `authorise`. Regime-defined; gateway treats as a string. | +| `workspace` | The workspace this credential authenticates to. Used by the gateway only as a default-fill-in for operations that omit a workspace. Never used as policy input — when authorisation needs to know which workspace the operation acts on, the operation places it in the resource address (or a parameter), and the regime decides. | +| `principal_id` | Stable identifier the gateway logs for audit (a user id, a sub claim, a service account id). Never used for authorisation — that's `authorise`'s job. | +| `source` | How the credential was presented (`api-key`, `jwt`, …). Non-policy; useful for logs and metrics only. | + +Anything else — roles, claims, group memberships, policy attributes +— stays inside the regime and is reachable only via `authorise`. + +## The `Resource` model + +A `Resource` is a structured value identifying *what is being +operated on*. Resources live at one of three levels in TrustGraph, +based on where the resource exists in the deployment: + +### Resource levels + +| Level | What lives there | Resource shape | +|---|---|---| +| **System** | The user registry, the workspace registry, the signing key, the audit log — anything that exists once per deployment. | `{}` | +| **Workspace** | A workspace's config, flow definitions, library (documents), knowledge cores, collections — things that exist *within* a workspace. | `{workspace: "..."}` | +| **Flow** | A flow's knowledge graph, agent state, LLM context, embedding state, MCP context — things that exist *within* a flow within a workspace. | `{workspace: "...", flow: "..."}` | + +Note carefully: + +- **Users are a system-level resource.** A user record exists at + the deployment level; the fact that a user has a *workspace + association* (one in OSS, possibly many in other regimes) is a + property of the user record, not a containment. Operations on + the user registry have `resource = {}`; the workspace + association appears as a *parameter*, not as a resource address + component. +- **Workspaces themselves are a system-level resource.** The + workspace registry exists at the deployment level. `create- + workspace` and `list-workspaces` are system-level operations; + the workspace identifier in their bodies is a parameter, not an + address. +- **A workspace's contents are workspace-level resources.** A + workspace's config, flows, library, etc. live within a + workspace. Their resource address is `{workspace: ...}`. +- **A flow's contents are flow-level resources.** Knowledge + graphs, agents, etc. live within a flow. Their resource + address is `{workspace: ..., flow: ...}`. + +### Component vocabulary + +| Component | Type | Meaning | Used by | +|---|---|---|---| +| `workspace` | string | Identifier of the workspace whose contents are being operated on | workspace-level and flow-level resource addresses | +| `flow` | string | Identifier of a flow within a workspace; always paired with `workspace` | flow-level resource addresses | +| `collection` | string | Reserved for finer-grained scoping within a workspace | future / enterprise | +| `document` | string | Reserved for per-document scoping | future / enterprise | + +A `Resource` is a partial mapping of these components to values. +The level of the resource (system / workspace / flow) determines +which components must be present. An empty `{}` is the +system-level resource. + +### Workspace as parameter vs. address + +Workspace plays two distinct roles in operations and shows up in +two distinct places: + +- **As a resource address component** — workspace identifies the + thing being operated on. Lives in `resource.workspace`. Example: + `config:read` reads the config *of* workspace W. +- **As an operation parameter** — workspace is data the operation + acts on or filters by, while the resource itself is system-level. + Lives in `parameters.workspace`. Example: `users:write` + creates a user *with workspace association* W; the resource is + the user registry (system), and W is a parameter. + +These are not interchangeable. The IAM regime considers each role +separately; the OSS role table, for instance, applies workspace- +scope to the address component when checking workspace-level +operations, and to a parameter when checking +"create-user-with-workspace-W". Both end up enforcing the admin's +scope, but through different code paths. + +### Extension rules + +The vocabulary is closed but extensible. Adding a new component: + +1. The component is added to the vocabulary in this spec, with a + defined name, type, and meaning. +2. Existing IAM regimes ignore unknown components (forward + compatibility — adding a new component does not break older + regimes that don't understand it). +3. Older gateways that don't populate a new component leave it + unset; regimes that need it for a decision treat "unset" as + "absent" and decide accordingly (typically: cannot grant + permission scoped to a component the gateway didn't supply). + +A regime that wants stricter behaviour (e.g. fail-closed on +unknown components rather than ignoring them) declares so as part +of its own configuration; the contract default is "ignore unknown". + +## Operation registry (gateway-side) + +Mapping a request onto `(capability, resource, parameters)` is +service-specific — it cannot be inferred from the capability +alone. The gateway maintains an **operation registry** that +declares, per operation: + +- The required capability. +- The resource level (system / workspace / flow) — determines the + shape of the resource identifier. +- How to extract the resource address components (workspace, + flow) from the request — from URL path, WebSocket envelope, or + body. +- Which body fields are operation parameters (and which of those + the IAM regime should see in the `parameters` argument). + +This registry is part of the gateway's endpoint declarations, not +part of the IAM contract. The contract specifies what arguments +`authorise` receives; how the gateway populates them is its own +concern. + +In the OSS gateway, registry keys follow these conventions: + +| Pattern | Used by | Resource level | +|---|---|---| +| bare op name (`create-user`, `list-users`, `login`, …) | `/api/v1/iam` and the auth surface | system / workspace, per op | +| `:` (`config:get`, `flow:list-blueprints`, `librarian:add-document`, …) | `/api/v1/{kind}` (workspace-scoped global services) | workspace | +| `flow-service:` (`flow-service:agent`, `flow-service:graph-rag`, …) | `/api/v1/flow/{flow}/service/{kind}` and the WS Mux | flow | +| `flow-import:` / `flow-export:` | `/api/v1/flow/{flow}/{import,export}/{kind}` streaming sockets | flow | + +Keys are an OSS-gateway implementation detail — the contract does +not constrain naming. The conventions above exist so the registry +key is uniquely derivable from the request path and (where +applicable) body without ambiguity. + +## Caching + +Both `authenticate` and `authorise` results are cached at the +gateway, on different policies: + +- **`authenticate`** — cached by a hash of the credential. The OSS + gateway uses a fixed short TTL (currently 60 s) so that revoked + API keys and disabled users stop working within the TTL window + without any push mechanism. Regimes that want a different + behaviour can return an `expires` hint with the identity; the + gateway honours the smaller of `expires` and its own ceiling. + +- **`authorise`** — cached by a hash of `(handle, capability, + resource, parameters)`. The regime returns a suggested TTL with + the decision; the gateway clamps it above by a deployment-set + ceiling (currently 60 s). Both allow and deny decisions are + cached; denies briefly, to avoid hammering the regime with + repeated rejected attempts. + +The TTL ceiling caps the revocation latency window — a role +revoked at the regime takes effect at the gateway no later than +the ceiling. Operators that need stricter revocation can lower +the ceiling. + +## Failure modes + +| Condition | Behaviour | +|---|---| +| `authenticate` returns AuthFailure | Gateway responds 401 with the masked `auth failure` body. | +| `authorise` returns deny | Gateway responds 403 with the masked `access denied` body. | +| IAM regime unreachable | Gateway responds 401 / 503 (deployment-defined). No fail-open. | +| `authorise_many` partial deny | Gateway treats the request as denied; the operation is rejected. Partial-success semantics are not part of the contract. | +| Regime returns "not supported" for a management operation | Gateway responds 501. | + +There is no fallback or "soft" decision path. An IAM regime that +is unavailable, slow, or returning errors causes requests to fail +closed. + +## Implementations + +### Open-source role-based regime + +Defined in [`iam-protocol.md`](iam-protocol.md). Implements the +contract via: + +- A pub/sub request/response service (`iam-svc`) reached only by + the gateway over the message bus. +- Credentials are API keys (opaque) or JWTs (Ed25519, locally + validated by the gateway against the regime's published public + key). +- `authorise` reduces to a role-and-workspace-scope check against + the role table defined in [`capabilities.md`](capabilities.md). +- Identity, user, and workspace records live in Cassandra. + +The OSS regime is deliberately simple — three roles, single +home-workspace per user (a regime data-model decision, not a +contract assertion), no policy language. + +### Future regimes + +The contract is shaped to admit, without code change in the +gateway: + +- **OIDC / SSO** — `authenticate` validates an OIDC ID token via + the IdP's JWKS; `Identity.handle` carries the verified subject + and group claims; `authorise` evaluates against group-to- + capability mappings configured at the regime. +- **ABAC / Policy engine** — `authorise` calls out to a policy + engine (Rego, Cedar, custom DSL) with the identity's attributes + and the resource as the policy input. +- **ReBAC (Zanzibar-style)** — `authorise` translates `(identity, + capability, resource)` into a relationship-tuple lookup against + a tuple store. +- **Hybrid** — multiple regimes composed: e.g. authenticate via + SSO, authorise via local policy. + +None of these require gateway changes. The contract surface is +the same; the regime is what differs. + +## References + +- [Identity and Access Management Specification](iam.md) — overall + design and the gateway-side framing. +- [IAM Service Protocol Specification](iam-protocol.md) — the OSS + regime's wire-level protocol. +- [Capability Vocabulary Specification](capabilities.md) — the + capability strings the gateway uses as `authorise` input. diff --git a/docs/tech-specs/iam-protocol.md b/docs/tech-specs/iam-protocol.md index 8049ebfe..603d1c06 100644 --- a/docs/tech-specs/iam-protocol.md +++ b/docs/tech-specs/iam-protocol.md @@ -8,21 +8,41 @@ parent: "Tech Specs" ## Overview -The IAM service is a backend processor, reached over the standard -request/response pub/sub pattern. It is the authority for users, -workspaces, API keys, and login credentials. The API gateway -delegates to it for authentication resolution and for all user / -workspace / key management. +This document specifies the wire protocol of the **open-source IAM +regime** — one implementation of the abstract IAM contract defined +in [`iam-contract.md`](iam-contract.md). Other regimes (OIDC / SSO, +ABAC, ReBAC, external policy engines) implement the same contract +with different transports, data models, and policy semantics; the +gateway is unaware of which regime it's wired against. -This document defines the wire protocol: the `IamRequest` and -`IamResponse` dataclasses, the operation set, the per-operation -input and output fields, the error taxonomy, and the initial HTTP -forwarding endpoint used while IAM is being integrated into the -gateway. +The OSS regime is a backend processor (`iam-svc`) reached over the +standard request/response pub/sub pattern. It owns users, +workspaces, API keys, login credentials, and JWT signing keys, all +backed by Cassandra. The API gateway is its only caller. -Architectural context — roles, capabilities, workspace scoping, -enforcement boundary — lives in [`iam.md`](iam.md) and -[`capabilities.md`](capabilities.md). +This document defines: + +- the `IamRequest` and `IamResponse` dataclasses on the bus, +- the operation set the OSS regime implements, +- per-operation input and output fields, +- the error taxonomy, +- the bootstrap modes, +- the initial HTTP forwarding endpoint used while the protocol is + being exercised. + +The mapping from this regime onto the abstract contract is direct: + +| Contract operation | OSS regime operation | +|---|---| +| `authenticate(credential)` | `resolve-api-key` (for API keys); local JWT validation against `get-signing-key-public` (for JWTs) | +| `authorise(identity, capability, resource, parameters)` | Role-table lookup against the OSS role bundles defined in [`capabilities.md`](capabilities.md), gated by workspace scope. Workspace can come from the resource address (workspace- and flow-level resources) or from a parameter (system-level resources whose parameters reference a workspace, e.g. `create-user with workspace association W`). | +| `authorise_many` | Loop over `authorise` | +| Identity / credential / workspace management | `create-user`, `create-api-key`, etc. as listed below. These are operations on system-level resources (the user / workspace / credential registries); workspace, where it appears in the body, is a parameter. | + +Architectural context — roles, capabilities, workspace as resource +scope, enforcement boundary — lives in [`iam.md`](iam.md) and +[`capabilities.md`](capabilities.md). The contract abstraction +lives in [`iam-contract.md`](iam-contract.md). ## Transport @@ -345,5 +365,7 @@ lands in the subsequent middleware work. ## References +- [IAM Contract Specification](iam-contract.md) — the abstract + gateway↔IAM regime contract this protocol implements. - [Identity and Access Management Specification](iam.md) - [Capability Vocabulary Specification](capabilities.md) diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md index 50b64444..a764535e 100644 --- a/docs/tech-specs/iam.md +++ b/docs/tech-specs/iam.md @@ -199,9 +199,9 @@ The server rejects all non-auth messages until authentication succeeds. The socket remains open on auth failure, allowing the client to retry with a different token without reconnecting. The client can also send a new auth message at any time to re-authenticate — for example, to -refresh an expiring JWT or to switch workspace. The -resolved identity (user, workspace, roles) is updated on each -successful auth. +refresh an expiring JWT or to switch workspace. The resolved +identity (handle, workspace, principal_id, source) is updated on +each successful auth. #### API keys @@ -219,7 +219,7 @@ For programmatic access: CLI tools, scripts, and integrations. On each request, the gateway resolves an API key by: 1. Hashing the token. -2. Checking a local cache (hash → user/workspace/roles). +2. Checking a local cache (hash → identity). 3. On cache miss, calling the IAM service to resolve. 4. Caching the result with a short TTL (e.g. 60 seconds). @@ -233,9 +233,15 @@ For interactive access via the UI or WebSocket connections. - A user logs in with username and password. The gateway forwards the request to the IAM service, which validates the credentials and returns a signed JWT. -- The JWT carries the user ID, workspace, and roles as claims. +- The JWT carries identity-binding claims only — user id (`sub`) + and the workspace this credential authenticates to. No roles, + no policy state. Per the IAM contract, all policy decisions go + through `authorise`; the gateway never reads roles or other + regime-internal state from the credential. - The gateway validates JWTs locally using the IAM service's public - signing key — no service call needed on subsequent requests. + signing key — no service call needed for the authentication step; + authorisation calls remain per-request (cached per the contract's + caching rules). - Token expiry is enforced by standard JWT validation at the time the request (or WebSocket connection) is made. - For long-lived WebSocket connections, the JWT is validated at connect @@ -285,35 +291,82 @@ authentication uses API keys or JWTs. On first start, the bootstrap process creates a default workspace and admin user with an initial API key. -### User identity +### Identity, credentials, and workspace binding -A user belongs to exactly one workspace. The design supports extending -this to multi-workspace access in the future (see -[Extension points](#extension-points)). +The gateway never asks "which workspace does *this user* belong to?". +That question forces every IAM regime to expose a user-to-workspace +mapping, which prevents regimes where the relationship is many-to-many +or doesn't exist (e.g. SSO with IdP-driven workspace selection). +Instead, the gateway asks "which workspace does *this credential* +authenticate to?" — a question every regime can answer in its own +terms. -A user record contains: +A credential (API key, JWT, OIDC token, etc.) is **bound to a +workspace at issue time**. The IAM regime decides what binding +means: + +- **OSS regime** — each user has a home workspace; credentials + issued to that user are bound to that workspace. A 1:1 + user-to-workspace constraint is an internal data-model decision, + not a contract assertion. +- **Multi-workspace regime** (future / enterprise) — a user with + access to several workspaces gets a different credential per + workspace. Each credential authenticates to exactly one + workspace; the relationship between user and workspace is a + regime-internal detail the gateway does not see. + +When the gateway authenticates a credential, the IAM regime returns +an `Identity` whose `workspace` is the workspace this credential is +for. That value — not "the user's workspace" — is what the gateway +uses for default-fill-in and as input to the IAM `authorise` call. + +#### Identity surface + +What the gateway holds after `authenticate`: + +| Field | Purpose | +|-------|---------| +| `handle` | Opaque token quoted back when calling `authorise`. Regime-defined. | +| `workspace` | The workspace this credential authenticates to. Used as the default if a request omits workspace. | +| `principal_id` | Stable identifier for audit logging (a user id, sub claim, service account id). Never used for authorisation. | +| `source` | How the credential was presented (`api-key`, `jwt`). Logged with audit events; not policy input. | + +Anything else — roles, claims, group memberships, policy attributes +— stays inside the regime and is reachable only via `authorise`. +See [`iam-contract.md`](iam-contract.md) for the full contract. + +#### OSS user record + +The OSS regime stores the following per user. These fields are +**OSS-implementation specifics**, not part of the contract. | Field | Type | Description | |-------|------|-------------| | `id` | string | Unique user identifier (UUID) | | `name` | string | Display name | | `email` | string | Email address (optional) | -| `workspace` | string | Workspace the user belongs to | +| `workspace` | string | Home workspace; default binding for issued credentials | | `roles` | list[string] | Assigned roles (e.g. `["reader"]`) | | `enabled` | bool | Whether the user can authenticate | | `created` | datetime | Account creation timestamp | -The `workspace` field maps to the existing `user` field in `Metadata`. -This means the storage-layer isolation (Cassandra, Neo4j, Qdrant -filtering by `user` + `collection`) works without changes — the gateway -sets the `user` metadata field to the authenticated user's workspace. +The `workspace` field on a user record is the **default binding** +used when issuing credentials, not a constraint visible to the +gateway. An enterprise regime may have no user records at all +(authentication delegated to an IdP). ### Workspaces -A workspace is an isolated data boundary. Users belong to a workspace, -and all data operations are scoped to it. Workspaces map to the existing -`user` field in `Metadata` and the corresponding Cassandra keyspace, -Qdrant collection prefix, and Neo4j property filters. +A workspace is an isolated data boundary — a tenancy scope in which +users, flows, configuration, documents, and knowledge graphs live. +Workspaces map to storage-layer isolation: the `user` field in +`Metadata`, the corresponding Cassandra keyspace, the Qdrant +collection prefix, the Neo4j property filter. + +Workspace is the most prominent component of an operation's +**resource scope**: when a request says "do X to Y", workspace is +part of "Y". Listing users, creating flows, querying the graph — +all of these target a specific workspace. | Field | Type | Description | |-------|------|-------------| @@ -322,57 +375,164 @@ Qdrant collection prefix, and Neo4j property filters. | `enabled` | bool | Whether the workspace is active | | `created` | datetime | Creation timestamp | -All data operations are scoped to a workspace. The gateway determines -the effective workspace for each request as follows: +#### Default-fill-in -1. If the request includes a `workspace` parameter, validate it against - the user's assigned workspace. - - If it matches, use it. - - If it does not match, return 403. (This could be extended to - check a workspace access grant list.) -2. If no `workspace` parameter is provided, use the user's assigned - workspace. +If a request omits workspace, the gateway fills it in from the +authenticated identity's bound workspace (`identity.workspace`) +before any IAM check runs. IAM never receives an unresolved +workspace; every `authorise` call sees a concrete value. -The gateway sets the `user` field in `Metadata` to the effective -workspace ID, replacing the caller-supplied `?user=` query parameter. +#### Authorisation -This design ensures forward compatibility. Clients that pass a -workspace parameter will work unchanged if multi-workspace support is -added later. Requests for an unassigned workspace get a clear 403 -rather than silent misbehaviour. +Whether the resolved workspace is permitted to be operated on by +this caller is an **IAM decision**, not a gateway one. The gateway +calls `authorise(identity, capability, {workspace: ..., ...})` and +relays the answer. In the OSS regime, the answer comes from the +caller's role × workspace-scope — see [`capabilities.md`](capabilities.md). +In other regimes it could come from group mappings, policies, +relationship tuples, or anything else the regime models. + +### Request anatomy + +The shape of a request — where workspace appears, where flow +appears, where parameters live — follows from **the level of the +resource being operated on**, not from any single property of the +request like its URL or its required capability. + +Resources live at one of three levels (see also the resource model +in [`iam-contract.md`](iam-contract.md)): + +| Resource level | Examples | Resource address | +|---|---|---| +| **System** | The user registry, the workspace registry, the IAM signing key, the audit log | empty `{}` | +| **Workspace** | A workspace's config, flow definitions, library, knowledge cores, collections | `{workspace: ...}` | +| **Flow** | A flow's knowledge graph, agent state, LLM context, embeddings, MCP context | `{workspace: ..., flow: ...}` | + +For the gateway-to-bus mapping this dictates **where workspace +lives in the message**, but only when workspace is part of the +*resource address*. Workspace can also appear as an *operation +parameter* on system-level resources (see below). + +#### Workspace as address vs. parameter + +Two distinct roles, two distinct locations: + +- **Workspace as address component.** Workspace identifies the + thing being operated on. Used for workspace-level and flow-level + resources. Lives in the addressing layer of the message — the + URL path for HTTP, or the WebSocket envelope alongside `flow` for + flow-scoped operations sent through the Mux. +- **Workspace as operation parameter.** Workspace is data the + operation acts on, while the resource itself is system-level. + Used for operations on the user registry (`create-user with + workspace association W`), the workspace registry (`create- + workspace W`), and other system-level operations that happen to + reference a workspace. Lives in the request body or inner WS + payload alongside the operation's other parameters. + +The two roles never coexist on the same operation. Either the +operation addresses something within a workspace (workspace is in +the address) or it operates on a system-level resource with +workspace as a parameter (workspace is in the body) or workspace +is irrelevant (system-level operations like `bootstrap`, +`rotate-signing-key`, `login` itself). + +#### Where workspace lives, by request type + +| Request type | Resource level | Workspace lives in | +|---|---|---| +| Flow-scoped data plane (`agent`, `graph-rag`, `llm`, `embeddings`, `mcp`, etc.) | Flow | Envelope alongside `flow` (WS) or URL path (HTTP) — part of the address | +| Workspace-scoped control plane (`config`, `library`, `knowledge`, `collection-management`, flow lifecycle) | Workspace | Body / inner request — part of the address | +| User registry ops (`create-user`, `list-users`, `disable-user`, etc.) | System | Body — as a *parameter* (the user's workspace association or a list filter) | +| Workspace registry ops (`create-workspace`, `list-workspaces`, etc.) | System | Body — as a *parameter* (the workspace identifier in `workspace_record`) | +| Credential ops (`create-api-key`, `revoke-api-key`, `change-password`, `reset-password`) | System | Body — as a *parameter* on ops that have one; absent on `change-password` (target is the caller's identity) | +| System ops (`bootstrap`, `login`, `rotate-signing-key`, `get-signing-key-public`) | System | Not present at all | + +The classification is deliberate. Users are a global concept that +*have* a workspace; they don't *live* in one. An OSS regime has +1:1 user-to-workspace; a multi-workspace regime maps a user to many +workspaces; an SSO regime might delegate workspace membership to an +IdP entirely. The gateway treats user-registry operations as +system-level so the contract is the same across regimes — the +workspace association is a parameter the regime interprets in its +own terms. + +#### HTTP + +HTTP routes by URL path, so the address lives in the URL. +Per-operation REST shape: + +- Flow-level: `POST /api/v1/workspaces/{w}/flows/{f}/services/{kind}` + — `workspace` and `flow` are URL components. +- Workspace-level: `POST /api/v1/workspaces/{w}/config`, + `/api/v1/workspaces/{w}/library`, etc. — `workspace` is a URL + component. +- System-level: `POST /api/v1/users`, `/api/v1/workspaces`, etc. — + no workspace in URL; if the operation references one, it's a + field in the body. + +`/api/v1/iam` is itself registry-driven: the body's `operation` +field is looked up against the registry to obtain the capability, +resource shape, and parameter shape per operation, rather than +gating the whole endpoint with a single coarse capability. + +#### WebSocket Mux + +The Mux envelope is the addressing layer for flow-scoped +operations. For workspace-level and system-level operations the +envelope routes by `service` only, and the inner request payload +carries the address components or parameters as appropriate. See +[`iam-contract.md`](iam-contract.md) for the operation-registry +mechanism the Mux uses to know which fields to read. ### Roles and access control -Three roles with fixed permissions: +Roles are an OSS-regime concept and live entirely in the IAM +service. The gateway does not enumerate or check them — it asks +`authorise(identity, capability, resource, parameters)` per +request and the regime maps the caller's roles to a decision. -| Role | Data operations | Admin operations | System | -|------|----------------|-----------------|--------| -| `reader` | Query knowledge graph, embeddings, RAG | None | None | -| `writer` | All reader operations + load documents, manage collections | None | None | -| `admin` | All writer operations | Config, flows, collection management, user management | Metrics | +The OSS regime ships three roles: -Role checks happen at the gateway before dispatching to backend -services. Each endpoint declares the minimum role required: +| Role | Capabilities granted | +|------|----------------------| +| `reader` | Read capabilities on data and config (`graph:read`, `documents:read`, `rows:read`, `config:read`, `flows:read`, `knowledge:read`, `collections:read`, `keys:self`, plus the per-service caps `agent`, `llm`, `embeddings`, `mcp`). | +| `writer` | All reader capabilities, plus `graph:write`, `documents:write`, `rows:write`, `knowledge:write`, `collections:write`. | +| `admin` | All writer capabilities, plus `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read`. | -| Endpoint pattern | Minimum role | -|-----------------|--------------| -| `GET /api/v1/socket` (queries) | `reader` | -| `POST /api/v1/librarian` | `writer` | -| `POST /api/v1/flow/*/import/*` | `writer` | -| `POST /api/v1/config` | `admin` | -| `GET /api/v1/flow/*` | `admin` | -| `GET /api/metrics` | `admin` | +Workspace scope: `reader` and `writer` are active only in the +caller's bound workspace; `admin` is active across all workspaces. -Roles are hierarchical: `admin` implies `writer`, which implies -`reader`. +The gateway gates each endpoint by *capability*, not by role. +Capabilities are declared per operation in the gateway's operation +registry; see [`iam-contract.md`](iam-contract.md) for the +registry mechanism and [`capabilities.md`](capabilities.md) for +the capability vocabulary. ### IAM service -The IAM service is a new backend service that manages all identity and -access data. It is the authority for users, workspaces, API keys, and -credentials. The gateway delegates to it. +The IAM service is a backend service that implements the +[IAM contract](iam-contract.md) — `authenticate`, `authorise`, and +the management operations the gateway forwards. It is the +authority for identity, credential validation, and access decisions. +The gateway treats it as a black box behind the contract; nothing +in the gateway is regime-specific. -#### Data model +The OSS distribution ships one IAM regime: a role-based service +backed by Cassandra, described in +[`iam-protocol.md`](iam-protocol.md). Enterprise / future regimes +can replace this implementation without changing the gateway, the +wire protocol between gateway and backends, or the capability +vocabulary — see the contract spec for the abstraction the gateway +is wired against and the implementation notes for what other +regimes look like. + +#### OSS data model + +The OSS regime stores users, workspaces, API keys, and signing +keys in Cassandra. This is an **OSS regime implementation +detail**; it is not part of the contract. Other regimes will have +different (or no) data models. ``` iam_workspaces ( @@ -456,42 +616,53 @@ surface — e.g. `"missing required field 'workspace'"` or ### Gateway changes -The current `Authenticator` class is replaced with a thin authentication -middleware that delegates to the IAM service: +The current `Authenticator` class is replaced with a thin +authentication+authorisation middleware that delegates to the IAM +service per the IAM contract. The gateway performs no role check +itself — authorisation is asked of the regime via `authorise`. For HTTP requests: 1. Extract Bearer token from the `Authorization` header. 2. If the token has JWT format (dotted structure): - Validate signature locally using the cached public key. - - Extract user ID, workspace, and roles from claims. + - Build an `Identity` from `sub` and `workspace` claims (no + other claims are consulted). 3. Otherwise, treat as an API key: - Hash the token and check the local cache. - - On cache miss, call the IAM service to resolve. - - Cache the result (user/workspace/roles) with a short TTL. + - On cache miss, call the IAM service to resolve to an + `Identity` (handle, workspace, principal_id, source). + - Cache the result with a short TTL. 4. If neither succeeds, return 401. -5. If the user or workspace is disabled, return 403. -6. Check the user's role against the endpoint's minimum role. If - insufficient, return 403. -7. Resolve the effective workspace: - - If the request includes a `workspace` parameter, validate it - against the user's assigned workspace. Return 403 on mismatch. - - If no `workspace` parameter, use the user's assigned workspace. -8. Set the `user` field in the request context to the effective - workspace ID. This propagates through `Metadata` to all downstream - services. +5. Look up the operation in the gateway's operation registry to get + `(capability, resource_level, extractors)`. Build the resource + address (system / workspace / flow level) and parameters from + the request. +6. Default-fill the workspace into the body when the operation is + workspace- or flow-level (so downstream code sees a single + canonical address); the resource address keeps its supplied + value. +7. Call `authorise(identity, capability, resource, parameters)`. + On allow, forward the request; on deny, return 403. On regime + error, fail closed (401 / 503 per deployment). +8. Cache the decision per the contract's caching rules (clamped + above by a deployment-set ceiling). For WebSocket connections: 1. Accept the connection in an unauthenticated state. 2. Wait for an auth message (`{"type": "auth", "token": "..."}`). -3. Validate the token using the same logic as steps 2-7 above. +3. Validate the token using the same logic as steps 1-3 above. 4. On success, attach the resolved identity to the connection and send `{"type": "auth-ok", ...}`. 5. On failure, send `{"type": "auth-failed", ...}` but keep the socket open. 6. Reject all non-auth messages until authentication succeeds. 7. Accept new auth messages at any time to re-authenticate. +8. For each subsequent request frame, look up + `flow-service:` in the registry and call `authorise` + against the `{workspace, flow}` resource — same authority + gateway HTTP callers see, evaluated per-frame. ### CLI changes @@ -892,6 +1063,12 @@ service, not in the config service. Reasons: ## References +- [IAM Contract Specification](iam-contract.md) — the gateway↔IAM + regime abstraction this design is wired against. +- [IAM Service Protocol Specification](iam-protocol.md) — the OSS + regime's wire-level protocol. +- [Capability Vocabulary Specification](capabilities.md) — the + capability strings the gateway uses as `authorise` input. - [Data Ownership and Information Separation](data-ownership-model.md) - [MCP Tool Bearer Token Specification](mcp-tool-bearer-token.md) - [Multi-Tenant Support Specification](multi-tenant-support.md) diff --git a/tests/unit/test_gateway/test_auth.py b/tests/unit/test_gateway/test_auth.py index ba2b9bc2..26e93fd9 100644 --- a/tests/unit/test_gateway/test_auth.py +++ b/tests/unit/test_gateway/test_auth.py @@ -87,7 +87,6 @@ class TestVerifyJwtEddsa: priv, pub = make_keypair() claims = { "sub": "user-1", "workspace": "default", - "roles": ["reader"], "iat": int(time.time()), "exp": int(time.time()) + 60, } @@ -99,7 +98,7 @@ class TestVerifyJwtEddsa: def test_expired_jwt_rejected(self): priv, pub = make_keypair() claims = { - "sub": "user-1", "workspace": "default", "roles": [], + "sub": "user-1", "workspace": "default", "iat": int(time.time()) - 3600, "exp": int(time.time()) - 1, } @@ -111,7 +110,7 @@ class TestVerifyJwtEddsa: priv_a, _ = make_keypair() _, pub_b = make_keypair() claims = { - "sub": "user-1", "workspace": "default", "roles": [], + "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } @@ -131,7 +130,7 @@ class TestVerifyJwtEddsa: # since we expect it to bail before verifying. header = {"alg": "HS256", "typ": "JWT", "kid": "x"} payload = { - "sub": "user-1", "workspace": "default", "roles": [], + "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } h = _b64url(json.dumps(header, separators=(",", ":")).encode()) @@ -149,11 +148,12 @@ class TestIdentity: def test_fields(self): i = Identity( - user_id="u", workspace="w", roles=["reader"], source="api-key", + handle="u", workspace="w", + principal_id="u", source="api-key", ) - assert i.user_id == "u" + assert i.handle == "u" assert i.workspace == "w" - assert i.roles == ["reader"] + assert i.principal_id == "u" assert i.source == "api-key" @@ -194,7 +194,6 @@ class TestIamAuthDispatch: priv, pub = make_keypair() claims = { "sub": "user-1", "workspace": "default", - "roles": ["writer"], "iat": int(time.time()), "exp": int(time.time()) + 60, } @@ -206,9 +205,9 @@ class TestIamAuthDispatch: ident = await auth.authenticate( make_request(f"Bearer {token}") ) - assert ident.user_id == "user-1" + assert ident.handle == "user-1" assert ident.workspace == "default" - assert ident.roles == ["writer"] + assert ident.principal_id == "user-1" assert ident.source == "jwt" @pytest.mark.asyncio @@ -217,7 +216,7 @@ class TestIamAuthDispatch: # must not validate — even ones that would otherwise pass. priv, _ = make_keypair() claims = { - "sub": "user-1", "workspace": "default", "roles": [], + "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv, claims) @@ -232,6 +231,9 @@ class TestIamAuthDispatch: async def fake_resolve(api_key): assert api_key == "tg_testkey" + # Roles are returned by the regime as a hint but the + # gateway ignores them — kept here so the resolve + # protocol shape is exercised. return ("user-xyz", "default", ["admin"]) async def fake_with_client(op): @@ -241,9 +243,9 @@ class TestIamAuthDispatch: ident = await auth.authenticate( make_request("Bearer tg_testkey") ) - assert ident.user_id == "user-xyz" + assert ident.handle == "user-xyz" assert ident.workspace == "default" - assert ident.roles == ["admin"] + assert ident.principal_id == "user-xyz" assert ident.source == "api-key" @pytest.mark.asyncio @@ -301,8 +303,8 @@ class TestApiKeyCache: a = await auth.authenticate(make_request("Bearer tg_a")) b = await auth.authenticate(make_request("Bearer tg_b")) - assert a.user_id == "u-tg_a" - assert b.user_id == "u-tg_b" + assert a.handle == "u-tg_a" + assert b.handle == "u-tg_b" assert seen == ["tg_a", "tg_b"] @pytest.mark.asyncio @@ -310,3 +312,136 @@ class TestApiKeyCache: # Not a behaviour test — just ensures we don't accidentally # set TTL to 0 (which would defeat the cache) or to a week. assert 10 <= API_KEY_CACHE_TTL <= 3600 + + +# -- IamAuth.authorise ----------------------------------------------------- + + +class TestAuthorise: + """``authorise()`` is the gateway's only authorisation entry + point under the IAM contract. It calls iam-svc, caches the + decision for the regime's TTL (clamped above), and raises 403 + on deny / 401 on regime error (fail closed).""" + + def _make_identity(self, handle="u-1", workspace="default"): + return Identity( + handle=handle, workspace=workspace, + principal_id=handle, source="api-key", + ) + + @pytest.mark.asyncio + async def test_allow_returns_no_exception(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authorise=AsyncMock(return_value=(True, 30)), + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + await auth.authorise( + self._make_identity(), + "graph:read", + {"workspace": "default"}, + {}, + ) + + @pytest.mark.asyncio + async def test_deny_raises_403(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authorise=AsyncMock(return_value=(False, 30)), + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPForbidden): + await auth.authorise( + self._make_identity(), + "users:admin", + {}, + {"workspace": "acme"}, + ) + + @pytest.mark.asyncio + async def test_regime_error_fails_closed_as_401(self): + # If iam-svc errors, the gateway must NOT silently allow. + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + raise RuntimeError("iam-svc down") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authorise( + self._make_identity(), + "graph:read", + {"workspace": "default"}, + {}, + ) + + @pytest.mark.asyncio + async def test_allow_decision_is_cached(self): + auth = IamAuth(backend=Mock()) + calls = {"n": 0} + + async def fake_with_client(op): + calls["n"] += 1 + return await op(Mock( + authorise=AsyncMock(return_value=(True, 30)), + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = self._make_identity() + for _ in range(5): + await auth.authorise( + ident, "graph:read", {"workspace": "default"}, {}, + ) + + assert calls["n"] == 1 + + @pytest.mark.asyncio + async def test_deny_decision_is_cached(self): + auth = IamAuth(backend=Mock()) + calls = {"n": 0} + + async def fake_with_client(op): + calls["n"] += 1 + return await op(Mock( + authorise=AsyncMock(return_value=(False, 30)), + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = self._make_identity() + for _ in range(5): + with pytest.raises(web.HTTPForbidden): + await auth.authorise( + ident, "users:admin", {}, {"workspace": "acme"}, + ) + + # Denies are cached too — repeated attempts don't re-hit IAM. + assert calls["n"] == 1 + + @pytest.mark.asyncio + async def test_different_resources_cached_separately(self): + auth = IamAuth(backend=Mock()) + calls = {"n": 0} + + async def fake_with_client(op): + calls["n"] += 1 + return await op(Mock( + authorise=AsyncMock(return_value=(True, 30)), + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = self._make_identity() + await auth.authorise( + ident, "graph:read", {"workspace": "a"}, {}, + ) + await auth.authorise( + ident, "graph:read", {"workspace": "b"}, {}, + ) + + # Different resource → different cache key → two IAM calls. + assert calls["n"] == 2 diff --git a/tests/unit/test_gateway/test_capabilities.py b/tests/unit/test_gateway/test_capabilities.py index 063e9ea4..102e381e 100644 --- a/tests/unit/test_gateway/test_capabilities.py +++ b/tests/unit/test_gateway/test_capabilities.py @@ -1,15 +1,22 @@ """ -Tests for gateway/capabilities.py — the capability + role + workspace -model that underpins all gateway authorisation. +Tests for gateway/capabilities.py — the thin authorisation surface +under the IAM contract. + +The gateway no longer holds policy state (roles, capability sets, +workspace scopes); those live in iam-svc. These tests cover only +what the gateway shim does itself: PUBLIC / AUTHENTICATED short- +circuiting, default-fill of workspace, and forwarding of capability +checks to ``auth.authorise``. """ import pytest from aiohttp import web +from unittest.mock import AsyncMock, MagicMock from trustgraph.gateway.capabilities import ( PUBLIC, AUTHENTICATED, - KNOWN_CAPABILITIES, ROLE_DEFINITIONS, - check, enforce_workspace, access_denied, auth_failure, + enforce, enforce_workspace, + access_denied, auth_failure, ) @@ -17,109 +24,74 @@ from trustgraph.gateway.capabilities import ( class _Identity: - """Minimal stand-in for auth.Identity — the capability module - accesses ``.workspace`` and ``.roles``.""" - def __init__(self, workspace, roles): - self.user_id = "user-1" + """Stand-in for auth.Identity — under the IAM contract it has + just ``handle``, ``workspace``, ``principal_id``, ``source``.""" + + def __init__(self, handle="user-1", workspace="default"): + self.handle = handle self.workspace = workspace - self.roles = list(roles) + self.principal_id = handle + self.source = "api-key" -def reader_in(ws): - return _Identity(ws, ["reader"]) +def _allow_auth(identity=None): + """Build an Auth double that authenticates to ``identity`` and + allows every authorise() call.""" + auth = MagicMock() + auth.authenticate = AsyncMock( + return_value=identity or _Identity(), + ) + auth.authorise = AsyncMock(return_value=None) + return auth -def writer_in(ws): - return _Identity(ws, ["writer"]) +def _deny_auth(identity=None): + """Build an Auth double that authenticates but denies authorise.""" + auth = MagicMock() + auth.authenticate = AsyncMock( + return_value=identity or _Identity(), + ) + auth.authorise = AsyncMock(side_effect=access_denied()) + return auth -def admin_in(ws): - return _Identity(ws, ["admin"]) +# -- enforce() ------------------------------------------------------------- -# -- role table sanity ----------------------------------------------------- +class TestEnforce: + @pytest.mark.asyncio + async def test_public_returns_none_no_auth(self): + auth = _allow_auth() + result = await enforce(MagicMock(), auth, PUBLIC) + assert result is None + auth.authenticate.assert_not_called() + auth.authorise.assert_not_called() -class TestRoleTable: + @pytest.mark.asyncio + async def test_authenticated_skips_authorise(self): + identity = _Identity() + auth = _allow_auth(identity) + result = await enforce(MagicMock(), auth, AUTHENTICATED) + assert result is identity + auth.authenticate.assert_awaited_once() + auth.authorise.assert_not_called() - def test_oss_roles_present(self): - assert set(ROLE_DEFINITIONS.keys()) == {"reader", "writer", "admin"} + @pytest.mark.asyncio + async def test_capability_calls_authorise_system_level(self): + identity = _Identity() + auth = _allow_auth(identity) + result = await enforce(MagicMock(), auth, "graph:read") + assert result is identity + auth.authorise.assert_awaited_once_with( + identity, "graph:read", {}, {}, + ) - def test_admin_is_cross_workspace(self): - assert ROLE_DEFINITIONS["admin"]["workspace_scope"] == "*" - - def test_reader_writer_are_assigned_scope(self): - assert ROLE_DEFINITIONS["reader"]["workspace_scope"] == "assigned" - assert ROLE_DEFINITIONS["writer"]["workspace_scope"] == "assigned" - - def test_admin_superset_of_writer(self): - admin = ROLE_DEFINITIONS["admin"]["capabilities"] - writer = ROLE_DEFINITIONS["writer"]["capabilities"] - assert writer.issubset(admin) - - def test_writer_superset_of_reader(self): - writer = ROLE_DEFINITIONS["writer"]["capabilities"] - reader = ROLE_DEFINITIONS["reader"]["capabilities"] - assert reader.issubset(writer) - - def test_admin_has_users_admin(self): - assert "users:admin" in ROLE_DEFINITIONS["admin"]["capabilities"] - - def test_writer_does_not_have_users_admin(self): - assert "users:admin" not in ROLE_DEFINITIONS["writer"]["capabilities"] - - def test_every_bundled_capability_is_known(self): - for role in ROLE_DEFINITIONS.values(): - for cap in role["capabilities"]: - assert cap in KNOWN_CAPABILITIES - - -# -- check() --------------------------------------------------------------- - - -class TestCheck: - - def test_reader_has_reader_cap_in_own_workspace(self): - assert check(reader_in("default"), "graph:read", "default") - - def test_reader_does_not_have_writer_cap(self): - assert not check(reader_in("default"), "graph:write", "default") - - def test_reader_cannot_act_in_other_workspace(self): - assert not check(reader_in("default"), "graph:read", "acme") - - def test_writer_has_write_in_own_workspace(self): - assert check(writer_in("default"), "graph:write", "default") - - def test_writer_cannot_act_in_other_workspace(self): - assert not check(writer_in("default"), "graph:write", "acme") - - def test_admin_has_everything_everywhere(self): - for cap in ("graph:read", "graph:write", "config:write", - "users:admin", "metrics:read"): - assert check(admin_in("default"), cap, "acme"), ( - f"admin should have {cap} in acme" - ) - - def test_admin_has_caps_without_explicit_workspace(self): - assert check(admin_in("default"), "users:admin") - - def test_default_target_is_identity_workspace(self): - # Reader with no target workspace → should check against own - assert check(reader_in("default"), "graph:read") - - def test_unknown_capability_returns_false(self): - assert not check(admin_in("default"), "nonsense:cap", "default") - - def test_unknown_role_contributes_nothing(self): - ident = _Identity("default", ["made-up-role"]) - assert not check(ident, "graph:read", "default") - - def test_multi_role_union(self): - # If a user is both reader and admin, they inherit admin's - # cross-workspace powers. - ident = _Identity("default", ["reader", "admin"]) - assert check(ident, "users:admin", "acme") + @pytest.mark.asyncio + async def test_capability_denied_raises_forbidden(self): + auth = _deny_auth() + with pytest.raises(web.HTTPForbidden): + await enforce(MagicMock(), auth, "users:admin") # -- enforce_workspace() --------------------------------------------------- @@ -127,56 +99,54 @@ class TestCheck: class TestEnforceWorkspace: - def test_reader_in_own_workspace_allowed(self): - data = {"workspace": "default", "operation": "x"} - enforce_workspace(data, reader_in("default")) - assert data["workspace"] == "default" - - def test_reader_no_workspace_injects_assigned(self): + @pytest.mark.asyncio + async def test_default_fills_from_identity(self): data = {"operation": "x"} - enforce_workspace(data, reader_in("default")) + auth = _allow_auth() + await enforce_workspace(data, _Identity(workspace="default"), auth) assert data["workspace"] == "default" - def test_reader_mismatched_workspace_denied(self): + @pytest.mark.asyncio + async def test_caller_supplied_workspace_kept(self): data = {"workspace": "acme", "operation": "x"} - with pytest.raises(web.HTTPForbidden): - enforce_workspace(data, reader_in("default")) - - def test_admin_can_target_any_workspace(self): - data = {"workspace": "acme", "operation": "x"} - enforce_workspace(data, admin_in("default")) + auth = _allow_auth() + await enforce_workspace(data, _Identity(workspace="default"), auth) assert data["workspace"] == "acme" - def test_admin_no_workspace_defaults_to_assigned(self): - data = {"operation": "x"} - enforce_workspace(data, admin_in("default")) - assert data["workspace"] == "default" - - def test_writer_same_workspace_specified_allowed(self): + @pytest.mark.asyncio + async def test_no_capability_skips_authorise(self): data = {"workspace": "default"} - enforce_workspace(data, writer_in("default")) - assert data["workspace"] == "default" + auth = _allow_auth() + await enforce_workspace(data, _Identity(), auth) + auth.authorise.assert_not_called() - def test_non_dict_passthrough(self): - # Non-dict bodies are returned unchanged (e.g. streaming). - result = enforce_workspace("not-a-dict", reader_in("default")) - assert result == "not-a-dict" + @pytest.mark.asyncio + async def test_capability_calls_authorise_with_resource(self): + data = {"workspace": "acme"} + identity = _Identity() + auth = _allow_auth(identity) + await enforce_workspace( + data, identity, auth, capability="graph:read", + ) + auth.authorise.assert_awaited_once_with( + identity, "graph:read", {"workspace": "acme"}, {}, + ) - def test_with_capability_tightens_check(self): - # Reader lacks graph:write; workspace-only check would pass - # (scope is fine), but combined check must reject. - data = {"workspace": "default"} + @pytest.mark.asyncio + async def test_capability_denied_propagates(self): + data = {"workspace": "acme"} + auth = _deny_auth() with pytest.raises(web.HTTPForbidden): - enforce_workspace( - data, reader_in("default"), capability="graph:write", + await enforce_workspace( + data, _Identity(), auth, capability="users:admin", ) - def test_with_capability_passes_when_granted(self): - data = {"workspace": "default"} - enforce_workspace( - data, reader_in("default"), capability="graph:read", - ) - assert data["workspace"] == "default" + @pytest.mark.asyncio + async def test_non_dict_passthrough(self): + auth = _allow_auth() + result = await enforce_workspace("not-a-dict", _Identity(), auth) + assert result == "not-a-dict" + auth.authorise.assert_not_called() # -- helpers --------------------------------------------------------------- @@ -199,5 +169,3 @@ class TestSentinels: def test_public_and_authenticated_are_distinct(self): assert PUBLIC != AUTHENTICATED - assert PUBLIC not in KNOWN_CAPABILITIES - assert AUTHENTICATED not in KNOWN_CAPABILITIES diff --git a/tests/unit/test_gateway/test_endpoint_manager.py b/tests/unit/test_gateway/test_endpoint_manager.py index cf12565c..8f659b71 100644 --- a/tests/unit/test_gateway/test_endpoint_manager.py +++ b/tests/unit/test_gateway/test_endpoint_manager.py @@ -73,14 +73,16 @@ class TestEndpointManager: prometheus_url="http://test:9090" ) - # Each dispatcher factory is invoked exactly once during - # construction — one per endpoint that needs a dedicated - # wire. dispatch_auth_iam is the dedicated factory for the - # AuthEndpoints forwarder (login / bootstrap / - # change-password), distinct from dispatch_global_service - # (the generic /api/v1/{kind} route). + # Each dispatcher factory is invoked once per endpoint that + # needs a dedicated wire. dispatch_auth_iam is shared by + # two endpoints — AuthEndpoints (login / bootstrap / + # change-password) and IamEndpoint (registry-driven + # /api/v1/iam) — so it's expected to be called twice. + # Both forwarders pin the dispatcher to kind=iam and reuse + # the same factory; they're distinct from + # dispatch_global_service (the generic /api/v1/{kind} route). mock_dispatcher_manager.dispatch_global_service.assert_called_once() - mock_dispatcher_manager.dispatch_auth_iam.assert_called_once() + assert mock_dispatcher_manager.dispatch_auth_iam.call_count == 2 mock_dispatcher_manager.dispatch_socket.assert_called_once() mock_dispatcher_manager.dispatch_flow_service.assert_called_once() mock_dispatcher_manager.dispatch_flow_import.assert_called_once() diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py index 23f22d30..6c3e323b 100644 --- a/tests/unit/test_gateway/test_socket_graceful_shutdown.py +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -25,11 +25,11 @@ from trustgraph.gateway.auth import Identity TEST_CAP = "graph:write" -def _valid_identity(roles=("admin",)): +def _valid_identity(): return Identity( - user_id="test-user", + handle="test-user", workspace="default", - roles=list(roles), + principal_id="test-user", source="api-key", ) @@ -37,11 +37,12 @@ def _valid_identity(roles=("admin",)): @pytest.fixture def mock_auth(): """Mock IAM-backed authenticator. Successful by default — - ``authenticate`` returns a valid admin identity. Tests that - need the auth failure path override the ``authenticate`` - attribute locally.""" + ``authenticate`` returns a valid identity and ``authorise`` + allows everything. Tests that need the failure paths override + the relevant attribute locally.""" auth = MagicMock() auth.authenticate = AsyncMock(return_value=_valid_identity()) + auth.authorise = AsyncMock(return_value=None) return auth @@ -135,6 +136,7 @@ async def test_handle_normal_flow(): """Valid bearer → handshake accepted, dispatcher created.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) + mock_auth.authorise = AsyncMock(return_value=None) dispatcher_created = False async def mock_dispatcher_factory(ws, running, match_info): @@ -192,6 +194,7 @@ async def test_handle_exception_group_cleanup(): """Test exception group triggers dispatcher cleanup.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) + mock_auth.authorise = AsyncMock(return_value=None) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() @@ -262,6 +265,7 @@ async def test_handle_dispatcher_cleanup_timeout(): """Test dispatcher cleanup with timeout.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) + mock_auth.authorise = AsyncMock(return_value=None) # Mock dispatcher that takes long to destroy mock_dispatcher = AsyncMock() @@ -388,6 +392,7 @@ async def test_handle_websocket_already_closed(): """Test handling when websocket is already closed.""" mock_auth = MagicMock() mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) + mock_auth.authorise = AsyncMock(return_value=None) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py index 5cfda7c8..f90694fc 100644 --- a/trustgraph-base/trustgraph/base/iam_client.py +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -1,4 +1,6 @@ +import json + from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import ( IamRequest, IamResponse, @@ -44,7 +46,13 @@ class IamClient(RequestResponse): Returns ``(user_id, workspace, roles)`` or raises ``RuntimeError`` with error type ``auth-failed`` if the key is - unknown / expired / revoked.""" + unknown / expired / revoked. + + Note: the ``roles`` value is a regime-internal hint and is + not used by the gateway directly under the IAM contract; + all authorisation decisions go through ``authorise()``. + Returned here only for backward compatibility with callers + that haven't migrated.""" resp = await self._request( operation="resolve-api-key", api_key=api_key, @@ -56,6 +64,40 @@ class IamClient(RequestResponse): list(resp.resolved_roles), ) + async def authorise(self, identity_handle, capability, + resource, parameters, timeout=IAM_TIMEOUT): + """Ask the IAM regime whether ``identity_handle`` may perform + ``capability`` on ``resource`` given ``parameters``. + + Implements the contract ``authorise(identity, capability, + resource, parameters) → (decision, ttl)``. Returns a tuple + ``(allow: bool, ttl_seconds: int)``. The TTL is the + regime's suggested cache lifetime for this decision; the + gateway honours it (clamped above by gateway-side policy).""" + resp = await self._request( + operation="authorise", + user_id=identity_handle, + capability=capability, + resource_json=json.dumps(resource or {}, sort_keys=True), + parameters_json=json.dumps(parameters or {}, sort_keys=True), + timeout=timeout, + ) + return resp.decision_allow, resp.decision_ttl_seconds + + async def authorise_many(self, identity_handle, checks, + timeout=IAM_TIMEOUT): + """Bulk authorise. ``checks`` is a list of dicts each + carrying ``capability``, ``resource``, and ``parameters``. + Returns a list of ``(allow, ttl)`` tuples in the same order.""" + resp = await self._request( + operation="authorise-many", + user_id=identity_handle, + authorise_checks=json.dumps(list(checks), sort_keys=True), + timeout=timeout, + ) + decisions = json.loads(resp.decisions_json or "[]") + return [(d.get("allow", False), d.get("ttl", 0)) for d in decisions] + async def create_user(self, workspace, user, actor="", timeout=IAM_TIMEOUT): """Create a user. ``user`` is a ``UserInput``.""" diff --git a/trustgraph-base/trustgraph/schema/services/iam.py b/trustgraph-base/trustgraph/schema/services/iam.py index 1e3ab1ab..4b5685a5 100644 --- a/trustgraph-base/trustgraph/schema/services/iam.py +++ b/trustgraph-base/trustgraph/schema/services/iam.py @@ -99,6 +99,21 @@ class IamRequest: workspace_record: WorkspaceInput | None = None key: ApiKeyInput | None = None + # ---- authorise / authorise-many inputs ---- + # Capability string from the vocabulary in capabilities.md. + capability: str = "" + # Resource identifier as JSON. See the IAM contract spec for + # the resource-component vocabulary. An empty dict denotes a + # system-level resource. + resource_json: str = "" + # Operation parameters as JSON. Decision-relevant fields the + # operation supplied that are not part of the resource address + # (e.g. workspace association on create-user). + parameters_json: str = "" + # For authorise-many: a JSON-serialised list of + # {"capability": str, "resource": dict, "parameters": dict}. + authorise_checks: str = "" + @dataclass class IamResponse: @@ -133,6 +148,18 @@ class IamResponse: bootstrap_admin_user_id: str = "" bootstrap_admin_api_key: str = "" + # ---- authorise / authorise-many outputs ---- + # authorise: the regime's allow / deny verdict. + decision_allow: bool = False + # Cache TTL the regime suggests, in seconds. Gateway respects + # this for both allow and deny decisions; bounded above by + # gateway-side policy (typically <= 60s). + decision_ttl_seconds: int = 0 + # authorise-many: a JSON-serialised list of {"allow": bool, + # "ttl": int} in the same order as the request's + # authorise_checks. + decisions_json: str = "" + error: Error | None = None diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index 95743261..6abcbe15 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -1,15 +1,16 @@ """ -IAM-backed authentication for the API gateway. +IAM-backed authentication and authorisation for the API gateway. -Replaces the legacy GATEWAY_SECRET shared-token Authenticator. The -gateway is now stateless with respect to credentials: it either -verifies a JWT locally using the active IAM signing public key, or -resolves an API key by hash with a short local cache backed by the -IAM service. +The gateway delegates both authentication ("who is this caller?") +and authorisation ("may they do this?") to the IAM regime via the +contract specified in docs/tech-specs/iam-contract.md. No regime- +specific policy (roles, scopes, claims) lives in the gateway. -Identity returned by authenticate() is the (user_id, workspace, -roles) triple the rest of the gateway — capability checks, workspace -resolver, audit logging — needs. +- Authentication: API keys are resolved by IAM; JWTs are validated + locally against the cached signing public key. +- Authorisation: every per-request decision is asked of IAM via + ``authorise(identity, capability, resource, parameters)``, with + results cached for the TTL the regime returns. """ import asyncio @@ -19,7 +20,7 @@ import json import logging import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from aiohttp import web @@ -37,12 +38,34 @@ logger = logging.getLogger("auth") API_KEY_CACHE_TTL = 60 # seconds +# Upper bound on cache TTL the gateway honours for an authorisation +# decision, regardless of what the regime suggested. Caps the +# revocation latency window. +AUTHZ_CACHE_TTL_MAX = 60 # seconds + @dataclass class Identity: - user_id: str + """The gateway-side surface of an authenticated caller. + + Per the IAM contract this is a small fixed shape; regime-internal + state (roles, claims, group memberships) is reachable only via + the regime's ``authorise`` operation. The gateway itself never + reads policy from this object. + """ + # Opaque handle, quoted back when calling ``authorise``. For + # the OSS regime this is the user record's id; the gateway + # treats it as a string with no semantic content. + handle: str + # The workspace this credential authenticates to. Used by the + # gateway as the default-fill-in for operations that omit a + # workspace. Never used as policy input. workspace: str - roles: list + # Stable identifier for audit logs. In OSS this is the same + # value as ``handle``; not assumed equal in the contract. + principal_id: str + # How the credential was presented. Non-policy; useful for + # logs / metrics only. source: str # "api-key" | "jwt" @@ -111,6 +134,13 @@ class IamAuth: self._key_cache = {} self._key_cache_lock = asyncio.Lock() + # Authorisation decision cache: hash(handle, capability, + # resource, parameters) -> (allow_bool, expires_ts). Holds + # both allows and denies — denies cached briefly to avoid + # hammering iam-svc with repeated rejected attempts. + self._authz_cache: dict[str, tuple[bool, float]] = {} + self._authz_cache_lock = asyncio.Lock() + # ------------------------------------------------------------------ # Short-lived client helper. Mirrors the pattern used by the # bootstrap framework and AsyncProcessor: a fresh uuid suffix per @@ -221,12 +251,13 @@ class IamAuth: sub = claims.get("sub", "") ws = claims.get("workspace", "") - roles = list(claims.get("roles", [])) if not sub or not ws: raise _auth_failure() + # JWT carries no policy state under the IAM contract; + # any roles / claims field is ignored here. return Identity( - user_id=sub, workspace=ws, roles=roles, source="jwt", + handle=sub, workspace=ws, principal_id=sub, source="jwt", ) async def _resolve_api_key(self, plaintext): @@ -245,7 +276,10 @@ class IamAuth: try: async def _call(client): return await client.resolve_api_key(plaintext) - user_id, workspace, roles = await self._with_client(_call) + # ``roles`` is returned by the OSS regime as a hint + # but is not consulted by the gateway; all policy + # decisions go through ``authorise``. + user_id, workspace, _roles = await self._with_client(_call) except Exception as e: logger.debug( f"API key resolution failed: " @@ -257,8 +291,81 @@ class IamAuth: raise _auth_failure() identity = Identity( - user_id=user_id, workspace=workspace, - roles=list(roles), source="api-key", + handle=user_id, workspace=workspace, + principal_id=user_id, source="api-key", ) self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL) return identity + + # ------------------------------------------------------------------ + # Authorisation + # ------------------------------------------------------------------ + + @staticmethod + def _authz_cache_key(handle, capability, resource, parameters): + payload = json.dumps( + { + "h": handle, + "c": capability, + "r": resource or {}, + "p": parameters or {}, + }, + sort_keys=True, + separators=(",", ":"), + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + async def authorise(self, identity, capability, resource, parameters): + """Ask the IAM regime whether ``identity`` may perform + ``capability`` on ``resource`` given ``parameters``. + + Caches the decision for the regime's suggested TTL, clamped + above by ``AUTHZ_CACHE_TTL_MAX``. Both allow and deny + decisions are cached (denies briefly, to avoid hammering + iam-svc with repeated rejected attempts). + + Raises ``HTTPForbidden`` (403 / "access denied") on a deny + decision. Raises ``HTTPUnauthorized`` (401 / "auth failure") + if the IAM service errors out — failing closed.""" + + key = self._authz_cache_key( + identity.handle, capability, resource, parameters, + ) + now = time.time() + + cached = self._authz_cache.get(key) + if cached and cached[1] > now: + allow, _ = cached + if not allow: + raise _access_denied() + return + + async with self._authz_cache_lock: + cached = self._authz_cache.get(key) + if cached and cached[1] > now: + allow, _ = cached + if not allow: + raise _access_denied() + return + + try: + async def _call(client): + return await client.authorise( + identity.handle, capability, + resource or {}, parameters or {}, + ) + allow, ttl = await self._with_client(_call) + except Exception as e: + logger.warning( + f"authorise failed: {type(e).__name__}: {e}; " + f"failing closed for " + f"{identity.principal_id!r} cap={capability!r}" + ) + raise _auth_failure() + + ttl = max(0, min(int(ttl or 0), AUTHZ_CACHE_TTL_MAX)) + self._authz_cache[key] = (bool(allow), now + ttl) + + if not allow: + raise _access_denied() + return diff --git a/trustgraph-flow/trustgraph/gateway/capabilities.py b/trustgraph-flow/trustgraph/gateway/capabilities.py index 15e25684..72ca51c7 100644 --- a/trustgraph-flow/trustgraph/gateway/capabilities.py +++ b/trustgraph-flow/trustgraph/gateway/capabilities.py @@ -1,36 +1,23 @@ """ -Capability vocabulary, role definitions, and authorisation helpers. +Gateway-side authorisation entry points. -See docs/tech-specs/capabilities.md for the authoritative description. -The data here is the OSS bundle table in that spec. Enterprise -editions may replace this module with their own role table; the -vocabulary (capability strings) is shared. +Under the IAM contract (see docs/tech-specs/iam-contract.md) the +gateway holds *no* policy state. Roles, capability sets, and +workspace-scope rules all live in the IAM regime (iam-svc for OSS). +This module is the thin surface the gateway uses to ask the regime +for a decision: -Role model ----------- -A role has two dimensions: +- ``PUBLIC`` / ``AUTHENTICATED`` sentinels for endpoints that don't + go through capability-based authorisation. +- :func:`enforce` — authenticate-only, then ask the regime. +- :func:`enforce_workspace` — default-fill the workspace from the + caller's bound workspace and ask the regime, with the workspace + treated as the resource address. - 1. **capability set** — which operations the role grants. - 2. **workspace scope** — which workspaces the role is active in. - -The authorisation question is: *given the caller's roles, a required -capability, and a target workspace, does any role grant the -capability AND apply to the target workspace?* - -Workspace scope values recognised here: - - - ``"assigned"`` — the role applies only to the caller's own - assigned workspace (stored on their user record). - - ``"*"`` — the role applies to every workspace. - -Enterprise editions can add richer scopes (explicit permitted-set, -patterns, etc.) without changing the wire protocol. - -Sentinels ---------- -- ``PUBLIC`` — endpoint requires no authentication. -- ``AUTHENTICATED`` — endpoint requires a valid identity, no - specific capability. +The capability strings themselves are an open vocabulary — see +docs/tech-specs/capabilities.md. The gateway does not validate them +beyond passing them through; an unknown capability simply produces a +deny verdict from the regime. """ from aiohttp import web @@ -40,125 +27,6 @@ PUBLIC = "__public__" AUTHENTICATED = "__authenticated__" -# Capability vocabulary. Mirrors the "Capability list" tables in -# capabilities.md. Kept as a set so the gateway can fail-closed on -# an endpoint that declares an unknown capability. -KNOWN_CAPABILITIES = { - # Data plane - "agent", - "graph:read", "graph:write", - "documents:read", "documents:write", - "rows:read", "rows:write", - "llm", - "embeddings", - "mcp", - # Control plane - "config:read", "config:write", - "flows:read", "flows:write", - "users:read", "users:write", "users:admin", - "keys:self", "keys:admin", - "workspaces:admin", - "iam:admin", - "metrics:read", - "collections:read", "collections:write", - "knowledge:read", "knowledge:write", -} - - -# Capability sets used below. -_READER_CAPS = { - "agent", - "graph:read", - "documents:read", - "rows:read", - "llm", - "embeddings", - "mcp", - "config:read", - "flows:read", - "collections:read", - "knowledge:read", - "keys:self", -} - -_WRITER_CAPS = _READER_CAPS | { - "graph:write", - "documents:write", - "rows:write", - "collections:write", - "knowledge:write", -} - -_ADMIN_CAPS = _WRITER_CAPS | { - "config:write", - "flows:write", - "users:read", "users:write", "users:admin", - "keys:admin", - "workspaces:admin", - "iam:admin", - "metrics:read", -} - - -# Role definitions. Each role has a capability set and a workspace -# scope. Enterprise overrides this mapping. -ROLE_DEFINITIONS = { - "reader": { - "capabilities": _READER_CAPS, - "workspace_scope": "assigned", - }, - "writer": { - "capabilities": _WRITER_CAPS, - "workspace_scope": "assigned", - }, - "admin": { - "capabilities": _ADMIN_CAPS, - "workspace_scope": "*", - }, -} - - -def _scope_permits(role_name, target_workspace, assigned_workspace): - """Does the given role apply to ``target_workspace``?""" - role = ROLE_DEFINITIONS.get(role_name) - if role is None: - return False - scope = role["workspace_scope"] - if scope == "*": - return True - if scope == "assigned": - return target_workspace == assigned_workspace - # Future scope types (lists, patterns) extend here. - return False - - -def check(identity, capability, target_workspace=None): - """Is ``identity`` permitted to invoke ``capability`` on - ``target_workspace``? - - Passes iff some role held by the caller both (a) grants - ``capability`` and (b) is active in ``target_workspace``. - - ``target_workspace`` defaults to the caller's assigned workspace, - which makes this function usable for system-level operations and - for authenticated endpoints that don't take a workspace argument - (the call collapses to "do any of my roles grant this cap?").""" - if capability not in KNOWN_CAPABILITIES: - return False - - target = target_workspace or identity.workspace - - for role_name in identity.roles: - role = ROLE_DEFINITIONS.get(role_name) - if role is None: - continue - if capability not in role["capabilities"]: - continue - if _scope_permits(role_name, target, identity.workspace): - return True - return False - - def access_denied(): return web.HTTPForbidden( text='{"error":"access denied"}', @@ -174,21 +42,19 @@ def auth_failure(): async def enforce(request, auth, capability): - """Authenticate + capability-check for endpoints that carry no - workspace dimension on the request (metrics, i18n, etc.). + """Authenticate the caller and (for non-sentinel capabilities) + ask the IAM regime whether they may invoke ``capability``. - For endpoints that carry a workspace field on the body, call - :func:`enforce_workspace` *after* parsing the body to validate - the workspace and re-check the capability in that scope. Most - endpoints do both. + The resource is system-level (``{}``) and parameters are empty — + use :func:`enforce_workspace` for workspace-scoped endpoints, or + drive authorisation through the operation registry for richer + cases. - - ``PUBLIC``: no authentication, returns ``None``. - - ``AUTHENTICATED``: any valid identity. - - capability string: identity must have it, checked against the - caller's assigned workspace (adequate for endpoints whose - capability is system-level, e.g. ``metrics:read``, or where - the real workspace-aware check happens in - :func:`enforce_workspace` after body parsing).""" + - ``PUBLIC``: returns ``None`` — no authentication. + - ``AUTHENTICATED``: returns the ``Identity`` — no authorisation. + - capability string: returns the ``Identity`` if the regime + allows; raises ``HTTPForbidden`` otherwise. + """ if capability == PUBLIC: return None @@ -197,42 +63,38 @@ async def enforce(request, auth, capability): if capability == AUTHENTICATED: return identity - if not check(identity, capability): - raise access_denied() - + await auth.authorise(identity, capability, {}, {}) return identity -def enforce_workspace(data, identity, capability=None): - """Resolve + validate the workspace on a request body. +async def enforce_workspace(data, identity, auth, capability=None): + """Default-fill the workspace on a request body and (optionally) + authorise the caller for ``capability`` against that workspace. - Target workspace = ``data["workspace"]`` if supplied, else the - caller's assigned workspace. - - At least one of the caller's roles must (a) be active in the - target workspace and, if ``capability`` is given, (b) grant - ``capability``. Otherwise 403. + caller's bound workspace. - On success, ``data["workspace"]`` is overwritten with the - resolved value — callers can rely on the outgoing message - having the gateway's chosen workspace rather than any - caller-supplied value. + resolved value so downstream code sees a single canonical + address. + - When ``capability`` is given, the regime is asked whether the + caller may invoke ``capability`` on ``{workspace: target}``. + Raises ``HTTPForbidden`` on a deny. - For ``capability=None`` the workspace scope alone is checked — - useful when the body has a workspace but the endpoint already - passed its capability check (e.g. via :func:`enforce`).""" + For ``capability=None`` no authorisation call is made — the + caller has presumably already authorised via :func:`enforce` + (handy for endpoints that authorise once then resolve workspace + on the body before forwarding). + """ if not isinstance(data, dict): return data requested = data.get("workspace", "") target = requested or identity.workspace + data["workspace"] = target - for role_name in identity.roles: - role = ROLE_DEFINITIONS.get(role_name) - if role is None: - continue - if capability is not None and capability not in role["capabilities"]: - continue - if _scope_permits(role_name, target, identity.workspace): - data["workspace"] = target - return data + if capability is not None: + await auth.authorise( + identity, capability, {"workspace": target}, {}, + ) - raise access_denied() + return data diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 013cd1ea..37a72f11 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -121,20 +121,45 @@ class Mux: }) return - # Workspace resolution. Role workspace scope determines - # which target workspaces are permitted. The resolved - # value is written to both the envelope and the inner - # request payload so clients don't have to repeat it - # per-message (same convenience HTTP callers get via - # enforce_workspace). + # Per-service capability gating. Resolved through the + # operation registry so the WS path matches what HTTP + # callers see — same authority, same caps. Service + # kinds that aren't registered are refused. + from ..registry import lookup as _registry_lookup from ..capabilities import enforce_workspace from aiohttp import web as _web + service = data.get("service", "") + op = _registry_lookup(f"flow-service:{service}") + if op is None: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "unknown service", + "type": "unknown-service", + }, + "complete": True, + }) + return + + # Workspace + flow form the resource address for a + # flow-level service call. Resolve workspace first + # (default-fill from the caller's bound workspace), + # then ask the regime to authorise the service-level + # capability against that {workspace, flow} resource. try: - enforce_workspace(data, self.identity) + await enforce_workspace(data, self.identity, self.auth) inner = data.get("request") if isinstance(inner, dict): - enforce_workspace(inner, self.identity) + await enforce_workspace(inner, self.identity, self.auth) + + resource = { + "workspace": data.get("workspace", ""), + "flow": data.get("flow", ""), + } + await self.auth.authorise( + self.identity, op.capability, resource, {}, + ) except _web.HTTPForbidden: await self.ws.send_json({ "id": request_id, @@ -145,6 +170,16 @@ class Mux: "complete": True, }) return + except _web.HTTPUnauthorized: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "auth failure", + "type": "auth-required", + }, + "complete": True, + }) + return workspace = data["workspace"] diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py index 6037fc4b..0b476b7b 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py @@ -97,7 +97,7 @@ class AuthEndpoints: ) req = { "operation": "change-password", - "user_id": identity.user_id, + "user_id": identity.handle, "password": body.get("current_password", ""), "new_password": body.get("new_password", ""), } diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py index ee9c0447..920b02ca 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py @@ -36,7 +36,7 @@ class ConstantEndpoint: data = await request.json() if identity is not None: - enforce_workspace(data, identity) + await enforce_workspace(data, identity, self.auth) async def responder(x, fin): pass diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py new file mode 100644 index 00000000..70fa33f7 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py @@ -0,0 +1,106 @@ +""" +Registry-driven /api/v1/iam endpoint. + +The gateway no longer gates IAM management with a single coarse +``users:admin`` capability. Instead, each operation declares its +own capability + resource shape in the registry (``registry.py``); +this endpoint reads the body's ``operation`` field, looks up the +declaration, and asks the IAM regime to authorise the call. + +Operations not in the registry produce a 400 ``unknown operation``. +This is the gateway's primary mechanism for fail-closed gating of +the IAM surface — the registry is the source of truth. +""" + +import logging + +from aiohttp import web + +from .. capabilities import ( + PUBLIC, AUTHENTICATED, auth_failure, +) +from .. registry import lookup, RequestContext + +logger = logging.getLogger("iam-endpoint") +logger.setLevel(logging.INFO) + + +class IamEndpoint: + """POST /api/v1/iam — generic forwarder gated by the operation + registry. The IAM dispatcher (``iam_dispatcher``) forwards the + body verbatim to iam-svc once authorisation succeeds.""" + + def __init__(self, endpoint_path, auth, dispatcher): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.post(self.path, self.handle)]) + + async def handle(self, request): + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + if not isinstance(body, dict): + return web.json_response( + {"error": "body must be an object"}, status=400, + ) + + op_name = body.get("operation", "") + op = lookup(op_name) + if op is None: + return web.json_response( + {"error": "unknown operation"}, status=400, + ) + + # Authentication: required for everything except PUBLIC. + identity = None + if op.capability != PUBLIC: + try: + identity = await self.auth.authenticate(request) + except web.HTTPException: + raise + + # Authorisation: capability sentinels short-circuit the + # regime call; capability strings go through authorise(). + if op.capability not in (PUBLIC, AUTHENTICATED): + ctx = RequestContext( + body=body, + match_info=dict(request.match_info), + identity=identity, + ) + try: + resource = op.extract_resource(ctx) + parameters = op.extract_parameters(ctx) + except Exception as e: + logger.warning( + f"extractor failed for {op_name!r}: " + f"{type(e).__name__}: {e}" + ) + return web.json_response( + {"error": "bad request"}, status=400, + ) + + await self.auth.authorise( + identity, op.capability, resource, parameters, + ) + + async def responder(x, fin): + pass + + try: + resp = await self.dispatcher.process(body, responder) + except web.HTTPException: + raise + except Exception as e: + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) + + return web.json_response(resp) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index 69b11e07..ed5ef4b5 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -9,90 +9,44 @@ from . socket import SocketEndpoint from . metrics import MetricsEndpoint from . i18n import I18nPackEndpoint from . auth_endpoints import AuthEndpoints +from . iam_endpoint import IamEndpoint +from . registry_endpoint import RegistryRoutedVariableEndpoint -from .. capabilities import PUBLIC, AUTHENTICATED +from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure +from .. registry import lookup as _registry_lookup, RequestContext from .. dispatch.manager import DispatcherManager -# Capability required for each kind on the /api/v1/{kind} generic -# endpoint (global services). Coarse gating — the IAM bundle split -# of "read vs write" per admin subsystem is not applied here because -# this endpoint forwards an opaque operation in the body. Writes -# are the upper bound on what the endpoint can do, so we gate on -# the write/admin capability. -GLOBAL_KIND_CAPABILITY = { - "config": "config:write", - "flow": "flows:write", - "librarian": "documents:write", - "knowledge": "knowledge:write", - "collection-management": "collections:write", - # IAM endpoints land on /api/v1/iam and require the admin bundle. - # Login / bootstrap / change-password are served by - # AuthEndpoints, which handle their own gating (PUBLIC / - # AUTHENTICATED). - "iam": "users:admin", -} +# /api/v1/{kind} (config / flow / librarian / knowledge / +# collection-management), /api/v1/iam, and /api/v1/flow/{flow}/... +# routes are all gated per-operation by the registry, not by a +# per-kind capability map. Login / bootstrap / change-password are +# served by AuthEndpoints with their own PUBLIC / AUTHENTICATED +# sentinels. -# Capability required for each kind on the -# /api/v1/flow/{flow}/service/{kind} endpoint (per-flow data-plane). -FLOW_KIND_CAPABILITY = { - "agent": "agent", - "text-completion": "llm", - "prompt": "llm", - "mcp-tool": "mcp", - "graph-rag": "graph:read", - "document-rag": "documents:read", - "embeddings": "embeddings", - "graph-embeddings": "graph:read", - "document-embeddings": "documents:read", - "triples": "graph:read", - "rows": "rows:read", - "nlp-query": "rows:read", - "structured-query": "rows:read", - "structured-diag": "rows:read", - "row-embeddings": "rows:read", - "sparql": "graph:read", -} - - -# Capability for the streaming flow import/export endpoints, -# keyed by the "kind" URL segment. -FLOW_IMPORT_CAPABILITY = { - "triples": "graph:write", - "graph-embeddings": "graph:write", - "document-embeddings": "documents:write", - "entity-contexts": "documents:write", - "rows": "rows:write", -} - -FLOW_EXPORT_CAPABILITY = { - "triples": "graph:read", - "graph-embeddings": "graph:read", - "document-embeddings": "documents:read", - "entity-contexts": "documents:read", -} - - -from .. capabilities import enforce, enforce_workspace import logging as _mgr_logging _mgr_logger = _mgr_logging.getLogger("endpoint") class _RoutedVariableEndpoint: - """HTTP endpoint whose required capability is looked up per - request from the URL's ``kind`` parameter. Used for the two - generic dispatch paths (``/api/v1/{kind}`` and - ``/api/v1/flow/{flow}/service/{kind}``). Self-contained rather - than subclassing ``VariableEndpoint`` to avoid mutating shared - state across concurrent requests.""" + """HTTP endpoint that gates per request via the operation + registry. The URL's ``kind`` parameter combined with a fixed + ``registry_prefix`` yields the registry key — e.g. prefix + ``flow-service`` and kind ``agent`` looks up + ``flow-service:agent``. - def __init__(self, endpoint_path, auth, dispatcher, capability_map): + Used for ``/api/v1/flow/{flow}/service/{kind}`` (per-flow + data-plane services). ``/api/v1/{kind}`` (workspace-level + global services) goes through ``RegistryRoutedVariableEndpoint`` + which discriminates on body operation as well as URL kind.""" + + def __init__(self, endpoint_path, auth, dispatcher, registry_prefix): self.path = endpoint_path self.auth = auth self.dispatcher = dispatcher - self._capability_map = capability_map + self._registry_prefix = registry_prefix async def start(self): pass @@ -102,18 +56,26 @@ class _RoutedVariableEndpoint: async def handle(self, request): kind = request.match_info.get("kind", "") - cap = self._capability_map.get(kind) - if cap is None: + op = _registry_lookup(f"{self._registry_prefix}:{kind}") + if op is None: return web.json_response( {"error": "unknown kind"}, status=404, ) - identity = await enforce(request, self.auth, cap) + identity = await self.auth.authenticate(request) try: data = await request.json() - if identity is not None: - enforce_workspace(data, identity) + ctx = RequestContext( + body=data if isinstance(data, dict) else {}, + match_info=dict(request.match_info), + identity=identity, + ) + resource = op.extract_resource(ctx) + parameters = op.extract_parameters(ctx) + await self.auth.authorise( + identity, op.capability, resource, parameters, + ) async def responder(x, fin): pass @@ -131,15 +93,15 @@ class _RoutedVariableEndpoint: class _RoutedSocketEndpoint: - """WebSocket endpoint whose required capability is looked up per - request from the URL's ``kind`` parameter. Used for the flow - import/export streaming endpoints.""" + """WebSocket endpoint gated per request via the operation + registry. Like ``_RoutedVariableEndpoint`` but for the + streaming flow import / export socket paths.""" - def __init__(self, endpoint_path, auth, dispatcher, capability_map): + def __init__(self, endpoint_path, auth, dispatcher, registry_prefix): self.path = endpoint_path self.auth = auth self.dispatcher = dispatcher - self._capability_map = capability_map + self._registry_prefix = registry_prefix async def start(self): pass @@ -148,11 +110,9 @@ class _RoutedSocketEndpoint: app.add_routes([web.get(self.path, self.handle)]) async def handle(self, request): - from .. capabilities import check, auth_failure, access_denied - kind = request.match_info.get("kind", "") - cap = self._capability_map.get(kind) - if cap is None: + op = _registry_lookup(f"{self._registry_prefix}:{kind}") + if op is None: return web.json_response( {"error": "unknown kind"}, status=404, ) @@ -168,8 +128,20 @@ class _RoutedSocketEndpoint: ) except web.HTTPException as e: return e - if not check(identity, cap): - return access_denied() + + ctx = RequestContext( + body={}, + match_info=dict(request.match_info), + identity=identity, + ) + try: + resource = op.extract_resource(ctx) + parameters = op.extract_parameters(ctx) + await self.auth.authorise( + identity, op.capability, resource, parameters, + ) + except web.HTTPException as e: + return e # Delegate the websocket handling to a standalone SocketEndpoint # with the resolved capability, bypassing the per-request mutation @@ -178,7 +150,7 @@ class _RoutedSocketEndpoint: endpoint_path=self.path, auth=self.auth, dispatcher=self.dispatcher, - capability=cap, + capability=op.capability, ) return await ws_ep.handle(request) @@ -203,6 +175,18 @@ class EndpointManager: auth=auth, ), + # /api/v1/iam — registry-driven IAM management. Per + # operation gating happens inside IamEndpoint via the + # operation registry; the dispatcher forwards verbatim + # to iam-svc once authorisation has succeeded. Listed + # before the generic /api/v1/{kind} route so it wins + # the match for "iam". + IamEndpoint( + endpoint_path="/api/v1/iam", + auth=auth, + dispatcher=dispatcher_manager.dispatch_auth_iam(), + ), + I18nPackEndpoint( endpoint_path="/api/v1/i18n/packs/{lang}", auth=auth, @@ -215,12 +199,16 @@ class EndpointManager: capability="metrics:read", ), - # Global services: capability chosen per-kind. - _RoutedVariableEndpoint( + # Global services: registry-driven per-operation gating. + # Each kind+op combination has a registry entry that + # declares its capability and resource shape. Listed + # after the IAM and auth-surface routes; aiohttp's + # path matcher prefers the more-specific path so this + # variable route doesn't shadow them. + RegistryRoutedVariableEndpoint( endpoint_path="/api/v1/{kind}", auth=auth, dispatcher=dispatcher_manager.dispatch_global_service(), - capability_map=GLOBAL_KIND_CAPABILITY, ), # /api/v1/socket: WebSocket handshake accepts @@ -236,26 +224,29 @@ class EndpointManager: in_band_auth=True, ), - # Per-flow request/response services — capability per kind. + # Per-flow request/response services — gated per + # ``flow-service:`` registry entry. _RoutedVariableEndpoint( endpoint_path="/api/v1/flow/{flow}/service/{kind}", auth=auth, dispatcher=dispatcher_manager.dispatch_flow_service(), - capability_map=FLOW_KIND_CAPABILITY, + registry_prefix="flow-service", ), - # Per-flow streaming import/export — capability per kind. + # Per-flow streaming import/export — gated per + # ``flow-import:`` / ``flow-export:`` registry + # entry. _RoutedSocketEndpoint( endpoint_path="/api/v1/flow/{flow}/import/{kind}", auth=auth, dispatcher=dispatcher_manager.dispatch_flow_import(), - capability_map=FLOW_IMPORT_CAPABILITY, + registry_prefix="flow-import", ), _RoutedSocketEndpoint( endpoint_path="/api/v1/flow/{flow}/export/{kind}", auth=auth, dispatcher=dispatcher_manager.dispatch_flow_export(), - capability_map=FLOW_EXPORT_CAPABILITY, + registry_prefix="flow-export", ), StreamEndpoint( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py new file mode 100644 index 00000000..296376fa --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py @@ -0,0 +1,123 @@ +""" +Registry-driven dispatch for ``/api/v1/{kind}`` global services. + +The body's ``operation`` field plus the URL's ``{kind}`` together +form the canonical operation name (``:``) that the +gateway looks up in ``registry.py``. The matched operation +declares its capability and resource shape; this endpoint asks the +IAM regime to authorise the call before forwarding the body +verbatim to the backend dispatcher. + +The dispatcher is the same ``dispatch_global_service()`` factory the +old coarse path used; only the gating layer has changed. + +Operations not present in the registry are rejected with 400 +``unknown operation`` — fail closed. +""" + +import logging + +from aiohttp import web + +from .. capabilities import ( + PUBLIC, AUTHENTICATED, auth_failure, +) +from .. registry import lookup, RequestContext + +logger = logging.getLogger("registry-endpoint") +logger.setLevel(logging.INFO) + + +class RegistryRoutedVariableEndpoint: + """POST /api/v1/{kind} — kind comes from the URL, operation comes + from the body, both are joined as the registry key.""" + + def __init__(self, endpoint_path, auth, dispatcher): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.post(self.path, self.handle)]) + + async def handle(self, request): + kind = request.match_info.get("kind", "") + if not kind: + return web.json_response( + {"error": "missing kind"}, status=404, + ) + + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + if not isinstance(body, dict): + return web.json_response( + {"error": "body must be an object"}, status=400, + ) + + op_name = body.get("operation", "") + if not op_name: + return web.json_response( + {"error": "missing operation"}, status=400, + ) + + registry_key = f"{kind}:{op_name}" + op = lookup(registry_key) + if op is None: + return web.json_response( + {"error": "unknown operation"}, status=400, + ) + + identity = None + if op.capability != PUBLIC: + identity = await self.auth.authenticate(request) + + if op.capability not in (PUBLIC, AUTHENTICATED): + ctx = RequestContext( + body=body, + match_info=dict(request.match_info), + identity=identity, + ) + try: + resource = op.extract_resource(ctx) + parameters = op.extract_parameters(ctx) + except Exception as e: + logger.warning( + f"extractor failed for {registry_key!r}: " + f"{type(e).__name__}: {e}" + ) + return web.json_response( + {"error": "bad request"}, status=400, + ) + + await self.auth.authorise( + identity, op.capability, resource, parameters, + ) + + # Default-fill workspace into the body so downstream + # dispatchers see the canonical resolved value. The + # extractor has already pulled the workspace out; + # mirror it back to the body for the verbatim forward. + if "workspace" in resource: + body["workspace"] = resource["workspace"] + + async def responder(x, fin): + pass + + try: + resp = await self.dispatcher.process( + body, responder, request.match_info, + ) + except web.HTTPException: + raise + except Exception as e: + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) + + return web.json_response(resp) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index 08629ea2..f53ad73b 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -5,7 +5,7 @@ import logging from .. running import Running from .. capabilities import ( - PUBLIC, AUTHENTICATED, check, auth_failure, access_denied, + PUBLIC, AUTHENTICATED, auth_failure, ) logger = logging.getLogger("socket") @@ -97,8 +97,12 @@ class SocketEndpoint: except web.HTTPException as e: return e if self.capability != AUTHENTICATED: - if not check(identity, self.capability): - return access_denied() + try: + await self.auth.authorise( + identity, self.capability, {}, {}, + ) + except web.HTTPException as e: + return e # 50MB max message size ws = web.WebSocketResponse(max_msg_size=52428800) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py index 5e0d9d21..6a336f42 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py @@ -36,7 +36,7 @@ class VariableEndpoint: data = await request.json() if identity is not None: - enforce_workspace(data, identity) + await enforce_workspace(data, identity, self.auth) async def responder(x, fin): pass diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py new file mode 100644 index 00000000..32a517a9 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -0,0 +1,515 @@ +""" +Gateway operation registry. + +Single declarative table mapping each operation the gateway +recognises to: + +- The capability the IAM regime is asked to authorise against. +- The resource level (system / workspace / flow) — determines the + shape of the resource identifier handed to ``authorise``. +- Extractors that build the resource and parameters from the + request context. + +This is a gateway-internal concept. It is not part of the IAM +contract — the contract specifies what arguments ``authorise`` +receives; the registry is how the gateway populates them. + +See docs/tech-specs/iam-contract.md for the contract and +docs/tech-specs/iam.md for the request anatomy. +""" + +from dataclasses import dataclass, field +from typing import Any, Callable + + +# Sentinels for operations that don't go through capability-based +# authorisation. Mirror the values used in capabilities.py so the +# gateway endpoint layer can recognise them uniformly. +PUBLIC = "__public__" +AUTHENTICATED = "__authenticated__" + + +class ResourceLevel: + """Where the operation's resource lives. + + ``SYSTEM`` — operation acts on a deployment-level resource + (the user registry, the workspace registry, + the signing key). resource = {}. Workspace, + if relevant, is a parameter, not an address. + + ``WORKSPACE`` — operation acts on something within a workspace + (config, library, knowledge, collections, flow + lifecycle). resource = {workspace}. + + ``FLOW`` — operation acts on something within a flow + within a workspace (graph, agent, llm, etc.). + resource = {workspace, flow}. + """ + SYSTEM = "system" + WORKSPACE = "workspace" + FLOW = "flow" + + +@dataclass +class RequestContext: + """The bundle of inputs the registry's extractors operate on. + Assembled by the gateway from the incoming request after + authentication.""" + + # Parsed JSON body (HTTP) or inner request payload (WebSocket). + body: dict = field(default_factory=dict) + + # URL path components (HTTP) or WebSocket envelope routing + # fields (id, service, workspace, flow). + match_info: dict = field(default_factory=dict) + + # Authenticated identity for default-fill-in. Always present + # by the time extractors run, except for PUBLIC operations + # where it is None. + identity: Any = None + + +@dataclass +class Operation: + """Declared operation the gateway can dispatch + authorise.""" + + # Canonical operation name (used for registry lookup, audit, + # debug logs). Mirrors the operation strings in the IAM + # service and other backends where applicable. + name: str + + # Capability required to invoke this operation. Either a + # string from the capability vocabulary in capabilities.md, or + # the PUBLIC / AUTHENTICATED sentinel for operations that + # don't go through capability-based authorisation. + capability: str + + # Where the operation's resource lives. Determines the + # shape of the resource argument passed to authorise. + resource_level: str + + # Build the resource identifier from the request context. + # Returns a dict with the appropriate components for the + # resource level: {} for SYSTEM, {workspace} for WORKSPACE, + # {workspace, flow} for FLOW. Default-fill-in of workspace + # from identity.workspace happens here when applicable. + extract_resource: Callable[[RequestContext], dict] + + # Build the parameters dict — decision-relevant fields the + # operation supplied that are not part of the resource + # address. E.g. workspace association on a system-level + # user-registry operation. + extract_parameters: Callable[[RequestContext], dict] + + +# --------------------------------------------------------------------------- +# Registry storage. +# --------------------------------------------------------------------------- + + +_REGISTRY: dict[str, Operation] = {} + + +def register(op: Operation) -> None: + if op.name in _REGISTRY: + raise RuntimeError( + f"operation {op.name!r} already registered" + ) + _REGISTRY[op.name] = op + + +def lookup(name: str) -> Operation | None: + return _REGISTRY.get(name) + + +def all_operations() -> list[Operation]: + return list(_REGISTRY.values()) + + +# --------------------------------------------------------------------------- +# Common extractor helpers. +# --------------------------------------------------------------------------- + + +def _empty_resource(_ctx: RequestContext) -> dict: + """System-level resource: empty dict.""" + return {} + + +def _workspace_from_body(ctx: RequestContext) -> dict: + """Workspace-level resource sourced from the request body's + workspace field, defaulting to the caller's bound workspace.""" + ws = (ctx.body.get("workspace") if isinstance(ctx.body, dict) else "") + if not ws and ctx.identity is not None: + ws = ctx.identity.workspace + return {"workspace": ws} + + +def _flow_from_match_info(ctx: RequestContext) -> dict: + """Flow-level resource sourced from URL path components or WS + envelope fields. Both ``workspace`` and ``flow`` are required; + no default-fill-in (the address is the operation's identity).""" + return { + "workspace": ctx.match_info.get("workspace", ""), + "flow": ctx.match_info.get("flow", ""), + } + + +def _no_parameters(_ctx: RequestContext) -> dict: + return {} + + +def _body_as_parameters(ctx: RequestContext) -> dict: + """All body fields are parameters — used when the operation's + body is small and uniformly decision-relevant (e.g. user- + registry ops where the body's user.workspace is what the + regime checks against the admin's scope).""" + return dict(ctx.body) if isinstance(ctx.body, dict) else {} + + +def _workspace_param_only(ctx: RequestContext) -> dict: + """Parameters dict carrying only the workspace association. + Used by system-level operations (e.g. user-registry ops) where + the workspace isn't part of the resource address but is the + field the regime uses to scope the admin's authority. + + Pulls the workspace from the inner ``user`` / ``workspace_record`` + body field if present (create-user, create-workspace), then from + the top-level body, then from the caller's bound workspace.""" + body = ctx.body if isinstance(ctx.body, dict) else {} + inner_user = body.get("user") if isinstance(body.get("user"), dict) else {} + inner_ws = ( + body.get("workspace_record") + if isinstance(body.get("workspace_record"), dict) else {} + ) + ws = ( + inner_user.get("workspace") + or inner_ws.get("id") + or body.get("workspace") + ) + if not ws and ctx.identity is not None: + ws = ctx.identity.workspace + return {"workspace": ws or ""} + + +# --------------------------------------------------------------------------- +# Operation registrations. +# +# The gateway looks operations up by their canonical name (the same +# string the request body / WS envelope carries in its ``operation`` +# field where applicable). Auth-surface operations (login, bootstrap, +# change-password) are not listed here — they have their own routes +# in auth_endpoints.py and use PUBLIC / AUTHENTICATED sentinels +# directly. Pure gateway↔IAM internal operations (resolve-api-key, +# authorise, authorise-many, get-signing-key-public) are likewise +# excluded; they are never invoked over the public API. +# --------------------------------------------------------------------------- + + +# IAM management operations. All routed through /api/v1/iam, body +# carries ``operation`` plus operation-specific fields. + +# User registry: SYSTEM-level resource (users are global, identified +# by handle). The admin's authority is scoped per workspace via the +# parameters {workspace} field — that's what the regime checks +# against the admin's role workspace_scope. +register(Operation( + name="create-user", + capability="users:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="list-users", + capability="users:read", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="get-user", + capability="users:read", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="update-user", + capability="users:write", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="disable-user", + capability="users:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="enable-user", + capability="users:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="delete-user", + capability="users:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) +register(Operation( + name="reset-password", + capability="users:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, +)) + + +# API keys: workspace-level resource — keys live within a workspace. +register(Operation( + name="create-api-key", + capability="keys:admin", + resource_level=ResourceLevel.WORKSPACE, + extract_resource=_workspace_from_body, + extract_parameters=_no_parameters, +)) +register(Operation( + name="list-api-keys", + capability="keys:admin", + resource_level=ResourceLevel.WORKSPACE, + extract_resource=_workspace_from_body, + extract_parameters=_no_parameters, +)) +register(Operation( + name="revoke-api-key", + capability="keys:admin", + resource_level=ResourceLevel.WORKSPACE, + extract_resource=_workspace_from_body, + extract_parameters=_no_parameters, +)) + + +# Workspace registry: SYSTEM-level resource (workspaces are the +# top-level addressable unit). No parameters — the workspace being +# acted on is identified by the body, not used as a scope cue. +register(Operation( + name="create-workspace", + capability="workspaces:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) +register(Operation( + name="list-workspaces", + capability="workspaces:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) +register(Operation( + name="get-workspace", + capability="workspaces:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) +register(Operation( + name="update-workspace", + capability="workspaces:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) +register(Operation( + name="disable-workspace", + capability="workspaces:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) + + +# Signing key: SYSTEM-level operational op. +register(Operation( + name="rotate-signing-key", + capability="iam:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) + + +# --------------------------------------------------------------------------- +# Auth-surface entries. +# +# Listed here so the registry is the one place the gateway looks for +# operation→capability mappings — including the sentinels for paths +# that don't go through capability-based authorisation. The actual +# routing is in auth_endpoints.py; these entries let the registry- +# driven dispatcher recognise the operation if it sees it on a +# generic path. +# --------------------------------------------------------------------------- + +register(Operation( + name="login", + capability=PUBLIC, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) +register(Operation( + name="bootstrap", + capability=PUBLIC, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) +register(Operation( + name="change-password", + capability=AUTHENTICATED, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) + + +# --------------------------------------------------------------------------- +# Generic kind/operation entries. +# +# Names are ``:`` so the registry key is unique +# across dispatchers. All entries below are workspace-level +# resources (workspace defaulted from the caller's bound workspace +# if absent). Read/write distinction maps to the existing +# ``:read`` / ``:write`` capability vocabulary +# defined in capabilities.md. +# --------------------------------------------------------------------------- + + +def _register_kind_op(kind: str, op: str, capability: str) -> None: + """Helper: register a workspace-level kind:op with the standard + extractors (workspace from body, no extra parameters).""" + register(Operation( + name=f"{kind}:{op}", + capability=capability, + resource_level=ResourceLevel.WORKSPACE, + extract_resource=_workspace_from_body, + extract_parameters=_no_parameters, + )) + + +# config: KV-style workspace config service. +for _op in ("get", "list", "getvalues", "getvalues-all-ws", "config"): + _register_kind_op("config", _op, "config:read") +for _op in ("put", "delete"): + _register_kind_op("config", _op, "config:write") + + +# flow: flow-blueprint and flow-lifecycle service. +for _op in ("list-blueprints", "get-blueprint", "list-flows", "get-flow"): + _register_kind_op("flow", _op, "flows:read") +for _op in ("put-blueprint", "delete-blueprint", "start-flow", "stop-flow"): + _register_kind_op("flow", _op, "flows:write") + + +# librarian: document storage and processing service. +for _op in ( + "get-document-metadata", "get-document-content", + "stream-document", "list-documents", "list-processing", + "get-upload-status", "list-uploads", +): + _register_kind_op("librarian", _op, "documents:read") +for _op in ( + "add-document", "remove-document", "update-document", + "add-processing", "remove-processing", + "begin-upload", "upload-chunk", "complete-upload", "abort-upload", +): + _register_kind_op("librarian", _op, "documents:write") + + +# knowledge: knowledge-graph core service. +for _op in ("get-kg-core", "list-kg-cores"): + _register_kind_op("knowledge", _op, "knowledge:read") +for _op in ("put-kg-core", "delete-kg-core", + "load-kg-core", "unload-kg-core"): + _register_kind_op("knowledge", _op, "knowledge:write") + + +# collection-management: workspace collection lifecycle. +_register_kind_op("collection-management", "list-collections", "collections:read") +for _op in ("update-collection", "delete-collection"): + _register_kind_op("collection-management", _op, "collections:write") + + +# --------------------------------------------------------------------------- +# Per-flow data-plane services. +# +# /api/v1/flow/{flow}/service/{kind} and the streaming +# /api/v1/flow/{flow}/{import,export}/{kind} paths. No body-level +# ``operation`` discriminator — the URL kind is the operation +# identity. Resource is FLOW level (workspace + flow). +# +# Names: ``flow-service:``, ``flow-import:``, +# ``flow-export:``. +# --------------------------------------------------------------------------- + + +def _register_flow_kind(prefix: str, kind: str, capability: str) -> None: + register(Operation( + name=f"{prefix}:{kind}", + capability=capability, + resource_level=ResourceLevel.FLOW, + extract_resource=_flow_from_match_info, + extract_parameters=_no_parameters, + )) + + +# Request/response services on /api/v1/flow/{flow}/service/{kind}. +_FLOW_SERVICES = { + "agent": "agent", + "text-completion": "llm", + "prompt": "llm", + "mcp-tool": "mcp", + "graph-rag": "graph:read", + "document-rag": "documents:read", + "embeddings": "embeddings", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "triples": "graph:read", + "rows": "rows:read", + "nlp-query": "rows:read", + "structured-query": "rows:read", + "structured-diag": "rows:read", + "row-embeddings": "rows:read", + "sparql": "graph:read", +} +for _kind, _cap in _FLOW_SERVICES.items(): + _register_flow_kind("flow-service", _kind, _cap) + + +# Streaming import socket endpoints. +_FLOW_IMPORTS = { + "triples": "graph:write", + "graph-embeddings": "graph:write", + "document-embeddings": "documents:write", + "entity-contexts": "documents:write", + "rows": "rows:write", +} +for _kind, _cap in _FLOW_IMPORTS.items(): + _register_flow_kind("flow-import", _kind, _cap) + + +# Streaming export socket endpoints. +_FLOW_EXPORTS = { + "triples": "graph:read", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "entity-contexts": "documents:read", +} +for _kind, _cap in _FLOW_EXPORTS.items(): + _register_flow_kind("flow-export", _kind, _cap) diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 6e7c7aa5..44c7df23 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -40,6 +40,78 @@ API_KEY_RANDOM_BYTES = 24 JWT_ISSUER = "trustgraph-iam" JWT_TTL_SECONDS = 3600 +# Default authorisation cache TTL the regime tells the gateway to +# observe. 60s is the OSS-spec maximum revocation latency: a role +# change, workspace disable, or key revoke takes effect within at +# most this much time. +AUTHZ_CACHE_TTL_SECONDS = 60 + + +# OSS regime role table. Lives here, not in the gateway — the +# gateway is regime-agnostic and must not encode policy. +# +# Each role has a capability set and a workspace scope. The +# evaluator (handle_authorise below) checks (a) that some role +# held by the caller grants the requested capability, and (b) +# that role's workspace scope permits the target workspace. + +_READER_CAPS = { + "agent", + "graph:read", + "documents:read", + "rows:read", + "llm", + "embeddings", + "mcp", + "config:read", + "flows:read", + "collections:read", + "knowledge:read", + "keys:self", +} + +_WRITER_CAPS = _READER_CAPS | { + "graph:write", + "documents:write", + "rows:write", + "collections:write", + "knowledge:write", +} + +_ADMIN_CAPS = _WRITER_CAPS | { + "config:write", + "flows:write", + "users:read", "users:write", "users:admin", + "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", +} + +ROLE_DEFINITIONS = { + "reader": { + "capabilities": _READER_CAPS, + "workspace_scope": "assigned", + }, + "writer": { + "capabilities": _WRITER_CAPS, + "workspace_scope": "assigned", + }, + "admin": { + "capabilities": _ADMIN_CAPS, + "workspace_scope": "*", + }, +} + + +def _scope_permits(role_scope, target_workspace, assigned_workspace): + """Does the given role apply to ``target_workspace``?""" + if role_scope == "*": + return True + if role_scope == "assigned": + return target_workspace == assigned_workspace + return False + def _now_iso(): return datetime.datetime.now(datetime.timezone.utc).isoformat() @@ -250,6 +322,10 @@ class IamService: return await self.handle_disable_workspace(v) if op == "rotate-signing-key": return await self.handle_rotate_signing_key(v) + if op == "authorise": + return await self.handle_authorise(v) + if op == "authorise-many": + return await self.handle_authorise_many(v) return _err( "invalid-argument", @@ -478,7 +554,7 @@ class IamService: ( id, ws, _username, _name, _email, password_hash, - roles, enabled, _mcp, _created, + _roles, enabled, _mcp, _created, ) = user_row if not enabled: @@ -496,11 +572,14 @@ class IamService: now_ts = int(_now_dt().timestamp()) exp_ts = now_ts + JWT_TTL_SECONDS + # Per the IAM contract the gateway never reads policy state + # from the credential — roles stay server-side, reachable + # only via authorise(). JWT carries identity + workspace + # binding only. claims = { "iss": JWT_ISSUER, "sub": id, "workspace": ws, - "roles": sorted(roles) if roles else [], "iat": now_ts, "exp": exp_ts, } @@ -1130,3 +1209,134 @@ class IamService: await self.table_store.delete_api_key(key_hash) return IamResponse() + + # ------------------------------------------------------------------ + # authorise / authorise-many + # + # The IAM contract (see docs/tech-specs/iam-contract.md) calls + # for the regime — not the gateway — to decide whether an + # identity may perform a capability on a resource given the + # operation's parameters. These two operations are the OSS + # regime's implementation of that contract. + # + # Inputs (on IamRequest): + # user_id — the identity handle (the gateway's + # opaque reference). For OSS this is the + # user record's id. + # capability — the capability string from the + # capabilities.md vocabulary. + # resource_json — JSON dict, the resource address + # ({} for system, {workspace} for + # workspace, {workspace, flow} for flow). + # parameters_json — JSON dict, decision-relevant operation + # parameters (e.g. workspace association + # on user-registry operations). + # authorise_checks — for authorise-many, a JSON list of + # {capability, resource, parameters}. + # + # Outputs (on IamResponse): + # decision_allow — single allow / deny verdict. + # decision_ttl_seconds — gateway cache TTL for this + # decision. + # decisions_json — for authorise-many, list of + # {allow, ttl} in request order. + # ------------------------------------------------------------------ + + def _decide(self, user_row, capability, resource, parameters): + """Single authorisation decision. Returns (allow, ttl).""" + + if user_row is None: + return False, AUTHZ_CACHE_TTL_SECONDS + + # user_row layout: + # 0:id 1:workspace 2:username 3:name 4:email 5:password_hash + # 6:roles 7:enabled 8:must_change_password 9:created + if not user_row[7]: # disabled + return False, AUTHZ_CACHE_TTL_SECONDS + + # Disabled workspace check (defense in depth — credentials + # bound to a disabled workspace shouldn't be able to act). + # Cheap; one row read. + # We do this only when a target workspace is involved, to + # avoid an extra read for system-level operations that + # bypass workspace altogether. + target_workspace = ( + (resource or {}).get("workspace") + or (parameters or {}).get("workspace") + ) + + roles = user_row[6] or set() + assigned_workspace = user_row[1] + + for role_name in roles: + defn = ROLE_DEFINITIONS.get(role_name) + if defn is None: + continue + if capability not in defn["capabilities"]: + continue + if target_workspace is None or _scope_permits( + defn["workspace_scope"], + target_workspace, + assigned_workspace, + ): + return True, AUTHZ_CACHE_TTL_SECONDS + + return False, AUTHZ_CACHE_TTL_SECONDS + + async def handle_authorise(self, v): + if not v.capability: + return _err("invalid-argument", "capability required") + if not v.user_id: + return _err("invalid-argument", "user_id (handle) required") + + try: + resource = json.loads(v.resource_json or "{}") + parameters = json.loads(v.parameters_json or "{}") + except json.JSONDecodeError as e: + return _err("invalid-argument", f"bad json: {e}") + + user_row = await self.table_store.get_user(v.user_id) + allow, ttl = self._decide( + user_row, v.capability, resource, parameters, + ) + return IamResponse( + decision_allow=allow, + decision_ttl_seconds=ttl, + ) + + async def handle_authorise_many(self, v): + if not v.user_id: + return _err("invalid-argument", "user_id (handle) required") + if not v.authorise_checks: + return _err("invalid-argument", "authorise_checks required") + + try: + checks = json.loads(v.authorise_checks) + except json.JSONDecodeError as e: + return _err("invalid-argument", f"bad json: {e}") + if not isinstance(checks, list): + return _err( + "invalid-argument", + "authorise_checks must be a JSON list", + ) + + # One user lookup for the whole batch. + user_row = await self.table_store.get_user(v.user_id) + + decisions = [] + for c in checks: + if not isinstance(c, dict): + decisions.append({ + "allow": False, + "ttl": AUTHZ_CACHE_TTL_SECONDS, + }) + continue + allow, ttl = self._decide( + user_row, + c.get("capability", ""), + c.get("resource") or {}, + c.get("parameters") or {}, + ) + decisions.append({"allow": allow, "ttl": ttl}) + + return IamResponse(decisions_json=json.dumps(decisions)) From 6302eb8c97a77c32217d98b6ded2c5beee657a01 Mon Sep 17 00:00:00 2001 From: Trevin Chow Date: Tue, 28 Apr 2026 08:33:49 -0700 Subject: [PATCH 19/21] test(ontology): harden domain/range validation + add missing tests (#848) Fixes #826. Addresses all five points the maintainer called out in the follow-up to #825. Source change (trustgraph-flow/trustgraph/extract/kg/ontology/extract.py): - Added `_is_subclass_of(cls, target, ontology_subset, max_depth=100)` helper with visited-set cycle detection + a defensive depth cap. LLM-generated ontologies may emit cycles (A subclass_of B, B subclass_of A); the prior while-loop would infinite-loop on that. - Replaced both near-identical domain and range subclass walks in `is_valid_triple` with a single call to the new helper. Net is -20 duplicated lines + 26-line helper. Tests (tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py): - test_is_valid_triple_subclass_is_accepted: domain expects Recipe, actual type is Cake (subclass), validates. - test_is_valid_triple_handles_subclass_cycle_without_infinite_loop: A subclass_of B, B subclass_of A; call returns False within the depth cap rather than hanging. - test_parse_and_validate_triples_collects_entity_types_from_rdf_type: end-to-end path: rdf:type triples build the entity_types dict, subsequent domain-check triples validate against it. - test_is_valid_triple_entity_types_none_default: the None default path now has explicit coverage. 156 existing tests in tests/unit/test_extract/test_ontology still pass. --- .../test_prompt_and_extraction.py | 72 +++++++++++++++++++ .../trustgraph/extract/kg/ontology/extract.py | 60 +++++++++------- 2 files changed, 107 insertions(+), 25 deletions(-) diff --git a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py index bae6bdbd..6a2048a5 100644 --- a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py +++ b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py @@ -277,6 +277,60 @@ class TestTripleValidation: is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid) assert not is_invalid, "Invalid range should be rejected" + def test_is_valid_triple_subclass_is_accepted(self, extractor, sample_ontology_subset): + """Domain check passes when actual type is a subclass of expected.""" + sample_ontology_subset.classes["Cake"] = { + "uri": "http://purl.org/ontology/fo/Cake", + "type": "owl:Class", + "subclass_of": "Recipe", + } + sample_ontology_subset.object_properties["has_ingredient"] = { + "domain": "Recipe", + "range": "Ingredient", + } + + result = extractor.is_valid_triple( + subject="cake:lemon-drizzle", + predicate="has_ingredient", + object_val="ingredient:lemon", + ontology_subset=sample_ontology_subset, + entity_types={"cake:lemon-drizzle": "Cake", "ingredient:lemon": "Ingredient"}, + ) + + assert result is True + + def test_is_valid_triple_handles_subclass_cycle_without_infinite_loop(self, extractor, sample_ontology_subset): + """A cycle in subclass_of must return False instead of hanging.""" + sample_ontology_subset.classes["A"] = {"subclass_of": "B"} + sample_ontology_subset.classes["B"] = {"subclass_of": "A"} + sample_ontology_subset.object_properties["p"] = {"domain": "Recipe", "range": "Ingredient"} + + result = extractor.is_valid_triple( + subject="entity:x", + predicate="p", + object_val="ingredient:y", + ontology_subset=sample_ontology_subset, + entity_types={"entity:x": "A", "ingredient:y": "Ingredient"}, + ) + + assert result is False + + def test_is_valid_triple_entity_types_none_default(self, extractor, sample_ontology_subset): + """entity_types=None should not raise; domain/range checks skip if type unknown.""" + sample_ontology_subset.object_properties["has_ingredient"] = { + "domain": "Recipe", + "range": "Ingredient", + } + + result = extractor.is_valid_triple( + subject="recipe:x", + predicate="has_ingredient", + object_val="ingredient:y", + ontology_subset=sample_ontology_subset, + ) + + assert result is True + class TestTripleParsing: """Test suite for parsing triples from LLM responses.""" @@ -377,6 +431,24 @@ class TestTripleParsing: assert triple.p.type == IRI, "Predicate should be IRI type" assert triple.o.type == LITERAL, "Object literal should be LITERAL type" + def test_parse_and_validate_triples_collects_entity_types_from_rdf_type(self, extractor, sample_ontology_subset): + """entity_types should be built from rdf:type triples in the same batch.""" + sample_ontology_subset.object_properties["has_ingredient"] = { + "domain": "Recipe", + "range": "Ingredient", + } + triples_response = [ + {"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}, + {"subject": "ingredient:beef", "predicate": "rdf:type", "object": "Ingredient"}, + {"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:beef"}, + ] + + valid_triples = extractor.parse_and_validate_triples( + triples_response, sample_ontology_subset + ) + + assert len(valid_triples) == 3 + class TestURIExpansionInExtraction: """Test suite for URI expansion during triple extraction.""" diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index ef9a7331..1d45d3f9 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -540,6 +540,32 @@ class Processor(FlowProcessor): return True return False + def _is_subclass_of(self, cls, target, ontology_subset, max_depth=100): + """Return True if cls is a subclass of target via subclass_of chain. + + Defends against cycles in ontology data (LLM-generated ontologies may + emit A subclass_of B, B subclass_of A) with a visited set. A depth cap + acts as a second line of defense against unbounded chains. + """ + if cls == target: + return True + visited = set() + curr = cls + depth = 0 + while curr in ontology_subset.classes and depth < max_depth: + if curr in visited: + return False # cycle detected + visited.add(curr) + cls_def = ontology_subset.classes[curr] + parent = cls_def.get('subclass_of') if isinstance(cls_def, dict) else None + if parent is None: + return False + if parent == target: + return True + curr = parent + depth += 1 + return False + def is_valid_triple(self, subject: str, predicate: str, object_val: str, ontology_subset: OntologySubset, entity_types: dict = None) -> bool: """Validate triple against ontology constraints.""" @@ -570,36 +596,20 @@ class Processor(FlowProcessor): expected_domain = prop_def.get('domain') if expected_domain and subject in entity_types: actual_domain = entity_types[subject] - if actual_domain != expected_domain: - is_subclass = False - curr_class = actual_domain - while curr_class in ontology_subset.classes: - cls_def = ontology_subset.classes[curr_class] - parent = cls_def.get('subclass_of') if isinstance(cls_def, dict) else None - if parent == expected_domain: - is_subclass = True - break - curr_class = parent - if not is_subclass: - return False + if actual_domain != expected_domain and not self._is_subclass_of( + actual_domain, expected_domain, ontology_subset + ): + return False # Range validation if is_obj_prop: expected_range = prop_def.get('range') if expected_range and object_val in entity_types: actual_range = entity_types[object_val] - if actual_range != expected_range: - is_subclass = False - curr_class = actual_range - while curr_class in ontology_subset.classes: - cls_def = ontology_subset.classes[curr_class] - parent = cls_def.get('subclass_of') if isinstance(cls_def, dict) else None - if parent == expected_range: - is_subclass = True - break - curr_class = parent - if not is_subclass: - return False + if actual_range != expected_range and not self._is_subclass_of( + actual_range, expected_range, ontology_subset + ): + return False return True @@ -988,4 +998,4 @@ class Processor(FlowProcessor): def run(): """Launch the OntoRAG extraction service.""" - Processor.launch(default_ident, __doc__) \ No newline at end of file + Processor.launch(default_ident, __doc__) From 9fc1d4527b913df734bac90c38e8089f7c77eb48 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 28 Apr 2026 22:13:12 +0100 Subject: [PATCH 20/21] iam: self-service ops, optional workspace filters, Mux service routing (#855) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three threads, all reinforcing the contract's system-level vs. workspace-association distinction. WS Mux service routing - tg-show-flows (and any workspace-level service over the WS) was failing with "unknown service" because the post-refactor Mux unconditionally looked up flow-service:. Now branches on the envelope's flow field: with flow → flow-service:; without flow → : from the inner body; with bare op lookup for service=iam. Resource and parameters come from the matched op's own extractors — same path the HTTP endpoints take. Optional workspace on system-level user/key ops - list-users returns the deployment-wide list when no workspace is supplied, filters when one is. get-user, update-user, disable-user, enable-user, delete-user, reset-password, create-api-key, list-api-keys, revoke-api-key all treat workspace as an optional integrity check rather than a required argument. - create-user keeps workspace required — there it's the new user's home-workspace binding, a parameter rather than an address. - API keys reclassified as SYSTEM-level resources. By the same reasoning that makes users system-level, an API key is a credential record on a deployment-wide registry; the workspace it authenticates to is a property, not a containment. Self-service surface - whoami: returns the caller's own user record. AUTHENTICATED-only; no users:read capability required. Foundation for UI affordances that depend on the caller's permissions. - bootstrap-status: POST /api/v1/auth/bootstrap-status, PUBLIC, side-effect-free. Returns {bootstrap_available: bool} so a first-run UI can decide whether to render setup without consuming the bootstrap op. - Gateway now injects actor=identity.handle on every authenticated forward to iam-svc (IamEndpoint and WS Mux iam path), overwriting any caller-supplied value. Underpins whoami, audit logging, and future regime-side decisions that need actor identity. - tg-whoami and tg-update-user CLIs. Spec polish - iam-contract.md: actor-injection rule documented; whoami / bootstrap-status added to operations list; permission-scope framing tightened (workspace scope is a property of the grant, not the user or role). - iam.md: self-service section; gateway flow gains the actor- injection step; role section reframed so iam-svc constraints don't leak into contract-level prose. - iam-protocol.md: ops table updated for whoami, bootstrap-status, optional-workspace pattern; bootstrap_available added to the IamResponse listing. --- docs/tech-specs/iam-contract.md | 71 +++++-- docs/tech-specs/iam-protocol.md | 43 ++-- docs/tech-specs/iam.md | 45 ++++- trustgraph-base/trustgraph/base/iam_client.py | 21 ++ .../trustgraph/messaging/translators/iam.py | 4 + .../trustgraph/schema/services/iam.py | 4 + trustgraph-cli/pyproject.toml | 2 + trustgraph-cli/trustgraph/cli/update_user.py | 125 ++++++++++++ trustgraph-cli/trustgraph/cli/whoami.py | 52 +++++ .../trustgraph/gateway/dispatch/mux.py | 69 +++++-- .../gateway/endpoint/auth_endpoints.py | 16 ++ .../gateway/endpoint/iam_endpoint.py | 8 + .../trustgraph/gateway/registry.py | 38 +++- trustgraph-flow/trustgraph/iam/service/iam.py | 190 ++++++++++-------- trustgraph-flow/trustgraph/tables/iam.py | 14 ++ 15 files changed, 555 insertions(+), 147 deletions(-) create mode 100644 trustgraph-cli/trustgraph/cli/update_user.py create mode 100644 trustgraph-cli/trustgraph/cli/whoami.py diff --git a/docs/tech-specs/iam-contract.md b/docs/tech-specs/iam-contract.md index 3289add1..da23fb31 100644 --- a/docs/tech-specs/iam-contract.md +++ b/docs/tech-specs/iam-contract.md @@ -83,17 +83,16 @@ The four arguments separate concerns: identifier. See *The Resource model* below. - **`parameters`** — operation-specific data that the regime may need to consider beyond the resource identifier. Used when a - decision depends on attributes the request supplies — e.g. an - admin scoped to one workspace creating a user *with workspace - association W*: the resource is the system-level user registry, - and W is a parameter the regime checks against the admin's - scope. + decision depends on attributes the request supplies — e.g. + creating a user *with workspace association W*: the resource is + the system-level user registry, and W is a parameter the regime + checks against the caller's permissions for `users:write`. -Different regimes use the four arguments differently — the OSS -regime checks role bundles against the capability and the role's -workspace scope against parameters; an SSO regime might consult an -upstream IdP's group memberships; an ABAC regime evaluates a -policy with all four as inputs. The contract is unchanged. +Different regimes use the four arguments differently — one regime +might evaluate role bundles whose grants carry workspace scope; +another might consult upstream IdP group memberships; an ABAC +regime evaluates a policy with all four as inputs. The contract +is unchanged. ### `authorise_many` @@ -129,14 +128,49 @@ most of them) but the operation set the gateway can forward is: `revoke-api-key`, `change-password`, `reset-password` - Workspace management: `create-workspace`, `list-workspaces`, `get-workspace`, `update-workspace`, `disable-workspace` -- Session management: `login` +- Session management: `login`, `whoami` - Key management: `get-signing-key-public`, `rotate-signing-key` -- Bootstrap: `bootstrap` +- Bootstrap: `bootstrap`, `bootstrap-status` + +`whoami` is the self-read counterpart to `get-user`: any +authenticated caller can read their own identity record without +holding a user-management capability. It is the gating-free probe +a UI uses to render affordances appropriate to the caller's role. + +`bootstrap-status` is a side-effect-free probe of whether an +unconsumed `bootstrap` call would currently succeed. It exists so +a first-run UI can decide whether to render setup without invoking +the consuming `bootstrap` op. Public — no authentication. A regime that does not support one of these (e.g. an SSO regime where users are managed in the IdP) returns a defined "not supported" error; the gateway surfaces it as a 501. +### Actor injection + +For any management operation forwarded by the gateway after +authentication, the gateway injects the authenticated caller's +`handle` as an `actor` field on the request. Regimes use `actor` +to identify *who is making the request* — distinct from the +operation's target (which lives in `user_id` / `key_id` / +`workspace_record` / etc.) — for purposes such as: + +- Self-service operations (`whoami`, `change-password`) that + resolve "the caller" without taking a target argument. +- Audit logging, where the actor is recorded against the change. +- Decisions that depend on the resolved resource state. The + gateway authorises against the parameters on the request, but it + cannot know the resolved resource's actual properties (e.g. the + workspace association of a target user) before the regime loads + it. When that matters, the regime can re-decide using the + actor's permissions and the resolved record — closing a class + of cases the gateway-side check can't see. + +Caller-supplied `actor` values on the request body are overwritten +by the gateway — the gateway is the only authority for actor +identity, and a regime that consults `actor` can rely on it being +authentic. + ## The `Identity` surface `Identity` is *mostly* opaque. The gateway holds the value as a @@ -327,13 +361,16 @@ contract via: - Credentials are API keys (opaque) or JWTs (Ed25519, locally validated by the gateway against the regime's published public key). -- `authorise` reduces to a role-and-workspace-scope check against - the role table defined in [`capabilities.md`](capabilities.md). +- `authorise` reduces to a lookup against the role bundles in + [`capabilities.md`](capabilities.md), with each grant's workspace + scope checked against the operation's workspace component. - Identity, user, and workspace records live in Cassandra. -The OSS regime is deliberately simple — three roles, single -home-workspace per user (a regime data-model decision, not a -contract assertion), no policy language. +The OSS regime is deliberately simple — three roles, a single +workspace association per user (a regime data-model decision, not +a contract assertion), no policy language. Other regimes can +grant the same user different permissions in different workspaces +without changing anything outside the regime. ### Future regimes diff --git a/docs/tech-specs/iam-protocol.md b/docs/tech-specs/iam-protocol.md index 603d1c06..e7e7984e 100644 --- a/docs/tech-specs/iam-protocol.md +++ b/docs/tech-specs/iam-protocol.md @@ -72,10 +72,16 @@ class IamRequest: # login). workspace: str = "" - # Acting user id, for audit. Set by the gateway to the - # authenticated caller's id on user-initiated operations. - # Empty for internal-origin (bootstrap, reconcilers) and for - # resolve-api-key / login (no actor yet). + # Acting user id. Set by the gateway to the authenticated + # caller's identity handle for every authenticated request + # (overwrites any caller-supplied value — the gateway is the + # only authority for actor identity, so handlers can rely on it + # being authentic). Used for audit logging, self-service ops + # like ``whoami`` that resolve "the caller", and future actor- + # scoped policy checks. Empty for unauthenticated ops + # (``login``, ``bootstrap``, ``bootstrap-status``, + # ``get-signing-key-public``, ``resolve-api-key``). See the + # actor-injection rule in the IAM contract spec. actor: str = "" # --- identity selectors --- @@ -135,6 +141,11 @@ class IamResponse: bootstrap_admin_user_id: str = "" bootstrap_admin_api_key: str = "" + # bootstrap-status: true iff an unconsumed ``bootstrap`` call + # would currently succeed. Always emitted by the response + # translator (the false case is meaningful for first-run UIs). + bootstrap_available: bool = False + # Present on any failed operation. error: Error | None = None ``` @@ -201,25 +212,29 @@ class ApiKeyRecord: | Operation | Request fields | Response fields | Notes | |---|---|---|---| | `login` | `username`, `password`, `workspace` (optional) | `jwt`, `jwt_expires` | If `workspace` omitted, IAM resolves to the user's assigned workspace. | +| `whoami` | `actor` (gateway-injected) | `user` | Returns the calling user's own record. AUTHENTICATED-only; no `users:read` capability required. | | `resolve-api-key` | `api_key` (plaintext) | `resolved_user_id`, `resolved_workspace`, `resolved_roles` | Gateway-internal. Service returns `auth-failed` for unknown / expired / revoked keys. | | `change-password` | `user_id`, `password` (current), `new_password` | — | Self-service. IAM validates `password` against stored hash. | -| `reset-password` | `user_id` | `temporary_password` | Admin-initiated. IAM generates a random password, sets `must_change_password=true` on the user, returns the plaintext once. | -| `create-user` | `workspace`, `user` | `user` | Admin-only. `user.password` is hashed and stored; `user.roles` must be subset of known roles. | -| `list-users` | `workspace` | `users` | | -| `get-user` | `workspace`, `user_id` | `user` | | -| `update-user` | `workspace`, `user_id`, `user` | `user` | `password` field on `user` is rejected; use `change-password` / `reset-password`. | -| `disable-user` | `workspace`, `user_id` | — | Soft-delete; sets `enabled=false`. Revokes all the user's API keys. | +| `reset-password` | `user_id`, `workspace` (optional integrity check) | `temporary_password` | Admin-initiated. IAM generates a random password, sets `must_change_password=true` on the user, returns the plaintext once. | +| `create-user` | `workspace`, `user` | `user` | `user.password` is hashed and stored; `user.roles` must be subset of known roles. `workspace` is the new user's home-workspace binding (a required *parameter*, not an address). | +| `list-users` | `workspace` (optional filter) | `users` | If `workspace` omitted, returns the deployment-wide list. | +| `get-user` | `user_id`, `workspace` (optional integrity check) | `user` | | +| `update-user` | `user_id`, `user`, `workspace` (optional integrity check) | `user` | `password` field on `user` is rejected; use `change-password` / `reset-password`. Username is immutable. | +| `disable-user` | `user_id`, `workspace` (optional integrity check) | — | Soft-delete; sets `enabled=false`. Revokes all the user's API keys. | +| `enable-user` | `user_id`, `workspace` (optional integrity check) | — | Re-enables a previously disabled user; does not restore API keys. | +| `delete-user` | `user_id`, `workspace` (optional integrity check) | — | Hard-delete; removes user record, username lookup, and all the user's API keys. | | `create-workspace` | `workspace_record` | `workspace` | System-level. | | `list-workspaces` | — | `workspaces` | System-level. | | `get-workspace` | `workspace_record` (id only) | `workspace` | System-level. | | `update-workspace` | `workspace_record` | `workspace` | System-level. | | `disable-workspace` | `workspace_record` (id only) | — | System-level. Sets `enabled=false`; revokes all workspace API keys; disables all users in the workspace. | -| `create-api-key` | `workspace`, `key` | `api_key_plaintext`, `api_key` | Plaintext returned **once**; only hash stored. `key.name` required. | -| `list-api-keys` | `workspace`, `user_id` | `api_keys` | | -| `revoke-api-key` | `workspace`, `key_id` | — | Deletes the key record. | +| `create-api-key` | `key`, `workspace` (optional integrity check) | `api_key_plaintext`, `api_key` | Plaintext returned **once**; only hash stored. `key.name` required. | +| `list-api-keys` | `user_id`, `workspace` (optional integrity check) | `api_keys` | | +| `revoke-api-key` | `key_id`, `workspace` (optional integrity check) | — | Deletes the key record. | | `get-signing-key-public` | — | `signing_key_public` | Gateway fetches this at startup. | | `rotate-signing-key` | — | — | System-level. Introduces a new signing key; old key continues to validate JWTs for a grace period (implementation-defined, minimum 1h). | -| `bootstrap` | — | `bootstrap_admin_user_id`, `bootstrap_admin_api_key` | If IAM tables are empty, creates the initial `default` workspace, an `admin` user, an initial API key, and an initial signing key; returns them once. No-op on subsequent calls (returns empty fields). | +| `bootstrap` | — | `bootstrap_admin_user_id`, `bootstrap_admin_api_key` | If IAM tables are empty and the service is in `bootstrap` mode, creates the initial `default` workspace, an `admin` user, an initial API key, and an initial signing key; returns them once. Otherwise returns a masked auth failure. | +| `bootstrap-status` | — | `bootstrap_available` | Side-effect-free probe; `true` iff iam-svc is in `bootstrap` mode and tables are empty. Intended for first-run UX. | ## Error taxonomy diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md index a764535e..dd0e12f5 100644 --- a/docs/tech-specs/iam.md +++ b/docs/tech-specs/iam.md @@ -268,6 +268,26 @@ The gateway forwards this to the IAM service, which validates credentials and returns a signed JWT. The gateway returns the JWT to the caller. +#### Self-service: `whoami` and `bootstrap-status` + +Two side-effect-free probes that exist to support UI affordances +without giving the caller broad read access: + +- `POST /api/v1/iam` with `{"operation": "whoami"}` — authenticated + only. Returns the caller's own user record (id, username, name, + email, workspace, roles, enabled, must_change_password, + created). No `users:read` capability is required, because every + authenticated caller can read themselves. The gateway populates + `actor` on the request from the authenticated identity, so the + regime resolves "the caller" without taking a target argument. + +- `POST /api/v1/auth/bootstrap-status` — public, side-effect-free. + Returns `{"bootstrap_available": true|false}`. `true` iff + iam-svc is in `bootstrap` mode and its tables are empty (i.e. an + unconsumed `bootstrap` call would currently succeed). Exists so + a first-run UI can decide whether to render the setup flow + without invoking the consuming `bootstrap` op. + #### IAM service delegation The gateway stays thin. Its authentication logic is: @@ -387,9 +407,10 @@ workspace; every `authorise` call sees a concrete value. Whether the resolved workspace is permitted to be operated on by this caller is an **IAM decision**, not a gateway one. The gateway calls `authorise(identity, capability, {workspace: ..., ...})` and -relays the answer. In the OSS regime, the answer comes from the -caller's role × workspace-scope — see [`capabilities.md`](capabilities.md). -In other regimes it could come from group mappings, policies, +relays the answer. In the OSS regime, the regime checks whether +the caller's permission grants for `` include this +workspace — see [`capabilities.md`](capabilities.md). In other +regimes the decision could come from group mappings, policies, relationship tuples, or anything else the regime models. ### Request anatomy @@ -500,8 +521,19 @@ The OSS regime ships three roles: | `writer` | All reader capabilities, plus `graph:write`, `documents:write`, `rows:write`, `knowledge:write`, `collections:write`. | | `admin` | All writer capabilities, plus `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read`. | -Workspace scope: `reader` and `writer` are active only in the -caller's bound workspace; `admin` is active across all workspaces. +Workspace scope is a property of the *grant*, not of the user or +role. In the OSS regime each capability granted by `reader` / +`writer` is scoped to the workspace the user record is associated +with; capabilities granted by `admin` are scoped to `*` (every +workspace). A user is a system-level object — they don't "live +in" a workspace, they hold permissions whose scope happens to +reference one. + +The OSS regime is deliberately limited to one workspace association +per user; future regimes are free to grant the same user different +permissions in different workspaces, or use a non-workspace scope +entirely. This is regime-internal — neither the contract nor the +gateway carries an assumption either way. The gateway gates each endpoint by *capability*, not by role. Capabilities are declared per operation in the gateway's operation @@ -647,6 +679,9 @@ For HTTP requests: error, fail closed (401 / 503 per deployment). 8. Cache the decision per the contract's caching rules (clamped above by a deployment-set ceiling). +9. For requests forwarded to iam-svc, set `actor` on the body + from `identity.handle`, overwriting any caller-supplied value. + See [`iam-contract.md`](iam-contract.md#actor-injection). For WebSocket connections: diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py index f90694fc..4be59de1 100644 --- a/trustgraph-base/trustgraph/base/iam_client.py +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -41,6 +41,27 @@ class IamClient(RequestResponse): ) return resp.bootstrap_admin_user_id, resp.bootstrap_admin_api_key + async def bootstrap_status(self, timeout=IAM_TIMEOUT): + """Returns whether an unconsumed ``bootstrap`` call would + currently succeed (i.e. iam-svc is in ``bootstrap`` mode and + its tables are empty). Side-effect-free; intended for first- + run UX so a UI can decide whether to render setup.""" + resp = await self._request( + operation="bootstrap-status", timeout=timeout, + ) + return resp.bootstrap_available + + async def whoami(self, actor, timeout=IAM_TIMEOUT): + """Return the user record for ``actor`` (the authenticated + caller's handle). AUTHENTICATED-only; no capability check — + every authenticated user can read themselves.""" + resp = await self._request( + operation="whoami", + actor=actor, + timeout=timeout, + ) + return resp.user + async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT): """Resolve a plaintext API key to its identity triple. diff --git a/trustgraph-base/trustgraph/messaging/translators/iam.py b/trustgraph-base/trustgraph/messaging/translators/iam.py index 4a717bba..1d7bf21c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/iam.py +++ b/trustgraph-base/trustgraph/messaging/translators/iam.py @@ -185,6 +185,10 @@ class IamResponseTranslator(MessageTranslator): result["bootstrap_admin_user_id"] = obj.bootstrap_admin_user_id if obj.bootstrap_admin_api_key: result["bootstrap_admin_api_key"] = obj.bootstrap_admin_api_key + # bootstrap-status: emit unconditionally — the false case is + # meaningful for UIs deciding whether to render first-run + # setup, so it can't be dropped by a truthy-only filter. + result["bootstrap_available"] = bool(obj.bootstrap_available) return result diff --git a/trustgraph-base/trustgraph/schema/services/iam.py b/trustgraph-base/trustgraph/schema/services/iam.py index 4b5685a5..797d6203 100644 --- a/trustgraph-base/trustgraph/schema/services/iam.py +++ b/trustgraph-base/trustgraph/schema/services/iam.py @@ -148,6 +148,10 @@ class IamResponse: bootstrap_admin_user_id: str = "" bootstrap_admin_api_key: str = "" + # bootstrap-status — true iff iam-svc is in 'bootstrap' mode with + # empty tables, i.e. an unconsumed bootstrap call would succeed. + bootstrap_available: bool = False + # ---- authorise / authorise-many outputs ---- # authorise: the regime's allow / deny verdict. decision_allow: bool = False diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 728079c8..e8062fba 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -44,6 +44,8 @@ tg-bootstrap-iam = "trustgraph.cli.bootstrap_iam:main" tg-login = "trustgraph.cli.login:main" tg-create-user = "trustgraph.cli.create_user:main" tg-list-users = "trustgraph.cli.list_users:main" +tg-whoami = "trustgraph.cli.whoami:main" +tg-update-user = "trustgraph.cli.update_user:main" tg-disable-user = "trustgraph.cli.disable_user:main" tg-enable-user = "trustgraph.cli.enable_user:main" tg-delete-user = "trustgraph.cli.delete_user:main" diff --git a/trustgraph-cli/trustgraph/cli/update_user.py b/trustgraph-cli/trustgraph/cli/update_user.py new file mode 100644 index 00000000..5c1dc4d7 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/update_user.py @@ -0,0 +1,125 @@ +""" +Update a user's profile fields: name, email, roles, enabled flag, +must-change-password flag. + +Username is immutable — create a new user and disable the old one +to effect a username change. Password changes go through +``tg-change-password`` (self-service) or ``tg-reset-password`` +(admin-driven). + +Only the fields you supply are changed; omitted fields are left +untouched on the user record. An empty ``--roles`` is rejected by +iam-svc (a user must have at least one role); to demote a user use +``tg-disable-user``. +""" + +import argparse +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def _parse_bool(s): + if s is None: + return None + s = s.strip().lower() + if s in ("yes", "y", "true", "t", "1"): + return True + if s in ("no", "n", "false", "f", "0"): + return False + raise argparse.ArgumentTypeError( + f"expected yes/no, got {s!r}" + ) + + +def do_update_user(args): + user = {} + if args.name is not None: + user["name"] = args.name + if args.email is not None: + user["email"] = args.email + if args.roles is not None: + user["roles"] = args.roles + if args.enabled is not None: + user["enabled"] = args.enabled + if args.must_change_password is not None: + user["must_change_password"] = args.must_change_password + + if not user: + print( + "tg-update-user: nothing to change — supply at least " + "one of --name / --email / --roles / --enabled / " + "--must-change-password", + file=sys.stderr, + ) + sys.exit(2) + + req = { + "operation": "update-user", + "user_id": args.user_id, + "user": user, + } + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + rec = resp.get("user", {}) + print(f"id : {rec.get('id', '')}") + print(f"username : {rec.get('username', '')}") + print(f"name : {rec.get('name', '')}") + print(f"email : {rec.get('email', '')}") + print(f"workspace : {rec.get('workspace', '')}") + print(f"roles : {', '.join(rec.get('roles', []))}") + print(f"enabled : {'yes' if rec.get('enabled') else 'no'}") + print( + f"must-change-pw: " + f"{'yes' if rec.get('must_change_password') else 'no'}" + ) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-update-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="Target user id", + ) + parser.add_argument( + "--name", default=None, help="New display name", + ) + parser.add_argument( + "--email", default=None, help="New email", + ) + parser.add_argument( + "--roles", nargs="+", default=None, + help="Replacement role list (e.g. --roles reader writer)", + ) + parser.add_argument( + "--enabled", type=_parse_bool, default=None, + help="Set enabled flag (yes/no)", + ) + parser.add_argument( + "--must-change-password", type=_parse_bool, default=None, + help="Set must-change-password flag (yes/no)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Optional workspace integrity check — when supplied, " + "iam-svc verifies the target user's home workspace " + "matches" + ), + ) + run_main(do_update_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/whoami.py b/trustgraph-cli/trustgraph/cli/whoami.py new file mode 100644 index 00000000..1799685d --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/whoami.py @@ -0,0 +1,52 @@ +""" +Show the authenticated caller's own user record. +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_whoami(args): + resp = call_iam(args.api_url, args.token, {"operation": "whoami"}) + user = resp.get("user") + if not user: + print("(no user record returned)") + return + + rows = [ + ["id", user.get("id", "")], + ["username", user.get("username", "")], + ["name", user.get("name", "")], + ["email", user.get("email", "")], + ["workspace", user.get("workspace", "")], + ["roles", ", ".join(user.get("roles", []))], + ["enabled", "yes" if user.get("enabled") else "no"], + [ + "must change password", + "yes" if user.get("must_change_password") else "no", + ], + ["created", user.get("created", "")], + ] + print(tabulate.tabulate(rows, tablefmt="plain")) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-whoami", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + run_main(do_whoami, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 37a72f11..03cd748b 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -123,14 +123,31 @@ class Mux: # Per-service capability gating. Resolved through the # operation registry so the WS path matches what HTTP - # callers see — same authority, same caps. Service - # kinds that aren't registered are refused. + # callers see — same authority, same caps. + # + # Lookup mirrors the HTTP routing decision in + # ``request_task``: presence of ``flow`` on the envelope + # means a flow-level data-plane service (graph-rag, + # agent, …); absence means a workspace-level service + # (config, flow management, librarian, …) whose specific + # operation is in the inner request body. ``iam`` is + # treated as workspace-level too — its operations are + # registered with bare names, no kind prefix. from ..registry import lookup as _registry_lookup from ..capabilities import enforce_workspace from aiohttp import web as _web service = data.get("service", "") - op = _registry_lookup(f"flow-service:{service}") + inner = data.get("request") or {} + inner_op = inner.get("operation", "") if isinstance(inner, dict) else "" + + if data.get("flow"): + op = _registry_lookup(f"flow-service:{service}") + elif service == "iam": + op = _registry_lookup(inner_op) if inner_op else None + else: + op = _registry_lookup(f"{service}:{inner_op}") if inner_op else None + if op is None: await self.ws.send_json({ "id": request_id, @@ -142,23 +159,36 @@ class Mux: }) return - # Workspace + flow form the resource address for a - # flow-level service call. Resolve workspace first - # (default-fill from the caller's bound workspace), - # then ask the regime to authorise the service-level - # capability against that {workspace, flow} resource. + # Resolve workspace first (default-fill from the caller's + # bound workspace), then ask the regime to authorise the + # service-level capability against the matched + # operation's resource shape. try: await enforce_workspace(data, self.identity, self.auth) - inner = data.get("request") if isinstance(inner, dict): await enforce_workspace(inner, self.identity, self.auth) - resource = { - "workspace": data.get("workspace", ""), - "flow": data.get("flow", ""), - } + if data.get("flow"): + resource = { + "workspace": data.get("workspace", ""), + "flow": data.get("flow", ""), + } + parameters = {} + else: + # Build a minimal RequestContext so the matched + # operation's own extractors decide resource and + # parameters — same path the HTTP endpoints take. + from ..registry import RequestContext + ctx = RequestContext( + body=inner if isinstance(inner, dict) else {}, + match_info={}, + identity=self.identity, + ) + resource = op.extract_resource(ctx) + parameters = op.extract_parameters(ctx) + await self.auth.authorise( - self.identity, op.capability, resource, {}, + self.identity, op.capability, resource, parameters, ) except _web.HTTPForbidden: await self.ws.send_json({ @@ -183,6 +213,17 @@ class Mux: workspace = data["workspace"] + # Plumb authenticated caller's handle as ``actor`` so + # iam-svc handlers (whoami, future actor-scoped checks) + # know who is calling. Overwrite any caller-supplied + # value so it can't be spoofed over the WS. + if ( + service == "iam" + and isinstance(data.get("request"), dict) + and self.identity is not None + ): + data["request"]["actor"] = self.identity.handle + await self.q.put(( data["id"], workspace, diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py index 0b476b7b..44bbc03e 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py @@ -36,6 +36,10 @@ class AuthEndpoints: app.add_routes([ web.post("/api/v1/auth/login", self.login), web.post("/api/v1/auth/bootstrap", self.bootstrap), + web.post( + "/api/v1/auth/bootstrap-status", + self.bootstrap_status, + ), web.post( "/api/v1/auth/change-password", self.change_password, @@ -83,6 +87,18 @@ class AuthEndpoints: ) return web.json_response(resp) + async def bootstrap_status(self, request): + """Public, side-effect-free. Returns ``{"bootstrap_available": + bool}`` so a UI can decide whether to render first-run setup + without invoking the consuming ``bootstrap`` op.""" + await enforce(request, self.auth, PUBLIC) + resp = await self._forward({"operation": "bootstrap-status"}) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + async def change_password(self, request): """Authenticated (any role). Accepts {current_password, new_password}; user_id is taken from the authenticated diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py index 70fa33f7..749eacd3 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/iam_endpoint.py @@ -92,6 +92,14 @@ class IamEndpoint: identity, op.capability, resource, parameters, ) + # Plumb the authenticated caller's handle through as ``actor`` + # so iam-svc handlers (e.g. whoami, future actor-scoped + # checks) know who is making the request. The gateway is + # the only authority for this — body-supplied ``actor`` + # values are overwritten so callers can't impersonate. + if identity is not None: + body["actor"] = identity.handle + async def responder(x, fin): pass diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index 32a517a9..5e3344f4 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -271,27 +271,31 @@ register(Operation( )) -# API keys: workspace-level resource — keys live within a workspace. +# API keys: SYSTEM-level resource — like users, a key record exists +# in the deployment-wide keys registry. The workspace the key +# authenticates to is a property of the record, not a containment; +# it appears as a parameter so the regime can scope the admin's +# authority to issue / list / revoke against it. register(Operation( name="create-api-key", capability="keys:admin", - resource_level=ResourceLevel.WORKSPACE, - extract_resource=_workspace_from_body, - extract_parameters=_no_parameters, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, )) register(Operation( name="list-api-keys", capability="keys:admin", - resource_level=ResourceLevel.WORKSPACE, - extract_resource=_workspace_from_body, - extract_parameters=_no_parameters, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, )) register(Operation( name="revoke-api-key", capability="keys:admin", - resource_level=ResourceLevel.WORKSPACE, - extract_resource=_workspace_from_body, - extract_parameters=_no_parameters, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_workspace_param_only, )) @@ -370,6 +374,13 @@ register(Operation( extract_resource=_empty_resource, extract_parameters=_no_parameters, )) +register(Operation( + name="bootstrap-status", + capability=PUBLIC, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) register(Operation( name="change-password", capability=AUTHENTICATED, @@ -377,6 +388,13 @@ register(Operation( extract_resource=_empty_resource, extract_parameters=_no_parameters, )) +register(Operation( + name="whoami", + capability=AUTHENTICATED, + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, +)) # --------------------------------------------------------------------------- diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 44c7df23..c89f65b0 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -280,6 +280,10 @@ class IamService: try: if op == "bootstrap": return await self.handle_bootstrap(v) + if op == "bootstrap-status": + return await self.handle_bootstrap_status(v) + if op == "whoami": + return await self.handle_whoami(v) if op == "resolve-api-key": return await self.handle_resolve_api_key(v) if op == "create-user": @@ -483,6 +487,39 @@ class IamService: bootstrap_admin_api_key=plaintext, ) + async def handle_whoami(self, v): + """Return the caller's own user record. ``v.actor`` is the + authenticated identity's handle (the gateway populates it + from ``identity.handle``). No ``users:read`` capability + required — every authenticated user can read themselves.""" + if not v.actor: + return _err( + "invalid-argument", + "actor required (gateway should populate this)", + ) + user_row = await self.table_store.get_user(v.actor) + if user_row is None: + return _err("not-found", "user not found") + return IamResponse(user=self._row_to_user_record(user_row)) + + async def handle_bootstrap_status(self, v): + """Probe op: returns whether the deployment is currently in + the unconsumed-bootstrap state (i.e. ``bootstrap`` mode with + empty tables, where an explicit ``bootstrap`` call would + succeed). PUBLIC so a UI can decide whether to render the + first-run setup flow without invoking the side-effectful + ``bootstrap`` op. + + The information leaked is intentionally narrow: an empty + deployment in bootstrap mode is already inferable (no users, + no logins succeed); this just makes the answer explicit + instead of forcing callers to probe the masked-failure path.""" + available = ( + self.bootstrap_mode == "bootstrap" + and not await self.table_store.any_workspace_exists() + ) + return IamResponse(bootstrap_available=available) + # ------------------------------------------------------------------ # Signing key helpers # ------------------------------------------------------------------ @@ -612,15 +649,22 @@ class IamService: created=_iso(created), ) - async def _user_in_workspace(self, user_id, workspace): + async def _resolve_user(self, user_id, workspace=None): """Return (user_row, error_response_or_None). Loads the user - record, verifies it exists, is enabled, and belongs to - ``workspace``. The workspace scope check rejects cross- - workspace admin attempts.""" + record by id and (when ``workspace`` is supplied) verifies the + record's home workspace matches. + + Workspace is an *optional integrity check* — the user record + is system-level, identified by id alone. If the caller asserts + a workspace, we verify; if they omit it, we just return the + record. Authorisation (whether the caller is permitted to + operate on this user) is the gateway's responsibility via the + contract's ``authorise`` call before the handler is reached. + """ user_row = await self.table_store.get_user(user_id) if user_row is None: return None, _err("not-found", "user not found") - if user_row[1] != workspace: + if workspace and user_row[1] != workspace: return None, _err( "operation-not-permitted", "user is in a different workspace", @@ -665,15 +709,10 @@ class IamService: # ------------------------------------------------------------------ async def handle_reset_password(self, v): - if not v.workspace: - return _err( - "invalid-argument", - "workspace required for reset-password", - ) if not v.user_id: return _err("invalid-argument", "user_id required") - _, err = await self._user_in_workspace(v.user_id, v.workspace) + _, err = await self._resolve_user(v.user_id, v.workspace or None) if err is not None: return err @@ -690,13 +729,11 @@ class IamService: # ------------------------------------------------------------------ async def handle_get_user(self, v): - if not v.workspace: - return _err("invalid-argument", "workspace required") if not v.user_id: return _err("invalid-argument", "user_id required") - user_row, err = await self._user_in_workspace( - v.user_id, v.workspace, + user_row, err = await self._resolve_user( + v.user_id, v.workspace or None, ) if err is not None: return err @@ -707,8 +744,6 @@ class IamService: must_change_password. Username is immutable — change it by creating a new user and disabling the old one. Password changes go through change-password / reset-password.""" - if not v.workspace: - return _err("invalid-argument", "workspace required") if not v.user_id: return _err("invalid-argument", "user_id required") if v.user is None: @@ -719,25 +754,17 @@ class IamService: "password cannot be changed via update-user; " "use change-password or reset-password", ) - if v.user.username and v.user.username != "": - # Compare to existing. Username-change not allowed. - existing, err = await self._user_in_workspace( - v.user_id, v.workspace, + + existing, err = await self._resolve_user( + v.user_id, v.workspace or None, + ) + if err is not None: + return err + if v.user.username and v.user.username != existing[2]: + return _err( + "invalid-argument", + "username is immutable; create a new user instead", ) - if err is not None: - return err - if v.user.username != existing[2]: - return _err( - "invalid-argument", - "username is immutable; create a new user " - "instead", - ) - else: - existing, err = await self._user_in_workspace( - v.user_id, v.workspace, - ) - if err is not None: - return err # Carry forward fields the caller didn't provide. ( @@ -774,12 +801,10 @@ class IamService: async def handle_disable_user(self, v): """Soft-delete: set enabled=false and revoke every API key belonging to the user.""" - if not v.workspace: - return _err("invalid-argument", "workspace required") if not v.user_id: return _err("invalid-argument", "user_id required") - _, err = await self._user_in_workspace(v.user_id, v.workspace) + _, err = await self._resolve_user(v.user_id, v.workspace or None) if err is not None: return err @@ -797,12 +822,10 @@ class IamService: async def handle_enable_user(self, v): """Re-enable a previously disabled user. Does not restore API keys — those have to be re-issued by the admin.""" - if not v.workspace: - return _err("invalid-argument", "workspace required") if not v.user_id: return _err("invalid-argument", "user_id required") - _, err = await self._user_in_workspace(v.user_id, v.workspace) + _, err = await self._resolve_user(v.user_id, v.workspace or None) if err is not None: return err @@ -821,29 +844,30 @@ class IamService: cover GDPR erasure-style requirements). When audit logging lands, the decision to delete vs. anonymise referenced audit rows will need to be revisited.""" - if not v.workspace: - return _err("invalid-argument", "workspace required") if not v.user_id: return _err("invalid-argument", "user_id required") - user_row, err = await self._user_in_workspace( - v.user_id, v.workspace, + user_row, err = await self._resolve_user( + v.user_id, v.workspace or None, ) if err is not None: return err # user_row indices match get_user columns. Username is [2]. username = user_row[2] + record_workspace = user_row[1] # Revoke all API keys. key_rows = await self.table_store.list_api_keys_by_user(v.user_id) for kr in key_rows: await self.table_store.delete_api_key(kr[0]) - # Remove username lookup. + # Remove username lookup — keyed on (workspace, username), + # so use the resolved workspace from the user record rather + # than relying on the caller-supplied filter. if username: await self.table_store.delete_username_lookup( - v.workspace, username, + record_workspace, username, ) # Remove user record. @@ -1098,12 +1122,15 @@ class IamService: # ------------------------------------------------------------------ async def handle_list_users(self, v): - if not v.workspace: - return _err( - "invalid-argument", "workspace required for list-users", - ) - - rows = await self.table_store.list_users_by_workspace(v.workspace) + # System-level operation: workspace, when supplied, is a + # filter on the user record's home-workspace association. + # Empty workspace returns the deployment-wide list — the + # gateway has already authorised the caller's authority to + # see that scope. + if v.workspace: + rows = await self.table_store.list_users_by_workspace(v.workspace) + else: + rows = await self.table_store.list_users() return IamResponse( users=[self._row_to_user_record(r) for r in rows], ) @@ -1113,24 +1140,21 @@ class IamService: # ------------------------------------------------------------------ async def handle_create_api_key(self, v): - if not v.workspace: - return _err( - "invalid-argument", "workspace required for create-api-key", - ) if v.key is None or not v.key.user_id: return _err("invalid-argument", "key.user_id required") if not v.key.name: return _err("invalid-argument", "key.name required") - # Target user must exist and belong to the caller's workspace. - user_row = await self.table_store.get_user(v.key.user_id) - if user_row is None: - return _err("not-found", "user not found") - if user_row[1] != v.workspace: - return _err( - "operation-not-permitted", - "target user is in a different workspace", - ) + # API keys are system-level records with a workspace + # association (the user's home workspace). Workspace is an + # optional integrity check on the caller's request — when + # supplied it must match the target user's home workspace; + # when omitted, the user's home workspace is used. + user_row, err = await self._resolve_user( + v.key.user_id, v.workspace or None, + ) + if err is not None: + return err plaintext = _generate_api_key() key_id = str(uuid.uuid4()) @@ -1161,20 +1185,15 @@ class IamService: # ------------------------------------------------------------------ async def handle_list_api_keys(self, v): - if not v.workspace: - return _err( - "invalid-argument", - "workspace required for list-api-keys", - ) if not v.user_id: return _err( "invalid-argument", "user_id required for list-api-keys", ) - # Workspace-scope check: user must live in this workspace. - user_row = await self.table_store.get_user(v.user_id) - if user_row is None or user_row[1] != v.workspace: - return _err("not-found", "user not found in workspace") + # Workspace is an optional integrity check. + _, err = await self._resolve_user(v.user_id, v.workspace or None) + if err is not None: + return err rows = await self.table_store.list_api_keys_by_user(v.user_id) return IamResponse( @@ -1186,11 +1205,6 @@ class IamService: # ------------------------------------------------------------------ async def handle_revoke_api_key(self, v): - if not v.workspace: - return _err( - "invalid-argument", - "workspace required for revoke-api-key", - ) if not v.key_id: return _err("invalid-argument", "key_id required") @@ -1199,13 +1213,15 @@ class IamService: return _err("not-found", "api key not found") key_hash, _id, user_id, _name, _prefix, _expires, _c, _lu = row - # Workspace-scope check via the owning user. - user_row = await self.table_store.get_user(user_id) - if user_row is None or user_row[1] != v.workspace: - return _err( - "operation-not-permitted", - "key belongs to a different workspace", - ) + + # Workspace is an optional integrity check via the owning user. + if v.workspace: + user_row = await self.table_store.get_user(user_id) + if user_row is None or user_row[1] != v.workspace: + return _err( + "operation-not-permitted", + "key belongs to a different workspace", + ) await self.table_store.delete_api_key(key_hash) return IamResponse() diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py index 3d41ebbd..f1a0734f 100644 --- a/trustgraph-flow/trustgraph/tables/iam.py +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -167,6 +167,11 @@ class IamTableStore: roles, enabled, must_change_password, created FROM iam_users WHERE workspace = ? ALLOW FILTERING """) + self.list_users_stmt = c.prepare(""" + SELECT id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + FROM iam_users + """) self.put_username_lookup_stmt = c.prepare(""" INSERT INTO iam_users_by_username (workspace, username, user_id) @@ -304,6 +309,15 @@ class IamTableStore: self.cassandra, self.list_users_by_workspace_stmt, (workspace,), ) + async def list_users(self): + """List every user across the deployment. Used by the + system-level list-users handler when no workspace filter is + supplied; the gateway has already authorised the call against + the caller's authority.""" + return await async_execute( + self.cassandra, self.list_users_stmt, (), + ) + async def delete_user(self, id): await async_execute( self.cassandra, self.delete_user_stmt, (id,), From d0850ff381b00c1fdebc1911f92981a51e275641 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 28 Apr 2026 22:46:02 +0100 Subject: [PATCH 21/21] Delete some stuff to free up disk space (#856) --- .github/workflows/release.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 07af8db9..02c546df 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -75,6 +75,13 @@ jobs: - name: Checkout uses: actions/checkout@v4 + - name: "Free up some disk space" + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + podman image prune --all --force + podman builder prune -a -f + - name: Docker Hub token run: echo ${{ secrets.DOCKER_SECRET }} > docker-token.txt