From d35473f7f7c96cb1517fcb2877e8b2e303d74795 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 21 Apr 2026 23:23:01 +0100 Subject: [PATCH] feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840) Introduces `workspace` as the isolation boundary for config, flows, library, and knowledge data. Removes `user` as a schema-level field throughout the code, API specs, and tests; workspace provides the same separation more cleanly at the trusted flow.workspace layer rather than through client-supplied message fields. Design ------ - IAM tech spec (docs/tech-specs/iam.md) documents current state, proposed auth/access model, and migration direction. - Data ownership model (docs/tech-specs/data-ownership-model.md) captures the workspace/collection/flow hierarchy. Schema + messaging ------------------ - Drop `user` field from AgentRequest/Step, GraphRagQuery, DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest, Sparql/Rows/Structured QueryRequest, ToolServiceRequest. - Keep collection/workspace routing via flow.workspace at the service layer. - Translators updated to not serialise/deserialise user. API specs --------- - OpenAPI schemas and path examples cleaned of user fields. - Websocket async-api messages updated. - Removed the unused parameters/User.yaml. Services + base --------------- - Librarian, collection manager, knowledge, config: all operations scoped by workspace. Config client API takes workspace as first positional arg. - `flow.workspace` set at flow start time by the infrastructure; no longer pass-through from clients. - Tool service drops user-personalisation passthrough. CLI + SDK --------- - tg-init-workspace and workspace-aware import/export. - All tg-* commands drop user args; accept --workspace. - Python API/SDK (flow, socket_client, async_*, explainability, library) drop user kwargs from every method signature. MCP server ---------- - All tool endpoints drop user parameters; socket_manager no longer keyed per user. Flow service ------------ - Closure-based topic cleanup on flow stop: only delete topics whose blueprint template was parameterised AND no remaining live flow (across all workspaces) still resolves to that topic. Three scopes fall out naturally from template analysis: * {id} -> per-flow, deleted on stop * {blueprint} -> per-blueprint, kept while any flow of the same blueprint exists * {workspace} -> per-workspace, kept while any flow in the workspace exists * literal -> global, never deleted (e.g. tg.request.librarian) Fixes a bug where stopping a flow silently destroyed the global librarian exchange, wedging all library operations until manual restart. RabbitMQ backend ---------------- - heartbeat=60, blocked_connection_timeout=300. Catches silently dead connections (broker restart, orphaned channels, network partitions) within ~2 heartbeat windows, so the consumer reconnects and re-binds its queue rather than sitting forever on a zombie connection. Tests ----- - Full test refresh: unit, integration, contract, provenance. - Dropped user-field assertions and constructor kwargs across ~100 test files. - Renamed user-collection isolation tests to workspace-collection. --- .gitignore | 1 + Makefile | 2 +- docs/tech-specs/data-ownership-model.md | 309 +++++++ docs/tech-specs/flow-class-definition.md | 36 +- docs/tech-specs/iam.md | 858 ++++++++++++++++++ specs/api/components/parameters/User.yaml | 8 - .../schemas/agent/AgentRequest.yaml | 9 - .../schemas/collection/CollectionRequest.yaml | 7 +- .../collection/CollectionResponse.yaml | 5 - .../DocumentEmbeddingsQueryRequest.yaml | 5 - .../GraphEmbeddingsQueryRequest.yaml | 5 - .../RowEmbeddingsQueryRequest.yaml | 5 - .../schemas/knowledge/KnowledgeRequest.yaml | 17 +- .../schemas/knowledge/KnowledgeResponse.yaml | 10 - .../schemas/librarian/LibrarianRequest.yaml | 5 - .../schemas/loading/DocumentLoadRequest.yaml | 5 - .../schemas/loading/TextLoadRequest.yaml | 5 - .../schemas/query/RowsQueryRequest.yaml | 5 - .../schemas/query/StructuredQueryRequest.yaml | 5 - .../schemas/query/TriplesQueryRequest.yaml | 5 - .../schemas/rag/DocumentRagRequest.yaml | 5 - .../schemas/rag/GraphRagRequest.yaml | 5 - specs/api/paths/collection-management.yaml | 21 +- specs/api/paths/document-stream.yaml | 8 - specs/api/paths/export-core.yaml | 10 - specs/api/paths/flow/agent.yaml | 4 - specs/api/paths/flow/document-embeddings.yaml | 1 - specs/api/paths/flow/document-load.yaml | 2 - specs/api/paths/flow/document-rag.yaml | 3 - specs/api/paths/flow/graph-embeddings.yaml | 1 - specs/api/paths/flow/graph-rag.yaml | 2 - specs/api/paths/flow/row-embeddings.yaml | 1 - specs/api/paths/flow/rows.yaml | 1 - specs/api/paths/flow/sparql-query.yaml | 5 - specs/api/paths/flow/structured-query.yaml | 2 - specs/api/paths/flow/text-load.yaml | 2 - specs/api/paths/flow/triples.yaml | 2 - specs/api/paths/import-core.yaml | 10 - specs/api/paths/knowledge.yaml | 12 +- .../requests/RowEmbeddingsRequest.yaml | 1 - .../messages/requests/SparqlQueryRequest.yaml | 5 - tests/contract/conftest.py | 7 +- .../test_document_embeddings_contract.py | 14 +- tests/contract/test_message_contracts.py | 6 +- tests/contract/test_orchestrator_contracts.py | 4 - .../contract/test_rows_cassandra_contracts.py | 17 +- .../test_rows_graphql_query_contracts.py | 52 +- tests/contract/test_schema_field_contracts.py | 3 +- .../test_structured_data_contracts.py | 13 +- ...test_agent_structured_query_integration.py | 24 +- .../test_cassandra_config_end_to_end.py | 18 +- .../integration/test_cassandra_integration.py | 10 +- .../test_document_rag_integration.py | 5 - ...test_document_rag_streaming_integration.py | 10 - .../integration/test_graph_rag_integration.py | 30 +- .../test_graph_rag_streaming_integration.py | 8 - .../test_import_export_graceful_shutdown.py | 1 - .../test_kg_extract_store_integration.py | 26 +- .../integration/test_nlp_query_integration.py | 18 +- .../test_object_extraction_integration.py | 60 +- .../test_prompt_streaming_integration.py | 5 +- .../test_rag_streaming_protocol.py | 8 - .../test_rows_cassandra_integration.py | 115 ++- .../test_rows_graphql_query_integration.py | 22 +- .../test_structured_query_integration.py | 2 - .../test_agent_service_non_streaming.py | 10 +- tests/unit/test_agent/test_aggregator.py | 10 +- .../test_agent/test_completion_dispatch.py | 4 +- ...est_orchestrator_provenance_integration.py | 1 - .../test_agent/test_pattern_base_subagent.py | 1 - tests/unit/test_agent/test_tool_service.py | 30 +- .../test_agent/test_tool_service_lifecycle.py | 141 +-- .../test_base/test_async_processor_config.py | 257 +++--- .../test_document_embeddings_client.py | 3 - .../unit/test_base/test_flow_base_modules.py | 7 +- .../test_base/test_flow_parameter_specs.py | 4 +- tests/unit/test_base/test_flow_processor.py | 32 +- tests/unit/test_chunking/conftest.py | 4 - .../test_chunking/test_recursive_chunker.py | 1 - .../unit/test_chunking/test_token_chunker.py | 1 - tests/unit/test_cli/test_config_commands.py | 18 +- tests/unit/test_cli/test_load_knowledge.py | 22 +- tests/unit/test_cli/test_tool_commands.py | 8 +- .../test_sync_document_embeddings_client.py | 3 - .../test_graph_rag_concurrency.py | 2 - .../unit/test_cores/test_knowledge_manager.py | 28 +- .../test_decoding/test_universal_processor.py | 6 +- .../test_milvus_collection_naming.py | 30 +- .../test_document_embeddings_processor.py | 8 +- .../test_graph_embeddings_processor.py | 7 +- .../test_row_embeddings_processor.py | 35 +- .../test_definitions_batching.py | 12 +- .../test_relationships_batching.py | 9 +- .../unit/test_gateway/test_config_receiver.py | 232 +++-- .../test_core_import_export_roundtrip.py | 32 +- .../test_gateway/test_dispatch_manager.py | 96 +- .../test_entity_contexts_import_dispatcher.py | 1 - ...test_graph_embeddings_import_dispatcher.py | 1 - .../test_rows_import_dispatcher.py | 1 - .../test_text_document_translator.py | 1 - tests/unit/test_knowledge_graph/conftest.py | 5 +- .../test_agent_extraction.py | 2 - .../test_object_extraction_logic.py | 4 +- .../test_triple_construction.py | 2 - .../test_librarian/test_chunked_upload.py | 68 +- .../test_provenance/test_dag_structure.py | 12 +- .../test_doc_embeddings_milvus_query.py | 40 +- .../test_doc_embeddings_pinecone_query.py | 28 +- .../test_doc_embeddings_qdrant_query.py | 20 +- .../test_graph_embeddings_milvus_query.py | 34 +- .../test_graph_embeddings_pinecone_query.py | 22 +- .../test_graph_embeddings_qdrant_query.py | 16 +- ...st_memgraph_workspace_collection_query.py} | 119 ++- ... test_neo4j_workspace_collection_query.py} | 127 ++- .../test_query/test_rows_cassandra_query.py | 50 +- .../test_triples_cassandra_query.py | 50 +- .../test_query/test_triples_falkordb_query.py | 27 +- .../test_query/test_triples_memgraph_query.py | 27 +- .../test_query/test_triples_neo4j_query.py | 12 +- .../test_metadata_preservation.py | 24 +- .../test_null_embedding_protection.py | 36 +- .../unit/test_retrieval/test_document_rag.py | 27 +- .../test_document_rag_service.py | 24 +- tests/unit/test_retrieval/test_graph_rag.py | 32 +- .../test_graph_rag_explain_forwarding.py | 1 - .../test_retrieval/test_graph_rag_service.py | 3 - tests/unit/test_retrieval/test_nlp_query.py | 10 +- .../test_schema_selection.py | 3 +- .../test_retrieval/test_structured_query.py | 2 - .../test_doc_embeddings_milvus_storage.py | 73 +- .../test_doc_embeddings_pinecone_storage.py | 35 +- .../test_doc_embeddings_qdrant_storage.py | 30 +- .../test_graph_embeddings_milvus_storage.py | 33 +- .../test_graph_embeddings_pinecone_storage.py | 28 +- .../test_graph_embeddings_qdrant_storage.py | 12 +- ...emgraph_workspace_collection_isolation.py} | 269 +++--- ...t_neo4j_workspace_collection_isolation.py} | 394 ++++---- .../test_row_embeddings_qdrant_storage.py | 53 +- .../test_rows_cassandra_storage.py | 121 +-- .../test_triples_cassandra_storage.py | 41 +- .../test_triples_falkordb_storage.py | 93 +- .../test_triples_memgraph_storage.py | 89 +- .../test_triples_neo4j_storage.py | 101 +-- .../test_row_embeddings_query.py | 54 +- .../test_tables/test_knowledge_table_store.py | 17 +- ...ocument_embeddings_translator_roundtrip.py | 2 - .../test_knowledge_translator_roundtrip.py | 9 +- trustgraph-base/trustgraph/api/__init__.py | 3 +- trustgraph-base/trustgraph/api/api.py | 33 +- trustgraph-base/trustgraph/api/async_flow.py | 60 +- .../trustgraph/api/async_socket_client.py | 30 +- trustgraph-base/trustgraph/api/bulk_client.py | 12 +- trustgraph-base/trustgraph/api/collection.py | 101 +-- trustgraph-base/trustgraph/api/config.py | 58 +- .../trustgraph/api/explainability.py | 89 +- trustgraph-base/trustgraph/api/flow.py | 242 +---- trustgraph-base/trustgraph/api/knowledge.py | 92 +- trustgraph-base/trustgraph/api/library.py | 168 ++-- .../trustgraph/api/socket_client.py | 37 +- trustgraph-base/trustgraph/api/types.py | 17 +- .../trustgraph/base/async_processor.py | 172 ++-- .../trustgraph/base/chunking_service.py | 5 +- .../base/collection_config_handler.py | 129 +-- .../trustgraph/base/config_client.py | 38 +- .../trustgraph/base/consumer_spec.py | 5 +- .../base/document_embeddings_client.py | 4 +- .../base/document_embeddings_query_service.py | 4 +- .../base/document_embeddings_store_service.py | 3 +- .../trustgraph/base/dynamic_tool_service.py | 11 +- trustgraph-base/trustgraph/base/flow.py | 7 +- .../trustgraph/base/flow_processor.py | 69 +- .../base/graph_embeddings_client.py | 4 +- .../base/graph_embeddings_query_service.py | 4 +- .../base/graph_embeddings_store_service.py | 3 +- .../trustgraph/base/graph_rag_client.py | 4 +- .../trustgraph/base/librarian_client.py | 24 +- .../trustgraph/base/request_response_spec.py | 5 +- .../base/row_embeddings_query_client.py | 3 +- .../base/structured_query_client.py | 3 +- .../trustgraph/base/subscriber_spec.py | 2 +- .../trustgraph/base/tool_service.py | 1 + .../trustgraph/base/tool_service_client.py | 8 +- .../trustgraph/base/triples_client.py | 7 +- .../trustgraph/base/triples_query_service.py | 12 +- .../trustgraph/base/triples_store_service.py | 5 +- .../trustgraph/clients/config_client.py | 29 + .../clients/document_embeddings_client.py | 4 +- .../trustgraph/clients/document_rag_client.py | 5 +- .../clients/graph_embeddings_client.py | 4 +- .../trustgraph/clients/graph_rag_client.py | 5 +- .../clients/row_embeddings_client.py | 4 +- .../clients/triples_query_client.py | 5 +- .../trustgraph/messaging/translators/agent.py | 2 - .../messaging/translators/collection.py | 8 +- .../messaging/translators/config.py | 19 +- .../messaging/translators/document_loading.py | 12 - .../messaging/translators/embeddings_query.py | 6 - .../trustgraph/messaging/translators/flow.py | 5 +- .../messaging/translators/knowledge.py | 12 +- .../messaging/translators/library.py | 6 +- .../messaging/translators/metadata.py | 12 +- .../messaging/translators/retrieval.py | 4 - .../messaging/translators/rows_query.py | 2 - .../messaging/translators/sparql_query.py | 2 - .../messaging/translators/structured_query.py | 6 +- .../messaging/translators/triples.py | 4 +- .../trustgraph/schema/core/metadata.py | 5 +- .../trustgraph/schema/knowledge/knowledge.py | 6 +- .../trustgraph/schema/services/agent.py | 2 - .../trustgraph/schema/services/collection.py | 13 +- .../trustgraph/schema/services/config.py | 40 +- .../trustgraph/schema/services/flow.py | 4 +- .../trustgraph/schema/services/library.py | 14 +- .../trustgraph/schema/services/query.py | 4 - .../trustgraph/schema/services/retrieval.py | 2 - .../trustgraph/schema/services/rows_query.py | 1 - .../schema/services/sparql_query.py | 1 - .../schema/services/structured_query.py | 1 - .../schema/services/tool_service.py | 2 - trustgraph-cli/pyproject.toml | 2 + .../trustgraph/cli/add_library_document.py | 30 +- .../trustgraph/cli/delete_collection.py | 35 +- .../trustgraph/cli/delete_config_item.py | 13 +- .../trustgraph/cli/delete_flow_blueprint.py | 21 +- .../trustgraph/cli/delete_kg_core.py | 33 +- .../trustgraph/cli/delete_mcp_tool.py | 24 +- trustgraph-cli/trustgraph/cli/delete_tool.py | 24 +- .../trustgraph/cli/export_workspace_config.py | 114 +++ .../trustgraph/cli/get_config_item.py | 13 +- .../trustgraph/cli/get_document_content.py | 18 +- .../trustgraph/cli/get_flow_blueprint.py | 20 +- trustgraph-cli/trustgraph/cli/get_kg_core.py | 33 +- .../trustgraph/cli/graph_to_turtle.py | 23 +- .../trustgraph/cli/import_workspace_config.py | 143 +++ .../trustgraph/cli/init_trustgraph.py | 22 +- trustgraph-cli/trustgraph/cli/invoke_agent.py | 36 +- .../cli/invoke_document_embeddings.py | 20 +- .../trustgraph/cli/invoke_document_rag.py | 45 +- .../trustgraph/cli/invoke_embeddings.py | 13 +- .../trustgraph/cli/invoke_graph_embeddings.py | 20 +- .../trustgraph/cli/invoke_graph_rag.py | 102 +-- trustgraph-cli/trustgraph/cli/invoke_llm.py | 12 +- .../trustgraph/cli/invoke_mcp_tool.py | 20 +- .../trustgraph/cli/invoke_nlp_query.py | 22 +- .../trustgraph/cli/invoke_prompt.py | 12 +- .../trustgraph/cli/invoke_row_embeddings.py | 20 +- .../trustgraph/cli/invoke_rows_query.py | 31 +- .../trustgraph/cli/invoke_sparql_query.py | 28 +- .../trustgraph/cli/invoke_structured_query.py | 32 +- .../trustgraph/cli/list_collections.py | 38 +- .../trustgraph/cli/list_config_items.py | 13 +- .../trustgraph/cli/list_explain_traces.py | 12 +- .../trustgraph/cli/load_doc_embeds.py | 16 +- trustgraph-cli/trustgraph/cli/load_kg_core.py | 36 +- .../trustgraph/cli/load_knowledge.py | 38 +- .../trustgraph/cli/load_sample_documents.py | 29 +- .../trustgraph/cli/load_structured_data.py | 63 +- trustgraph-cli/trustgraph/cli/load_turtle.py | 29 +- .../trustgraph/cli/put_config_item.py | 13 +- .../trustgraph/cli/put_flow_blueprint.py | 13 +- trustgraph-cli/trustgraph/cli/put_kg_core.py | 43 +- trustgraph-cli/trustgraph/cli/query_graph.py | 26 +- .../trustgraph/cli/remove_library_document.py | 28 +- .../trustgraph/cli/save_doc_embeds.py | 15 +- .../trustgraph/cli/set_collection.py | 25 +- trustgraph-cli/trustgraph/cli/set_mcp_tool.py | 12 +- trustgraph-cli/trustgraph/cli/set_prompt.py | 15 +- .../trustgraph/cli/set_token_costs.py | 15 +- trustgraph-cli/trustgraph/cli/set_tool.py | 12 +- trustgraph-cli/trustgraph/cli/show_config.py | 12 +- .../trustgraph/cli/show_explain_trace.py | 59 +- .../cli/show_extraction_provenance.py | 40 +- .../trustgraph/cli/show_flow_blueprints.py | 23 +- .../trustgraph/cli/show_flow_state.py | 17 +- trustgraph-cli/trustgraph/cli/show_flows.py | 28 +- trustgraph-cli/trustgraph/cli/show_graph.py | 21 +- .../trustgraph/cli/show_kg_cores.py | 21 +- .../trustgraph/cli/show_library_documents.py | 16 +- .../trustgraph/cli/show_library_processing.py | 29 +- .../trustgraph/cli/show_mcp_tools.py | 13 +- trustgraph-cli/trustgraph/cli/show_prompts.py | 13 +- .../trustgraph/cli/show_token_costs.py | 13 +- trustgraph-cli/trustgraph/cli/show_tools.py | 13 +- trustgraph-cli/trustgraph/cli/start_flow.py | 13 +- .../cli/start_library_processing.py | 38 +- trustgraph-cli/trustgraph/cli/stop_flow.py | 12 +- .../trustgraph/cli/stop_library_processing.py | 30 +- .../trustgraph/cli/unload_kg_core.py | 24 +- .../trustgraph/cli/verify_system_status.py | 31 +- .../trustgraph/agent/mcp_tool/service.py | 34 +- .../agent/orchestrator/aggregator.py | 3 +- .../agent/orchestrator/pattern_base.py | 67 +- .../agent/orchestrator/plan_pattern.py | 32 +- .../agent/orchestrator/react_pattern.py | 15 +- .../trustgraph/agent/orchestrator/service.py | 48 +- .../agent/orchestrator/supervisor_pattern.py | 19 +- .../trustgraph/agent/react/service.py | 79 +- .../trustgraph/agent/react/tools.py | 33 +- .../trustgraph/chunking/recursive/chunker.py | 6 +- .../trustgraph/chunking/token/chunker.py | 6 +- .../trustgraph/config/service/config.py | 154 ++-- .../trustgraph/config/service/service.py | 9 +- trustgraph-flow/trustgraph/cores/knowledge.py | 29 +- trustgraph-flow/trustgraph/cores/service.py | 12 +- .../decoding/mistral_ocr/processor.py | 8 +- .../trustgraph/decoding/pdf/pdf_decoder.py | 8 +- .../direct/milvus_doc_embeddings.py | 52 +- .../direct/milvus_graph_embeddings.py | 52 +- .../embeddings/row_embeddings/embeddings.py | 59 +- .../trustgraph/extract/kg/agent/extract.py | 45 +- .../extract/kg/definitions/extract.py | 2 - .../trustgraph/extract/kg/ontology/extract.py | 130 ++- .../extract/kg/relationships/extract.py | 1 - .../trustgraph/extract/kg/rows/processor.py | 64 +- .../trustgraph/flow/service/flow.py | 349 ++++--- .../trustgraph/flow/service/service.py | 7 +- .../trustgraph/gateway/config/receiver.py | 125 ++- .../gateway/dispatch/core_export.py | 6 +- .../gateway/dispatch/core_import.py | 8 +- .../gateway/dispatch/document_stream.py | 8 +- .../dispatch/entity_contexts_import.py | 1 - .../dispatch/graph_embeddings_import.py | 1 - .../trustgraph/gateway/dispatch/manager.py | 64 +- .../trustgraph/gateway/dispatch/mux.py | 24 +- .../gateway/dispatch/rows_import.py | 1 - .../trustgraph/gateway/dispatch/serialize.py | 16 +- .../gateway/dispatch/triples_import.py | 1 - .../librarian/collection_manager.py | 55 +- .../trustgraph/librarian/librarian.py | 62 +- .../trustgraph/librarian/service.py | 30 +- .../trustgraph/metering/counter.py | 33 +- .../trustgraph/prompt/template/service.py | 56 +- .../query/doc_embeddings/milvus/service.py | 4 +- .../query/doc_embeddings/pinecone/service.py | 4 +- .../query/doc_embeddings/qdrant/service.py | 4 +- .../query/graph_embeddings/milvus/service.py | 4 +- .../graph_embeddings/pinecone/service.py | 4 +- .../query/graph_embeddings/qdrant/service.py | 4 +- .../trustgraph/query/graphql/schema.py | 6 +- .../query/ontology/query_explanation.py | 4 +- .../query/ontology/query_service.py | 2 +- .../query/ontology/question_analyzer.py | 2 +- .../query/row_embeddings/qdrant/service.py | 20 +- .../query/rows/cassandra/service.py | 72 +- .../trustgraph/query/sparql/algebra.py | 86 +- .../trustgraph/query/sparql/service.py | 2 +- .../query/triples/cassandra/service.py | 26 +- .../query/triples/falkordb/service.py | 2 +- .../query/triples/memgraph/service.py | 133 ++- .../trustgraph/query/triples/neo4j/service.py | 136 ++- .../retrieval/document_rag/document_rag.py | 15 +- .../trustgraph/retrieval/document_rag/rag.py | 21 +- .../retrieval/graph_rag/graph_rag.py | 49 +- .../trustgraph/retrieval/graph_rag/rag.py | 33 +- .../trustgraph/retrieval/nlp_query/service.py | 67 +- .../retrieval/structured_diag/service.py | 67 +- .../retrieval/structured_query/service.py | 4 +- .../storage/doc_embeddings/milvus/write.py | 20 +- .../storage/doc_embeddings/pinecone/write.py | 22 +- .../storage/doc_embeddings/qdrant/write.py | 22 +- .../storage/graph_embeddings/milvus/write.py | 20 +- .../graph_embeddings/pinecone/write.py | 22 +- .../storage/graph_embeddings/qdrant/write.py | 22 +- .../trustgraph/storage/knowledge/store.py | 4 +- .../storage/row_embeddings/qdrant/write.py | 41 +- .../storage/rows/cassandra/write.py | 95 +- .../storage/triples/cassandra/write.py | 48 +- .../storage/triples/falkordb/write.py | 106 ++- .../storage/triples/memgraph/write.py | 140 ++- .../trustgraph/storage/triples/neo4j/write.py | 119 ++- trustgraph-flow/trustgraph/tables/config.py | 79 +- .../trustgraph/tables/knowledge.py | 74 +- trustgraph-flow/trustgraph/tables/library.py | 139 ++- .../trustgraph/tool_service/joke/service.py | 11 +- trustgraph-mcp/trustgraph/mcp_server/mcp.py | 175 ++-- .../trustgraph/decoding/ocr/pdf_decoder.py | 8 +- .../decoding/universal/processor.py | 11 +- 377 files changed, 6868 insertions(+), 5785 deletions(-) create mode 100644 docs/tech-specs/data-ownership-model.md create mode 100644 docs/tech-specs/iam.md delete mode 100644 specs/api/components/parameters/User.yaml rename tests/unit/test_query/{test_memgraph_user_collection_query.py => test_memgraph_workspace_collection_query.py} (76%) rename tests/unit/test_query/{test_neo4j_user_collection_query.py => test_neo4j_workspace_collection_query.py} (75%) rename tests/unit/test_storage/{test_memgraph_user_collection_isolation.py => test_memgraph_workspace_collection_isolation.py} (53%) rename tests/unit/test_storage/{test_neo4j_user_collection_isolation.py => test_neo4j_workspace_collection_isolation.py} (51%) create mode 100644 trustgraph-cli/trustgraph/cli/export_workspace_config.py create mode 100644 trustgraph-cli/trustgraph/cli/import_workspace_config.py diff --git a/.gitignore b/.gitignore index daeba074..32942156 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ trustgraph-parquet/trustgraph/parquet_version.py trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-mcp/trustgraph/mcp_version.py +trustgraph/trustgraph/trustgraph_version.py vertexai/ \ No newline at end of file diff --git a/Makefile b/Makefile index 85f10fdd..0f0f37b2 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ container-bedrock container-vertexai \ container-hf container-ocr \ container-unstructured container-mcp -some-containers: container-base container-flow +some-containers: container-base container-flow container-unstructured push: ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} diff --git a/docs/tech-specs/data-ownership-model.md b/docs/tech-specs/data-ownership-model.md new file mode 100644 index 00000000..ea94ec46 --- /dev/null +++ b/docs/tech-specs/data-ownership-model.md @@ -0,0 +1,309 @@ +--- +layout: default +title: "Data Ownership and Information Separation" +parent: "Tech Specs" +--- + +# Data Ownership and Information Separation + +## Purpose + +This document defines the logical ownership model for data in +TrustGraph: what the artefacts are, who owns them, and how they relate +to each other. + +The IAM spec ([iam.md](iam.md)) describes authentication and +authorisation mechanics. This spec addresses the prior question: what +are the boundaries around data, and who owns what? + +## Concepts + +### Workspace + +A workspace is the primary isolation boundary. It represents an +organisation, team, or independent operating unit. All data belongs to +exactly one workspace. Cross-workspace access is never permitted through +the API. + +A workspace owns: +- Source documents +- Flows (processing pipeline definitions) +- Knowledge cores (stored extraction output) +- Collections (organisational units for extracted knowledge) + +### Collection + +A collection is an organisational unit within a workspace. It groups +extracted knowledge produced from source documents. A workspace can +have multiple collections, allowing: + +- Processing the same documents with different parameters or models. +- Maintaining separate knowledge bases for different purposes. +- Deleting extracted knowledge without deleting source documents. + +Collections do not own source documents. A source document exists at the +workspace level and can be processed into multiple collections. + +### Source document + +A source document (PDF, text file, etc.) is raw input uploaded to the +system. Documents belong to the workspace, not to a specific collection. + +This is intentional. A document is an asset that exists independently +of how it is processed. The same PDF might be processed into multiple +collections with different chunking parameters or extraction models. +Tying a document to a single collection would force re-upload for each +collection. + +### Flow + +A flow defines a processing pipeline: which models to use, what +parameters to apply (chunk size, temperature, etc.), and how processing +services are connected. Flows belong to a workspace. + +The processing services themselves (document-decoder, chunker, +embeddings, LLM completion, etc.) are shared infrastructure — they serve +all workspaces. Each flow has its own queues, keeping data from +different workspaces and flows separate as it moves through the +pipeline. + +Different workspaces can define different flows. Workspace A might use +GPT-5.2 with a chunk size of 2000, while workspace B uses Claude with a +chunk size of 1000. + +### Prompts + +Prompts are templates that control how the LLM behaves during knowledge +extraction and query answering. They belong to a workspace, allowing +different workspaces to have different extraction strategies, response +styles, or domain-specific instructions. + +### Ontology + +An ontology defines the concepts, entities, and relationships that the +extraction pipeline looks for in source documents. Ontologies belong to +a workspace. A medical workspace might define ontologies around diseases, +symptoms, and treatments, while a legal workspace defines ontologies +around statutes, precedents, and obligations. + +### Schemas + +Schemas define structured data types for extraction. They specify what +fields to extract, their types, and how they relate. Schemas belong to +a workspace, as different workspaces extract different structured +information from their documents. + +### Tools, tool services, and MCP tools + +Tools define capabilities available to agents: what actions they can +take, what external services they can call. Tool services configure how +tools connect to backend services. MCP tools configure connections to +remote MCP servers, including authentication tokens. All belong to a +workspace. + +### Agent patterns and agent task types + +Agent patterns define agent behaviour strategies (how an agent reasons, +what steps it follows). Agent task types define the kinds of tasks +agents can perform. Both belong to a workspace, as different workspaces +may have different agent configurations. + +### Token costs + +Token cost definitions specify pricing for LLM token usage per model. +These belong to a workspace since different workspaces may use different +models or have different billing arrangements. + +### Flow blueprints + +Flow blueprints are templates for creating flows. They define the +default pipeline structure and parameters. Blueprints belong to a +workspace, allowing workspaces to define custom processing templates. + +### Parameter types + +Parameter types define the kinds of parameters that flows accept (e.g. +"llm-model", "temperature"), including their defaults and validation +rules. They belong to a workspace since workspaces that define custom +flows need to define the parameter types those flows use. + +### Interface descriptions + +Interface descriptions define the connection points of a flow — what +queues and topics it uses. They belong to a workspace since they +describe workspace-owned flows. + +### Knowledge core + +A knowledge core is a stored snapshot of extracted knowledge (triples +and graph embeddings). Knowledge cores belong to a workspace and can be +loaded into any collection within that workspace. + +Knowledge cores serve as a portable extraction output. You process +documents through a flow, the pipeline produces triples and embeddings, +and the results can be stored as a knowledge core. That core can later +be loaded into a different collection or reloaded after a collection is +cleared. + +### Extracted knowledge + +Extracted knowledge is the live, queryable content within a collection: +triples in the knowledge graph, graph embeddings, and document +embeddings. It is the product of processing source documents through a +flow into a specific collection. + +Extracted knowledge is scoped to a workspace and a collection. It +cannot exist without both. + +### Processing record + +A processing record tracks which source document was processed, through +which flow, into which collection. It links the source document +(workspace-scoped) to the extracted knowledge (workspace + collection +scoped). + +## Ownership summary + +| Artefact | Owned by | Shared across collections? | +|----------|----------|---------------------------| +| Workspaces | Global (platform) | N/A | +| User accounts | Global (platform) | N/A | +| API keys | Global (platform) | N/A | +| Source documents | Workspace | Yes | +| Flows | Workspace | N/A | +| Flow blueprints | Workspace | N/A | +| Prompts | Workspace | N/A | +| Ontologies | Workspace | N/A | +| Schemas | Workspace | N/A | +| Tools | Workspace | N/A | +| Tool services | Workspace | N/A | +| MCP tools | Workspace | N/A | +| Agent patterns | Workspace | N/A | +| Agent task types | Workspace | N/A | +| Token costs | Workspace | N/A | +| Parameter types | Workspace | N/A | +| Interface descriptions | Workspace | N/A | +| Knowledge cores | Workspace | Yes — can be loaded into any collection | +| Collections | Workspace | N/A | +| Extracted knowledge | Workspace + collection | No | +| Processing records | Workspace + collection | No | + +## Scoping summary + +### Global (system-level) + +A small number of artefacts exist outside any workspace: + +- **Workspace registry** — the list of workspaces itself +- **User accounts** — users reference a workspace but are not owned by + one +- **API keys** — belong to users, not workspaces + +These are managed by the IAM layer and exist at the platform level. + +### Workspace-owned + +All other configuration and data is workspace-owned: + +- Flow definitions and parameters +- Flow blueprints +- Prompts +- Ontologies +- Schemas +- Tools, tool services, and MCP tools +- Agent patterns and agent task types +- Token costs +- Parameter types +- Interface descriptions +- Collection definitions +- Knowledge cores +- Source documents +- Collections and their extracted knowledge + +## Relationship between artefacts + +``` +Platform (global) + | + +-- Workspaces + | | + +-- User accounts (each assigned to a workspace) + | | + +-- API keys (belong to users) + +Workspace + | + +-- Source documents (uploaded, unprocessed) + | + +-- Flows (pipeline definitions: models, parameters, queues) + | + +-- Flow blueprints (templates for creating flows) + | + +-- Prompts (LLM instruction templates) + | + +-- Ontologies (entity and relationship definitions) + | + +-- Schemas (structured data type definitions) + | + +-- Tools, tool services, MCP tools (agent capabilities) + | + +-- Agent patterns and agent task types (agent behaviour) + | + +-- Token costs (LLM pricing per model) + | + +-- Parameter types (flow parameter definitions) + | + +-- Interface descriptions (flow connection points) + | + +-- Knowledge cores (stored extraction snapshots) + | + +-- Collections + | + +-- Extracted knowledge (triples, embeddings) + | + +-- Processing records (links documents to collections) +``` + +A typical workflow: + +1. A source document is uploaded to the workspace. +2. A flow defines how to process it (which models, what parameters). +3. The document is processed through the flow into a collection. +4. Processing records track what was processed. +5. Extracted knowledge (triples, embeddings) is queryable within the + collection. +6. Optionally, the extracted knowledge is stored as a knowledge core + for later reuse. + +## Implementation notes + +The current codebase uses a `user` field in message metadata and storage +partition keys to identify the workspace. The `collection` field +identifies the collection within that workspace. The IAM spec describes +how the gateway maps authenticated credentials to a workspace identity +and sets these fields. + +For details on how each storage backend implements this scoping, see: + +- [Entity-Centric Graph](entity-centric-graph.md) — Cassandra KG schema +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) +- [Collection Management](collection-management.md) + +### Known inconsistencies in current implementation + +- **Pipeline intermediate tables** do not include collection in their + partition keys. Re-processing the same document into a different + collection may overwrite intermediate state. +- **Processing metadata** stores collection in the row payload but not + in the partition key, making collection-based queries inefficient. +- **Upload sessions** are keyed by upload ID, not workspace. The + gateway should validate workspace ownership before allowing + operations on upload sessions. + +## References + +- [Identity and Access Management](iam.md) +- [Collection Management](collection-management.md) +- [Entity-Centric Graph](entity-centric-graph.md) +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) +- [Multi-Tenant Support](multi-tenant-support.md) diff --git a/docs/tech-specs/flow-class-definition.md b/docs/tech-specs/flow-class-definition.md index 94229b72..3a81bf71 100644 --- a/docs/tech-specs/flow-class-definition.md +++ b/docs/tech-specs/flow-class-definition.md @@ -20,8 +20,8 @@ Defines shared service processors that are instantiated once per flow blueprint. ```json "class": { "service-name:{class}": { - "request": "queue-pattern:{class}", - "response": "queue-pattern:{class}", + "request": "queue-pattern:{workspace}:{class}", + "response": "queue-pattern:{workspace}:{class}", "settings": { "setting-name": "fixed-value", "parameterized-setting": "{parameter-name}" @@ -31,11 +31,11 @@ Defines shared service processors that are instantiated once per flow blueprint. ``` **Characteristics:** -- Shared across all flow instances of the same class +- Shared across all flow instances of the same class within a workspace - Typically expensive or stateless services (LLMs, embedding models) -- Use `{class}` template variable for queue naming +- Use `{workspace}` and `{class}` template variables for queue naming - Settings can be fixed values or parameterized with `{parameter-name}` syntax -- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}` +- Examples: `embeddings:{workspace}:{class}`, `text-completion:{workspace}:{class}` ### 2. Flow Section Defines flow-specific processors that are instantiated for each individual flow instance. Each flow gets its own isolated set of these processors. @@ -43,8 +43,8 @@ Defines flow-specific processors that are instantiated for each individual flow ```json "flow": { "processor-name:{id}": { - "input": "queue-pattern:{id}", - "output": "queue-pattern:{id}", + "input": "queue-pattern:{workspace}:{id}", + "output": "queue-pattern:{workspace}:{id}", "settings": { "setting-name": "fixed-value", "parameterized-setting": "{parameter-name}" @@ -56,9 +56,9 @@ Defines flow-specific processors that are instantiated for each individual flow **Characteristics:** - Unique instance per flow - Handle flow-specific data and state -- Use `{id}` template variable for queue naming +- Use `{workspace}` and `{id}` template variables for queue naming - Settings can be fixed values or parameterized with `{parameter-name}` syntax -- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}` +- Examples: `chunker:{workspace}:{id}`, `pdf-decoder:{workspace}:{id}` ### 3. Interfaces Section Defines the entry points and interaction contracts for the flow. These form the API surface for external systems and internal component communication. @@ -68,8 +68,8 @@ Interfaces can take two forms: **Fire-and-Forget Pattern** (single queue): ```json "interfaces": { - "document-load": "persistent://tg/flow/document-load:{id}", - "triples-store": "persistent://tg/flow/triples-store:{id}" + "document-load": "persistent://tg/flow/{workspace}:document-load:{id}", + "triples-store": "persistent://tg/flow/{workspace}:triples-store:{id}" } ``` @@ -77,8 +77,8 @@ Interfaces can take two forms: ```json "interfaces": { "embeddings": { - "request": "non-persistent://tg/request/embeddings:{class}", - "response": "non-persistent://tg/response/embeddings:{class}" + "request": "non-persistent://tg/request/{workspace}:embeddings:{class}", + "response": "non-persistent://tg/response/{workspace}:embeddings:{class}" } } ``` @@ -117,6 +117,16 @@ Additional information about the flow blueprint: ### System Variables +#### {workspace} +- Replaced with the workspace identifier +- Isolates queue names between workspaces so that two workspaces + starting the same flow do not share queues +- Must be included in all queue name patterns to ensure workspace + isolation +- Example: `ws-acme`, `ws-globex` +- All blueprint templates must include `{workspace}` in queue name + patterns + #### {id} - Replaced with the unique flow instance identifier - Creates isolated resources for each flow diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md new file mode 100644 index 00000000..5de50749 --- /dev/null +++ b/docs/tech-specs/iam.md @@ -0,0 +1,858 @@ +--- +layout: default +title: "Identity and Access Management" +parent: "Tech Specs" +--- + +# Identity and Access Management + +## Problem Statement + +TrustGraph has no meaningful identity or access management. The system +relies on a single shared gateway token for authentication and an +honour-system `user` query parameter for data isolation. This creates +several problems: + +- **No user identity.** There are no user accounts, no login, and no way + to know who is making a request. The `user` field in message metadata + is a caller-supplied string with no validation — any client can claim + to be any user. + +- **No access control.** A valid gateway token grants unrestricted access + to every endpoint, every user's data, every collection, and every + administrative operation. There is no way to limit what an + authenticated caller can do. + +- **No credential isolation.** All callers share one static token. There + is no per-user credential, no token expiration, and no rotation + mechanism. Revoking access means changing the shared token, which + affects all callers. + +- **Data isolation is unenforced.** Storage backends (Cassandra, Neo4j, + Qdrant) filter queries by `user` and `collection`, but the gateway + does not prevent a caller from specifying another user's identity. + Cross-user data access is trivial. + +- **No audit trail.** There is no logging of who accessed what. Without + user identity, audit logging is impossible. + +These gaps make the system unsuitable for multi-user deployments, +multi-tenant SaaS, or any environment where access needs to be +controlled or audited. + +## Current State + +### Authentication + +The API gateway supports a single shared token configured via the +`GATEWAY_SECRET` environment variable or `--api-token` CLI argument. If +unset, authentication is disabled entirely. When enabled, every HTTP +endpoint requires an `Authorization: Bearer ` header. WebSocket +connections pass the token as a query parameter. + +Implementation: `trustgraph-flow/trustgraph/gateway/auth.py` + +```python +class Authenticator: + def __init__(self, token=None, allow_all=False): + self.token = token + self.allow_all = allow_all + + def permitted(self, token, roles): + if self.allow_all: return True + if self.token != token: return False + return True +``` + +The `roles` parameter is accepted but never evaluated. All authenticated +requests have identical privileges. + +MCP tool configurations support an optional per-tool `auth-token` for +service-to-service authentication with remote MCP servers. These are +static, system-wide tokens — not per-user credentials. See +[mcp-tool-bearer-token.md](mcp-tool-bearer-token.md) for details. + +### User identity + +The `user` field is passed explicitly by the caller as a query parameter +(e.g. `?user=trustgraph`) or set by CLI tools. It flows through the +system in the core `Metadata` dataclass: + +```python +@dataclass +class Metadata: + id: str = "" + root: str = "" + user: str = "" + collection: str = "" +``` + +There is no user registration, login, user database, or session +management. + +### Data isolation + +The `user` + `collection` pair is used at the storage layer to partition +data: + +- **Cassandra**: queries filter by `user` and `collection` columns +- **Neo4j**: queries filter by `user` and `collection` properties +- **Qdrant**: vector search filters by `user` and `collection` metadata + +| Layer | Isolation mechanism | Enforced by | +|-------|-------------------|-------------| +| Gateway | Single shared token | `Authenticator` class | +| Message metadata | `user` + `collection` fields | Caller (honour system) | +| Cassandra | Column filters on `user`, `collection` | Query layer | +| Neo4j | Property filters on `user`, `collection` | Query layer | +| Qdrant | Metadata filters on `user`, `collection` | Query layer | +| Pub/sub topics | Per-flow topic namespacing | Flow service | + +The storage-layer isolation depends on all queries correctly filtering by +`user` and `collection`. There is no gateway-level enforcement preventing +a caller from querying another user's data by passing a different `user` +parameter. + +### Configuration and secrets + +| Setting | Source | Default | Purpose | +|---------|--------|---------|---------| +| `GATEWAY_SECRET` | Env var | Empty (auth disabled) | Gateway bearer token | +| `--api-token` | CLI arg | None | Gateway bearer token (overrides env) | +| `PULSAR_API_KEY` | Env var | None | Pub/sub broker auth | +| MCP `auth-token` | Config service | None | Per-tool MCP server auth | + +No secrets are encrypted at rest. The gateway token and MCP tokens are +stored and transmitted in plaintext (aside from any transport-layer +encryption such as TLS). + +### Capabilities that do not exist + +- Per-user authentication (JWT, OAuth, SAML, API keys per user) +- User accounts or user management +- Role-based access control (RBAC) +- Attribute-based access control (ABAC) +- Per-user or per-workspace API keys +- Token expiration or rotation +- Session management +- Per-user rate limiting +- Audit logging of user actions +- Permission checks preventing cross-user data access +- Multi-workspace credential isolation + +### Key files + +| File | Purpose | +|------|---------| +| `trustgraph-flow/trustgraph/gateway/auth.py` | Authenticator class | +| `trustgraph-flow/trustgraph/gateway/service.py` | Gateway init, token config | +| `trustgraph-flow/trustgraph/gateway/endpoint/*.py` | Per-endpoint auth checks | +| `trustgraph-base/trustgraph/schema/core/metadata.py` | `Metadata` dataclass with `user` field | + +## Technical Design + +### Design principles + +- **Auth at the edge.** The gateway is the single enforcement point. + Internal services trust the gateway and do not re-authenticate. + This avoids distributing credential validation across dozens of + microservices. + +- **Identity from credentials, not from callers.** The gateway derives + user identity from authentication credentials. Callers can no longer + self-declare their identity via query parameters. + +- **Workspace isolation by default.** Every authenticated user belongs to + a workspace. All data operations are scoped to that workspace. + Cross-workspace access is not possible through the API. + +- **Extensible API contract.** The API accepts an optional workspace + parameter on every request. This allows the same protocol to support + single-workspace deployments today and multi-workspace extensions in + the future without breaking changes. + +- **Simple roles, not fine-grained permissions.** A small number of + predefined roles controls what operations a user can perform. This is + sufficient for the current API surface and avoids the complexity of + per-resource permission management. + +### Authentication + +The gateway supports two credential types. Both are carried as a Bearer +token in the `Authorization` header for HTTP requests. The gateway +distinguishes them by format. + +For WebSocket connections, credentials are not passed in the URL or +headers. Instead, the client authenticates after connecting by sending +an auth message as the first frame: + +``` +Client: opens WebSocket to /api/v1/socket +Server: accepts connection (unauthenticated state) +Client: sends {"type": "auth", "token": "tg_abc123..."} +Server: validates token + success → {"type": "auth-ok", "workspace": "acme"} + failure → {"type": "auth-failed", "error": "invalid token"} +``` + +The server rejects all non-auth messages until authentication succeeds. +The socket remains open on auth failure, allowing the client to retry +with a different token without reconnecting. The client can also send +a new auth message at any time to re-authenticate — for example, to +refresh an expiring JWT or to switch workspace. The +resolved identity (user, workspace, roles) is updated on each +successful auth. + +#### API keys + +For programmatic access: CLI tools, scripts, and integrations. + +- Opaque tokens (e.g. `tg_a1b2c3d4e5f6...`). Not JWTs — short, + simple, easy to paste into CLI tools and headers. +- Each user has one or more API keys. +- Keys are stored hashed (SHA-256 with salt) in the IAM service. The + plaintext key is returned once at creation time and cannot be + retrieved afterwards. +- Keys can be revoked individually without affecting other users. +- Keys optionally have an expiry date. Expired keys are rejected. + +On each request, the gateway resolves an API key by: + +1. Hashing the token. +2. Checking a local cache (hash → user/workspace/roles). +3. On cache miss, calling the IAM service to resolve. +4. Caching the result with a short TTL (e.g. 60 seconds). + +Revoked keys stop working when the cache entry expires. No push +invalidation is needed. + +#### JWTs (login sessions) + +For interactive access via the UI or WebSocket connections. + +- A user logs in with username and password. The gateway forwards the + request to the IAM service, which validates the credentials and + returns a signed JWT. +- The JWT carries the user ID, workspace, and roles as claims. +- The gateway validates JWTs locally using the IAM service's public + signing key — no service call needed on subsequent requests. +- Token expiry is enforced by standard JWT validation at the time the + request (or WebSocket connection) is made. +- For long-lived WebSocket connections, the JWT is validated at connect + time only. The connection remains authenticated for its lifetime. + +The IAM service manages the signing key. The gateway fetches the public +key at startup (or on first JWT encounter) and caches it. + +#### Login endpoint + +``` +POST /api/v1/auth/login +{ + "username": "alice", + "password": "..." +} +→ { + "token": "eyJ...", + "expires": "2026-04-20T19:00:00Z" +} +``` + +The gateway forwards this to the IAM service, which validates +credentials and returns a signed JWT. The gateway returns the JWT to +the caller. + +#### IAM service delegation + +The gateway stays thin. Its authentication logic is: + +1. Extract Bearer token from header (or query param for WebSocket). +2. If the token has JWT format (dotted structure), validate the + signature locally and extract claims. +3. Otherwise, treat as an API key: hash it and check the local cache. + On cache miss, call the IAM service to resolve. +4. If neither succeeds, return 401. + +All user management, key management, credential validation, and token +signing logic lives in the IAM service. The gateway is a generic +enforcement point that can be replaced without changing the IAM +service. + +#### No legacy token support + +The existing `GATEWAY_SECRET` shared token is removed. All +authentication uses API keys or JWTs. On first start, the bootstrap +process creates a default workspace and admin user with an initial API +key. + +### User identity + +A user belongs to exactly one workspace. The design supports extending +this to multi-workspace access in the future (see +[Extension points](#extension-points)). + +A user record contains: + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique user identifier (UUID) | +| `name` | string | Display name | +| `email` | string | Email address (optional) | +| `workspace` | string | Workspace the user belongs to | +| `roles` | list[string] | Assigned roles (e.g. `["reader"]`) | +| `enabled` | bool | Whether the user can authenticate | +| `created` | datetime | Account creation timestamp | + +The `workspace` field maps to the existing `user` field in `Metadata`. +This means the storage-layer isolation (Cassandra, Neo4j, Qdrant +filtering by `user` + `collection`) works without changes — the gateway +sets the `user` metadata field to the authenticated user's workspace. + +### Workspaces + +A workspace is an isolated data boundary. Users belong to a workspace, +and all data operations are scoped to it. Workspaces map to the existing +`user` field in `Metadata` and the corresponding Cassandra keyspace, +Qdrant collection prefix, and Neo4j property filters. + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique workspace identifier | +| `name` | string | Display name | +| `enabled` | bool | Whether the workspace is active | +| `created` | datetime | Creation timestamp | + +All data operations are scoped to a workspace. The gateway determines +the effective workspace for each request as follows: + +1. If the request includes a `workspace` parameter, validate it against + the user's assigned workspace. + - If it matches, use it. + - If it does not match, return 403. (This could be extended to + check a workspace access grant list.) +2. If no `workspace` parameter is provided, use the user's assigned + workspace. + +The gateway sets the `user` field in `Metadata` to the effective +workspace ID, replacing the caller-supplied `?user=` query parameter. + +This design ensures forward compatibility. Clients that pass a +workspace parameter will work unchanged if multi-workspace support is +added later. Requests for an unassigned workspace get a clear 403 +rather than silent misbehaviour. + +### Roles and access control + +Three roles with fixed permissions: + +| Role | Data operations | Admin operations | System | +|------|----------------|-----------------|--------| +| `reader` | Query knowledge graph, embeddings, RAG | None | None | +| `writer` | All reader operations + load documents, manage collections | None | None | +| `admin` | All writer operations | Config, flows, collection management, user management | Metrics | + +Role checks happen at the gateway before dispatching to backend +services. Each endpoint declares the minimum role required: + +| Endpoint pattern | Minimum role | +|-----------------|--------------| +| `GET /api/v1/socket` (queries) | `reader` | +| `POST /api/v1/librarian` | `writer` | +| `POST /api/v1/flow/*/import/*` | `writer` | +| `POST /api/v1/config` | `admin` | +| `GET /api/v1/flow/*` | `admin` | +| `GET /api/metrics` | `admin` | + +Roles are hierarchical: `admin` implies `writer`, which implies +`reader`. + +### IAM service + +The IAM service is a new backend service that manages all identity and +access data. It is the authority for users, workspaces, API keys, and +credentials. The gateway delegates to it. + +#### Data model + +``` +iam_workspaces ( + id text PRIMARY KEY, + name text, + enabled boolean, + created timestamp +) + +iam_users ( + id text PRIMARY KEY, + workspace text, + name text, + email text, + password_hash text, + roles set, + enabled boolean, + created timestamp +) + +iam_api_keys ( + key_hash text PRIMARY KEY, + user_id text, + name text, + expires timestamp, + created timestamp +) +``` + +A secondary index on `iam_api_keys.user_id` supports listing a user's +keys. + +#### Responsibilities + +- User CRUD (create, list, update, disable) +- Workspace CRUD (create, list, update, disable) +- API key management (create, revoke, list) +- API key resolution (hash → user/workspace/roles) +- Credential validation (username/password → signed JWT) +- JWT signing key management (initialise, rotate) +- Bootstrap (create default workspace and admin user on first start) + +#### Communication + +The IAM service communicates via the standard request/response pub/sub +pattern, the same as the config service. The gateway calls it to +resolve API keys and to handle login requests. User management +operations (create user, revoke key, etc.) also go through the IAM +service. + +### Gateway changes + +The current `Authenticator` class is replaced with a thin authentication +middleware that delegates to the IAM service: + +For HTTP requests: + +1. Extract Bearer token from the `Authorization` header. +2. If the token has JWT format (dotted structure): + - Validate signature locally using the cached public key. + - Extract user ID, workspace, and roles from claims. +3. Otherwise, treat as an API key: + - Hash the token and check the local cache. + - On cache miss, call the IAM service to resolve. + - Cache the result (user/workspace/roles) with a short TTL. +4. If neither succeeds, return 401. +5. If the user or workspace is disabled, return 403. +6. Check the user's role against the endpoint's minimum role. If + insufficient, return 403. +7. Resolve the effective workspace: + - If the request includes a `workspace` parameter, validate it + against the user's assigned workspace. Return 403 on mismatch. + - If no `workspace` parameter, use the user's assigned workspace. +8. Set the `user` field in the request context to the effective + workspace ID. This propagates through `Metadata` to all downstream + services. + +For WebSocket connections: + +1. Accept the connection in an unauthenticated state. +2. Wait for an auth message (`{"type": "auth", "token": "..."}`). +3. Validate the token using the same logic as steps 2-7 above. +4. On success, attach the resolved identity to the connection and + send `{"type": "auth-ok", ...}`. +5. On failure, send `{"type": "auth-failed", ...}` but keep the + socket open. +6. Reject all non-auth messages until authentication succeeds. +7. Accept new auth messages at any time to re-authenticate. + +### CLI changes + +CLI tools authenticate with API keys: + +- `--api-key` argument on all CLI tools, replacing `--api-token`. +- `tg-create-workspace`, `tg-list-workspaces` for workspace management. +- `tg-create-user`, `tg-list-users`, `tg-disable-user` for user + management. +- `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` for + key management. +- `--workspace` argument on tools that operate on workspace-scoped + data. +- The API key is passed as a Bearer token in the same way as the + current shared token, so the transport protocol is unchanged. + +### Audit logging + +With user identity established, the gateway logs: + +- Timestamp, user ID, workspace, endpoint, HTTP method, response status. +- Audit logs are written to the standard logging output (structured + JSON). Integration with external log aggregation (Loki, ELK) is a + deployment concern, not an application concern. + +### Config service changes + +All configuration is workspace-scoped (see +[data-ownership-model.md](data-ownership-model.md)). The config service +needs to support this. + +#### Schema change + +The config table adds workspace as a key dimension: + +``` +config ( + workspace text, + class text, + key text, + value text, + PRIMARY KEY ((workspace, class), key) +) +``` + +#### Request format + +Config requests add a `workspace` field at the request level. The +existing `(type, key)` structure is unchanged within each workspace. + +**Get:** +```json +{ + "operation": "get", + "workspace": "workspace-a", + "keys": [{"type": "prompt", "key": "rag-prompt"}] +} +``` + +**Put:** +```json +{ + "operation": "put", + "workspace": "workspace-a", + "values": [{"type": "prompt", "key": "rag-prompt", "value": "..."}] +} +``` + +**List (all keys of a type within a workspace):** +```json +{ + "operation": "list", + "workspace": "workspace-a", + "type": "prompt" +} +``` + +**Delete:** +```json +{ + "operation": "delete", + "workspace": "workspace-a", + "keys": [{"type": "prompt", "key": "rag-prompt"}] +} +``` + +The workspace is set by: + +- **Gateway** — from the authenticated user's workspace for API-facing + requests. +- **Internal services** — explicitly, based on `Metadata.user` from + the message being processed, or `_system` for operational config. + +#### System config namespace + +Processor-level operational config (logging levels, connection strings, +resource limits) is not workspace-specific. This stays in a reserved +`_system` workspace that is not associated with any user workspace. +Services read system config at startup without needing a workspace +context. + +#### Config change notifications + +The config notify mechanism pushes change notifications via pub/sub +when config is updated. A single update may affect multiple workspaces +and multiple config types. The notification message carries a dict of +changes keyed by config type, with each value being the list of +affected workspaces: + +```json +{ + "version": 42, + "changes": { + "prompt": ["workspace-a", "workspace-b"], + "schema": ["workspace-a"] + } +} +``` + +System config changes use the reserved `_system` workspace: + +```json +{ + "version": 43, + "changes": { + "logging": ["_system"] + } +} +``` + +This structure is keyed by type because handlers register by type. A +handler registered for `prompt` looks up `"prompt"` directly and gets +the list of affected workspaces — no iteration over unrelated types. + +#### Config change handlers + +The current `on_config` hook mechanism needs two modes to support shared +processing services: + +- **Workspace-scoped handlers** — notify when a config type changes in a + specific workspace. The handler looks up its registered type in the + changes dict and checks if its workspace is in the list. Used by the + gateway and by services that serve a single workspace. + +- **Global handlers** — notify when a config type changes in any + workspace. The handler looks up its registered type in the changes + dict and gets the full list of affected workspaces. Used by shared + processing services (prompt-rag, agent manager, etc.) that serve all + workspaces. Each workspace in the list tells the handler which cache + entry to update rather than reloading everything. + +#### Per-workspace config caching + +Shared services that handle messages from multiple workspaces maintain a +per-workspace config cache. When a message arrives, the service looks up +the config for the workspace identified in `Metadata.user`. If the +workspace is not yet cached, the service fetches its config on demand. +Config change notifications update the relevant cache entry. + +### Flow and queue isolation + +Flows are workspace-owned. When two workspaces start flows with the same +name and blueprint, their queues must be separate to prevent data +mixing. + +Flow blueprint templates currently use `{id}` (flow instance ID) and +`{class}` (blueprint name) as template variables in queue names. A new +`{workspace}` variable is added so queue names include the workspace: + +**Current queue names (no workspace isolation):** +``` +flow:tg:document-load:{id} → flow:tg:document-load:default +request:tg:embeddings:{class} → request:tg:embeddings:everything +``` + +**With workspace isolation:** +``` +flow:tg:{workspace}:document-load:{id} → flow:tg:ws-a:document-load:default +request:tg:{workspace}:embeddings:{class} → request:tg:ws-a:embeddings:everything +``` + +The flow service substitutes `{workspace}` from the authenticated +workspace when starting a flow, the same way it substitutes `{id}` and +`{class}` today. + +Processing services are shared infrastructure — they consume from +workspace-specific queues but are not themselves workspace-aware. The +workspace is carried in `Metadata.user` on every message, so services +know which workspace's data they are processing. + +Blueprint templates need updating to include `{workspace}` in all queue +name patterns. For migration, the flow service can inject the workspace +into queue names automatically if the template does not include +`{workspace}`, defaulting to the legacy behaviour for existing +blueprints. + +See [flow-class-definition.md](flow-class-definition.md) for the full +blueprint template specification. + +### What changes and what doesn't + +**Changes:** + +| Component | Change | +|-----------|--------| +| `gateway/auth.py` | Replace `Authenticator` with new auth middleware | +| `gateway/service.py` | Initialise IAM client, configure JWT validation | +| `gateway/endpoint/*.py` | Add role requirement per endpoint | +| Metadata propagation | Gateway sets `user` from workspace, ignores query param | +| Config service | Add workspace dimension to config schema | +| Config table | `PRIMARY KEY ((workspace, class), key)` | +| Config request/response schema | Add `workspace` field | +| Config notify messages | Include workspace ID in change notifications | +| `on_config` handlers | Support workspace-scoped and global modes | +| Shared services | Per-workspace config caching | +| Flow blueprints | Add `{workspace}` template variable to queue names | +| Flow service | Substitute `{workspace}` when starting flows | +| CLI tools | New user management commands, `--api-key` argument | +| Cassandra schema | New `iam_workspaces`, `iam_users`, `iam_api_keys` tables | + +**Does not change:** + +| Component | Reason | +|-----------|--------| +| Internal service-to-service pub/sub | Services trust the gateway | +| `Metadata` dataclass | `user` field continues to carry workspace identity | +| Storage-layer isolation | Same `user` + `collection` filtering | +| Message serialisation | No schema changes | + +### Migration + +This is a breaking change. Existing deployments must be reconfigured: + +1. `GATEWAY_SECRET` is removed. Authentication requires API keys or + JWT login tokens. +2. The `?user=` query parameter is removed. Workspace identity comes + from authentication. +3. On first start, the IAM service bootstraps a default workspace and + admin user. The initial API key is output to the service log. +4. Operators create additional workspaces and users via CLI tools. +5. Flow blueprints must be updated to include `{workspace}` in queue + name patterns. +6. Config data must be migrated to include the workspace dimension. + +## Extension points + +The design includes deliberate extension points for future capabilities. +These are not implemented but the architecture does not preclude them: + +- **Multi-workspace access.** Users could be granted access to + additional workspaces beyond their primary assignment. The workspace + validation step checks a grant list instead of a single assignment. +- **Rules-based access control.** A separate access control service + could evaluate fine-grained policies (per-collection permissions, + operation-level restrictions, time-based access). The gateway + delegates authorisation decisions to this service. +- **External identity provider integration.** SAML, LDAP, and OIDC + flows (group mapping, claims-based role assignment) could be added + to the IAM service. +- **Cross-workspace administration.** A `superadmin` role for platform + operators who manage multiple workspaces. +- **Delegated workspace provisioning.** APIs for programmatic workspace + creation and user onboarding. + +These extensions are additive — they extend the validation logic +without changing the request/response protocol. The gateway can be +replaced with an alternative implementation that supports these +capabilities while the IAM service and backend services remain +unchanged. + +## Implementation plan + +Workspace support is a prerequisite for auth — users are assigned to +workspaces, config is workspace-scoped, and flows use workspace in +queue names. Implementing workspaces first allows the structural changes +to be tested end-to-end without auth complicating debugging. + +### Phase 1: Workspace support (no auth) + +All workspace-scoped data and processing changes. The system works with +workspaces but no authentication — callers pass workspace as a +parameter, honour system. This allows full end-to-end testing: multiple +workspaces with separate flows, config, queues, and data. + +#### Config service + +- Update config client API to accept a workspace parameter on all + requests +- Update config storage schema to add workspace as a key dimension +- Update config notification API to report changes as a dict of + type → workspace list +- Update the processor base class to understand workspaces in config + notifications (workspace-scoped and global handler modes) +- Update all processors to implement workspace-aware config handling + (per-workspace config caching, on-demand fetch) + +#### Flow and queue isolation + +- Update flow blueprints to include `{workspace}` in all queue name + patterns +- Update the flow service to substitute `{workspace}` when starting + flows +- Update all built-in blueprints to include `{workspace}` + +#### CLI tools (workspace support) + +- Add `--workspace` argument to CLI tools that operate on + workspace-scoped data +- Add `tg-create-workspace`, `tg-list-workspaces` commands + +### Phase 2: Authentication and access control + +With workspaces working, add the IAM service and lock down the gateway. + +#### IAM service + +A new service handling identity and access management on behalf of the +API gateway: + +- Add workspace table support (CRUD, enable/disable) +- Add user table support (CRUD, enable/disable, workspace assignment) +- Add roles support (role assignment, role validation) +- Add API key support (create, revoke, list, hash storage) +- Add ability to initialise a JWT signing key for token grants +- Add token grant endpoint: user/password login returns a signed JWT +- Add bootstrap/initialisation mechanism: ability to set the signing + key and create the initial workspace + admin user on first start + +#### API gateway integration + +- Add IAM middleware to the API gateway replacing the current + `Authenticator` +- Add local JWT validation (public key from IAM service) +- Add API key resolution with local cache (hash → user/workspace/roles, + cache miss calls IAM service, short TTL) +- Add login endpoint forwarding to IAM service +- Add workspace resolution: validate requested workspace against user + assignment +- Add role-based endpoint access checks +- Add user management API endpoints (forwarded to IAM service) +- Add audit logging (user ID, workspace, endpoint, method, status) +- WebSocket auth via first-message protocol (auth message after + connect, socket stays open on failure, re-auth supported) + +#### CLI tools (auth support) + +- Add `tg-create-user`, `tg-list-users`, `tg-disable-user` commands +- Add `tg-create-api-key`, `tg-list-api-keys`, `tg-revoke-api-key` + commands +- Replace `--api-token` with `--api-key` on existing CLI tools + +#### Bootstrap and cutover + +- Create default workspace and admin user on first start if IAM tables + are empty +- Remove `GATEWAY_SECRET` and `?user=` query parameter support + +## Design Decisions + +### IAM data store + +IAM data is stored in dedicated Cassandra tables owned by the IAM +service, not in the config service. Reasons: + +- **Security isolation.** The config service has a broad, generic + protocol. An access control failure on the config service could + expose credentials. A dedicated IAM service with a purpose-built + protocol limits the attack surface and makes security auditing + clearer. +- **Data model fit.** IAM needs indexed lookups (API key hash → user, + list keys by user). The config service's `(workspace, type, key) → + value` model stores opaque JSON strings with no secondary indexes. +- **Scope.** IAM data is global (workspaces, users, keys). Config is + workspace-scoped. Mixing global and workspace-scoped data in the + same store adds complexity. +- **Audit.** IAM operations (key creation, revocation, login attempts) + are security events that should be logged separately from general + config changes. + +## Deferred to future design + +- **OIDC integration.** External identity provider support (SAML, LDAP, + OIDC) is left for future implementation. The extension points section + describes where this fits architecturally. +- **API key scoping.** API keys could be scoped to specific collections + within a workspace rather than granting workspace-wide access. To be + designed when the need arises. +- **tg-init-trustgraph** only initialises a single workspace. + +## References + +- [Data Ownership and Information Separation](data-ownership-model.md) +- [MCP Tool Bearer Token Specification](mcp-tool-bearer-token.md) +- [Multi-Tenant Support Specification](multi-tenant-support.md) +- [Neo4j User Collection Isolation](neo4j-user-collection-isolation.md) diff --git a/specs/api/components/parameters/User.yaml b/specs/api/components/parameters/User.yaml deleted file mode 100644 index ad0657ca..00000000 --- a/specs/api/components/parameters/User.yaml +++ /dev/null @@ -1,8 +0,0 @@ -name: user -in: query -required: false -schema: - type: string - default: trustgraph -description: User identifier -example: alice diff --git a/specs/api/components/schemas/agent/AgentRequest.yaml b/specs/api/components/schemas/agent/AgentRequest.yaml index ddf2019a..26703402 100644 --- a/specs/api/components/schemas/agent/AgentRequest.yaml +++ b/specs/api/components/schemas/agent/AgentRequest.yaml @@ -43,15 +43,6 @@ properties: type: string description: Result of the action example: "Paris is the capital of France" - user: - type: string - description: User context for this step - example: alice - user: - type: string - description: User identifier for multi-tenancy - default: trustgraph - example: alice streaming: type: boolean description: Enable streaming response delivery diff --git a/specs/api/components/schemas/collection/CollectionRequest.yaml b/specs/api/components/schemas/collection/CollectionRequest.yaml index bf3ab7d4..e1dc8338 100644 --- a/specs/api/components/schemas/collection/CollectionRequest.yaml +++ b/specs/api/components/schemas/collection/CollectionRequest.yaml @@ -14,14 +14,9 @@ properties: - delete-collection description: | Collection operation: - - `list-collections`: List collections for user + - `list-collections`: List collections in workspace - `update-collection`: Create or update collection metadata - `delete-collection`: Delete collection - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection identifier (for update, delete) diff --git a/specs/api/components/schemas/collection/CollectionResponse.yaml b/specs/api/components/schemas/collection/CollectionResponse.yaml index f924cbf5..d65a7274 100644 --- a/specs/api/components/schemas/collection/CollectionResponse.yaml +++ b/specs/api/components/schemas/collection/CollectionResponse.yaml @@ -12,13 +12,8 @@ properties: items: type: object required: - - user - collection properties: - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier diff --git a/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml index f2d0aec2..b6e9dcb3 100644 --- a/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml +++ b/specs/api/components/schemas/embeddings-query/DocumentEmbeddingsQueryRequest.yaml @@ -17,11 +17,6 @@ properties: minimum: 1 maximum: 1000 example: 20 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to search diff --git a/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml index 6cf60bbd..212eb3e2 100644 --- a/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml +++ b/specs/api/components/schemas/embeddings-query/GraphEmbeddingsQueryRequest.yaml @@ -17,11 +17,6 @@ properties: minimum: 1 maximum: 1000 example: 20 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to search diff --git a/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml index 916b4beb..51111a94 100644 --- a/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml +++ b/specs/api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml @@ -27,11 +27,6 @@ properties: minimum: 1 maximum: 1000 example: 20 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to search diff --git a/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml b/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml index 5c40e118..8be57dd6 100644 --- a/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml +++ b/specs/api/components/schemas/knowledge/KnowledgeRequest.yaml @@ -18,17 +18,12 @@ properties: - unload-kg-core description: | Knowledge core operation: - - `list-kg-cores`: List knowledge cores for user + - `list-kg-cores`: List knowledge cores in workspace - `get-kg-core`: Get knowledge core by ID - `put-kg-core`: Store triples and/or embeddings - `delete-kg-core`: Delete knowledge core by ID - `load-kg-core`: Load knowledge core into flow - `unload-kg-core`: Unload knowledge core from flow - user: - type: string - description: User identifier (for list-kg-cores, put-kg-core, delete-kg-core) - default: trustgraph - example: alice id: type: string description: Knowledge core ID (for get, put, delete, load, unload) @@ -53,17 +48,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier @@ -89,17 +79,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier diff --git a/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml b/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml index 229233ca..b0e4d6bb 100644 --- a/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml +++ b/specs/api/components/schemas/knowledge/KnowledgeResponse.yaml @@ -15,17 +15,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier @@ -48,17 +43,12 @@ properties: type: object required: - id - - user - collection properties: id: type: string description: Knowledge core ID example: core-123 - user: - type: string - description: User identifier - example: alice collection: type: string description: Collection identifier diff --git a/specs/api/components/schemas/librarian/LibrarianRequest.yaml b/specs/api/components/schemas/librarian/LibrarianRequest.yaml index eed999f0..25dca7e2 100644 --- a/specs/api/components/schemas/librarian/LibrarianRequest.yaml +++ b/specs/api/components/schemas/librarian/LibrarianRequest.yaml @@ -62,11 +62,6 @@ properties: description: Collection identifier default: default example: default - user: - type: string - description: User identifier - default: trustgraph - example: alice document-id: type: string description: Document identifier diff --git a/specs/api/components/schemas/loading/DocumentLoadRequest.yaml b/specs/api/components/schemas/loading/DocumentLoadRequest.yaml index 45bbe428..8d9a996f 100644 --- a/specs/api/components/schemas/loading/DocumentLoadRequest.yaml +++ b/specs/api/components/schemas/loading/DocumentLoadRequest.yaml @@ -15,11 +15,6 @@ properties: type: string description: Document identifier example: doc-456 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection for document diff --git a/specs/api/components/schemas/loading/TextLoadRequest.yaml b/specs/api/components/schemas/loading/TextLoadRequest.yaml index 447308d4..57f7ecc3 100644 --- a/specs/api/components/schemas/loading/TextLoadRequest.yaml +++ b/specs/api/components/schemas/loading/TextLoadRequest.yaml @@ -14,11 +14,6 @@ properties: type: string description: Document identifier example: doc-123 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection for document diff --git a/specs/api/components/schemas/query/RowsQueryRequest.yaml b/specs/api/components/schemas/query/RowsQueryRequest.yaml index 08f03ad3..611864e8 100644 --- a/specs/api/components/schemas/query/RowsQueryRequest.yaml +++ b/specs/api/components/schemas/query/RowsQueryRequest.yaml @@ -28,11 +28,6 @@ properties: type: string description: Operation name (for multi-operation documents) example: GetPerson - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to query diff --git a/specs/api/components/schemas/query/StructuredQueryRequest.yaml b/specs/api/components/schemas/query/StructuredQueryRequest.yaml index ae564c0a..00bc75cb 100644 --- a/specs/api/components/schemas/query/StructuredQueryRequest.yaml +++ b/specs/api/components/schemas/query/StructuredQueryRequest.yaml @@ -10,11 +10,6 @@ properties: type: string description: Natural language question example: Who does Alice know that works in engineering? - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to query diff --git a/specs/api/components/schemas/query/TriplesQueryRequest.yaml b/specs/api/components/schemas/query/TriplesQueryRequest.yaml index d49e0300..0efb1452 100644 --- a/specs/api/components/schemas/query/TriplesQueryRequest.yaml +++ b/specs/api/components/schemas/query/TriplesQueryRequest.yaml @@ -18,11 +18,6 @@ properties: minimum: 1 maximum: 100000 example: 100 - user: - type: string - description: User identifier - default: trustgraph - example: alice collection: type: string description: Collection to query diff --git a/specs/api/components/schemas/rag/DocumentRagRequest.yaml b/specs/api/components/schemas/rag/DocumentRagRequest.yaml index 97a9d2ff..92bc383b 100644 --- a/specs/api/components/schemas/rag/DocumentRagRequest.yaml +++ b/specs/api/components/schemas/rag/DocumentRagRequest.yaml @@ -9,11 +9,6 @@ properties: type: string description: User query or question example: What are the key findings in the research papers? - user: - type: string - description: User identifier for multi-tenancy - default: trustgraph - example: alice collection: type: string description: Collection to search within diff --git a/specs/api/components/schemas/rag/GraphRagRequest.yaml b/specs/api/components/schemas/rag/GraphRagRequest.yaml index 733dd7c1..754dcc92 100644 --- a/specs/api/components/schemas/rag/GraphRagRequest.yaml +++ b/specs/api/components/schemas/rag/GraphRagRequest.yaml @@ -9,11 +9,6 @@ properties: type: string description: User query or question example: What connections exist between quantum physics and computer science? - user: - type: string - description: User identifier for multi-tenancy - default: trustgraph - example: alice collection: type: string description: Collection to search within diff --git a/specs/api/paths/collection-management.yaml b/specs/api/paths/collection-management.yaml index 7dffd4e0..acd7ff9f 100644 --- a/specs/api/paths/collection-management.yaml +++ b/specs/api/paths/collection-management.yaml @@ -10,11 +10,10 @@ post: Collections are organizational units for grouping: - Documents in the librarian - Knowledge cores - - User data + - Workspace data Each collection has: - - **user**: Owner identifier - - **collection**: Unique collection ID + - **collection**: Unique collection ID (within the workspace) - **name**: Human-readable display name - **description**: Purpose and contents - **tags**: Labels for filtering and organization @@ -22,7 +21,7 @@ post: ## Operations ### list-collections - List all collections for a user. Optionally filter by tags and limit results. + List all collections in the workspace. Optionally filter by tags and limit results. Returns array of collection metadata. ### update-collection @@ -30,7 +29,7 @@ post: If it exists, metadata is updated. Allows setting name, description, and tags. ### delete-collection - Delete a collection by user and collection ID. This removes the metadata but + Delete a collection by collection ID. This removes the metadata but typically does not delete the associated data (documents, knowledge cores). operationId: collectionManagementService @@ -44,22 +43,19 @@ post: $ref: '../components/schemas/collection/CollectionRequest.yaml' examples: listCollections: - summary: List all collections for user + summary: List all collections in workspace value: operation: list-collections - user: alice listCollectionsFiltered: summary: List collections filtered by tags value: operation: list-collections - user: alice tag-filter: ["research", "AI"] limit: 50 updateCollection: summary: Create/update collection value: operation: update-collection - user: alice collection: research name: Research Papers description: Academic research papers on AI and ML @@ -69,7 +65,6 @@ post: summary: Delete collection value: operation: delete-collection - user: alice collection: research responses: '200': @@ -84,13 +79,11 @@ post: value: timestamp: "2024-01-15T10:30:00Z" collections: - - user: alice - collection: research + - collection: research name: Research Papers description: Academic research papers on AI and ML tags: ["research", "AI", "academic"] - - user: alice - collection: personal + - collection: personal name: Personal Documents description: Personal notes and documents tags: ["personal"] diff --git a/specs/api/paths/document-stream.yaml b/specs/api/paths/document-stream.yaml index 5f6a11a7..67aea0e1 100644 --- a/specs/api/paths/document-stream.yaml +++ b/specs/api/paths/document-stream.yaml @@ -8,7 +8,6 @@ get: ## Parameters - - `user`: User identifier (required) - `document-id`: Document IRI to retrieve (required) - `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB) @@ -16,13 +15,6 @@ get: security: - bearerAuth: [] parameters: - - name: user - in: query - required: true - schema: - type: string - description: User identifier - example: trustgraph - name: document-id in: query required: true diff --git a/specs/api/paths/export-core.yaml b/specs/api/paths/export-core.yaml index e7dc06b0..7fddd024 100644 --- a/specs/api/paths/export-core.yaml +++ b/specs/api/paths/export-core.yaml @@ -23,7 +23,6 @@ get: "m": { // Metadata "i": "core-id", // Knowledge core ID "m": [...], // Metadata triples array - "u": "user", // User "c": "collection" // Collection }, "t": [...] // Triples array @@ -36,7 +35,6 @@ get: "m": { // Metadata "i": "core-id", "m": [...], - "u": "user", "c": "collection" }, "e": [ // Entities array @@ -56,7 +54,6 @@ get: ## Query Parameters - **id**: Knowledge core ID to export - - **user**: User identifier ## Streaming @@ -86,13 +83,6 @@ get: type: string description: Knowledge core ID to export example: core-123 - - name: user - in: query - required: true - schema: - type: string - description: User identifier - example: alice responses: '200': description: Export stream diff --git a/specs/api/paths/flow/agent.yaml b/specs/api/paths/flow/agent.yaml index 2cecf89c..a38b6a82 100644 --- a/specs/api/paths/flow/agent.yaml +++ b/specs/api/paths/flow/agent.yaml @@ -69,25 +69,21 @@ post: summary: Simple question value: question: What is the capital of France? - user: alice streamingQuestion: summary: Question with streaming enabled value: question: Explain quantum computing - user: alice streaming: true conversationWithHistory: summary: Multi-turn conversation value: question: And what about its population? - user: alice history: - thought: User is asking about the capital of France action: search arguments: query: "capital of France" observation: "Paris is the capital of France" - user: alice responses: '200': description: Successful response diff --git a/specs/api/paths/flow/document-embeddings.yaml b/specs/api/paths/flow/document-embeddings.yaml index dbab2f92..ba7344fe 100644 --- a/specs/api/paths/flow/document-embeddings.yaml +++ b/specs/api/paths/flow/document-embeddings.yaml @@ -75,7 +75,6 @@ post: value: vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] limit: 10 - user: alice collection: research largeQuery: summary: Larger result set diff --git a/specs/api/paths/flow/document-load.yaml b/specs/api/paths/flow/document-load.yaml index 09ddc09f..97ca3f3f 100644 --- a/specs/api/paths/flow/document-load.yaml +++ b/specs/api/paths/flow/document-load.yaml @@ -88,14 +88,12 @@ post: value: data: JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg== id: doc-789 - user: alice collection: research withMetadata: summary: Load with metadata value: data: JVBERi0xLjQKJeLjz9MK... id: doc-101112 - user: bob collection: papers metadata: - s: {v: "doc-101112", e: false} diff --git a/specs/api/paths/flow/document-rag.yaml b/specs/api/paths/flow/document-rag.yaml index f91bfc27..891868a5 100644 --- a/specs/api/paths/flow/document-rag.yaml +++ b/specs/api/paths/flow/document-rag.yaml @@ -40,7 +40,6 @@ post: - Higher = more context but slower - Lower = faster but may miss relevant info - **collection**: Target specific document collection - - **user**: Multi-tenant isolation operationId: documentRagService security: @@ -64,13 +63,11 @@ post: summary: Basic document query value: query: What are the key findings in the research papers? - user: alice collection: research streamingQuery: summary: Streaming query value: query: Summarize the main conclusions - user: alice collection: research doc-limit: 15 streaming: true diff --git a/specs/api/paths/flow/graph-embeddings.yaml b/specs/api/paths/flow/graph-embeddings.yaml index 277659de..16c1f4b3 100644 --- a/specs/api/paths/flow/graph-embeddings.yaml +++ b/specs/api/paths/flow/graph-embeddings.yaml @@ -66,7 +66,6 @@ post: value: vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] limit: 10 - user: alice collection: research largeQuery: summary: Larger result set diff --git a/specs/api/paths/flow/graph-rag.yaml b/specs/api/paths/flow/graph-rag.yaml index b3f087c6..e9ffadea 100644 --- a/specs/api/paths/flow/graph-rag.yaml +++ b/specs/api/paths/flow/graph-rag.yaml @@ -77,13 +77,11 @@ post: summary: Basic graph query value: query: What connections exist between quantum physics and computer science? - user: alice collection: research streamingQuery: summary: Streaming query with custom limits value: query: Trace the historical development of AI from Turing to modern LLMs - user: alice collection: research entity-limit: 40 triple-limit: 25 diff --git a/specs/api/paths/flow/row-embeddings.yaml b/specs/api/paths/flow/row-embeddings.yaml index 05837c06..9f9f5c4f 100644 --- a/specs/api/paths/flow/row-embeddings.yaml +++ b/specs/api/paths/flow/row-embeddings.yaml @@ -62,7 +62,6 @@ post: vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178] schema_name: customers limit: 10 - user: alice collection: sales filteredQuery: summary: Search specific index diff --git a/specs/api/paths/flow/rows.yaml b/specs/api/paths/flow/rows.yaml index d648c9db..e8b1573a 100644 --- a/specs/api/paths/flow/rows.yaml +++ b/specs/api/paths/flow/rows.yaml @@ -89,7 +89,6 @@ post: email } } - user: alice collection: research queryWithVariables: summary: Query with variables diff --git a/specs/api/paths/flow/sparql-query.yaml b/specs/api/paths/flow/sparql-query.yaml index 2f343488..b7970bd0 100644 --- a/specs/api/paths/flow/sparql-query.yaml +++ b/specs/api/paths/flow/sparql-query.yaml @@ -61,10 +61,6 @@ post: query: type: string description: SPARQL 1.1 query string - user: - type: string - default: trustgraph - description: User/keyspace identifier collection: type: string default: default @@ -78,7 +74,6 @@ post: summary: SELECT query value: query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" - user: trustgraph collection: default askQuery: summary: ASK query diff --git a/specs/api/paths/flow/structured-query.yaml b/specs/api/paths/flow/structured-query.yaml index 6d4dfe87..7e963b5b 100644 --- a/specs/api/paths/flow/structured-query.yaml +++ b/specs/api/paths/flow/structured-query.yaml @@ -79,13 +79,11 @@ post: summary: Simple relationship question value: question: Who does Alice know? - user: alice collection: research complexQuestion: summary: Complex multi-hop question value: question: What companies employ engineers that Bob collaborates with? - user: bob collection: work filterQuestion: summary: Question with implicit filters diff --git a/specs/api/paths/flow/text-load.yaml b/specs/api/paths/flow/text-load.yaml index 08bfe47b..b97d249f 100644 --- a/specs/api/paths/flow/text-load.yaml +++ b/specs/api/paths/flow/text-load.yaml @@ -87,14 +87,12 @@ post: value: text: This is the document text... id: doc-123 - user: alice collection: research withMetadata: summary: Load with RDF metadata using base64 text value: text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u id: doc-456 - user: alice collection: research metadata: - s: {v: "doc-456", e: false} diff --git a/specs/api/paths/flow/triples.yaml b/specs/api/paths/flow/triples.yaml index 5557ea5a..9683d9f7 100644 --- a/specs/api/paths/flow/triples.yaml +++ b/specs/api/paths/flow/triples.yaml @@ -81,7 +81,6 @@ post: s: v: https://example.com/person/alice e: true - user: alice collection: research limit: 100 allInstancesOfType: @@ -100,7 +99,6 @@ post: p: v: https://example.com/knows e: true - user: alice limit: 200 responses: '200': diff --git a/specs/api/paths/import-core.yaml b/specs/api/paths/import-core.yaml index 38c99bf0..633f5477 100644 --- a/specs/api/paths/import-core.yaml +++ b/specs/api/paths/import-core.yaml @@ -23,7 +23,6 @@ post: "m": { // Metadata "i": "core-id", // Knowledge core ID "m": [...], // Metadata triples array - "u": "user", // User "c": "collection" // Collection }, "t": [...] // Triples array @@ -36,7 +35,6 @@ post: "m": { // Metadata "i": "core-id", "m": [...], - "u": "user", "c": "collection" }, "e": [ // Entities array @@ -51,7 +49,6 @@ post: ## Query Parameters - **id**: Knowledge core ID - - **user**: User identifier ## Streaming @@ -77,13 +74,6 @@ post: type: string description: Knowledge core ID to import example: core-123 - - name: user - in: query - required: true - schema: - type: string - description: User identifier - example: alice requestBody: required: true content: diff --git a/specs/api/paths/knowledge.yaml b/specs/api/paths/knowledge.yaml index 71bba496..0fe1976f 100644 --- a/specs/api/paths/knowledge.yaml +++ b/specs/api/paths/knowledge.yaml @@ -12,12 +12,12 @@ post: - **Graph Embeddings**: Vector embeddings for entities - **Metadata**: Descriptive information about the knowledge - Each core has an ID, user, and collection for organization. + Each core has an ID and collection for organization (within the workspace). ## Operations ### list-kg-cores - List all knowledge cores for a user. Returns array of core IDs. + List all knowledge cores in the workspace. Returns array of core IDs. ### get-kg-core Retrieve a knowledge core by ID. Returns triples and/or graph embeddings. @@ -58,7 +58,6 @@ post: summary: List knowledge cores value: operation: list-kg-cores - user: alice getKnowledgeCore: summary: Get knowledge core value: @@ -71,7 +70,6 @@ post: triples: metadata: id: core-123 - user: alice collection: default metadata: - s: {v: "https://example.com/core-123", e: true} @@ -91,7 +89,6 @@ post: graph-embeddings: metadata: id: core-123 - user: alice collection: default metadata: [] entities: @@ -106,7 +103,6 @@ post: triples: metadata: id: core-456 - user: bob collection: research metadata: [] triples: @@ -116,7 +112,6 @@ post: graph-embeddings: metadata: id: core-456 - user: bob collection: research metadata: [] entities: @@ -127,7 +122,6 @@ post: value: operation: delete-kg-core id: core-123 - user: alice loadKnowledgeCore: summary: Load core into flow value: @@ -161,7 +155,6 @@ post: triples: metadata: id: core-123 - user: alice collection: default metadata: - s: {v: "https://example.com/core-123", e: true} @@ -177,7 +170,6 @@ post: graph-embeddings: metadata: id: core-123 - user: alice collection: default metadata: [] entities: diff --git a/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml b/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml index 8010417d..efe6412b 100644 --- a/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml +++ b/specs/websocket/components/messages/requests/RowEmbeddingsRequest.yaml @@ -26,5 +26,4 @@ examples: vectors: [0.023, -0.142, 0.089, 0.234] schema_name: customers limit: 10 - user: trustgraph collection: default diff --git a/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml b/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml index 6954b539..bc32f33a 100644 --- a/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml +++ b/specs/websocket/components/messages/requests/SparqlQueryRequest.yaml @@ -24,10 +24,6 @@ properties: query: type: string description: SPARQL 1.1 query string - user: - type: string - default: trustgraph - description: User/keyspace identifier collection: type: string default: default @@ -42,5 +38,4 @@ examples: flow: my-flow request: query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" - user: trustgraph collection: default diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index 4fdfe83b..e93866e7 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -72,7 +72,6 @@ def sample_message_data(): }, "DocumentRagQuery": { "query": "What is artificial intelligence?", - "user": "test_user", "collection": "test_collection", "doc_limit": 10 }, @@ -95,7 +94,6 @@ def sample_message_data(): }, "Metadata": { "id": "test-doc-123", - "user": "test_user", "collection": "test_collection" }, "Term": { @@ -130,9 +128,8 @@ def invalid_message_data(): {}, # Missing required fields ], "DocumentRagQuery": [ - {"query": None, "user": "test", "collection": "test", "doc_limit": 10}, # Invalid query - {"query": "test", "user": None, "collection": "test", "doc_limit": 10}, # Invalid user - {"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit + {"query": None, "collection": "test", "doc_limit": 10}, # Invalid query + {"query": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit {"query": "test"}, # Missing required fields ], "Term": [ diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index b6d14124..c56e7b93 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -18,24 +18,18 @@ class TestDocumentEmbeddingsRequestContract: def test_request_schema_fields(self): """Test that DocumentEmbeddingsRequest has expected fields""" - # Create a request request = DocumentEmbeddingsRequest( vector=[0.1, 0.2, 0.3], limit=10, - user="test_user", collection="test_collection" ) - # Verify all expected fields exist assert hasattr(request, 'vector') assert hasattr(request, 'limit') - assert hasattr(request, 'user') assert hasattr(request, 'collection') - # Verify field values assert request.vector == [0.1, 0.2, 0.3] assert request.limit == 10 - assert request.user == "test_user" assert request.collection == "test_collection" def test_request_translator_decode(self): @@ -45,7 +39,6 @@ class TestDocumentEmbeddingsRequestContract: data = { "vector": [0.1, 0.2, 0.3, 0.4], "limit": 5, - "user": "custom_user", "collection": "custom_collection" } @@ -54,7 +47,6 @@ class TestDocumentEmbeddingsRequestContract: assert isinstance(result, DocumentEmbeddingsRequest) assert result.vector == [0.1, 0.2, 0.3, 0.4] assert result.limit == 5 - assert result.user == "custom_user" assert result.collection == "custom_collection" def test_request_translator_decode_with_defaults(self): @@ -63,7 +55,7 @@ class TestDocumentEmbeddingsRequestContract: data = { "vector": [0.1, 0.2] - # No limit, user, or collection provided + # No limit or collection provided } result = translator.decode(data) @@ -71,7 +63,6 @@ class TestDocumentEmbeddingsRequestContract: assert isinstance(result, DocumentEmbeddingsRequest) assert result.vector == [0.1, 0.2] assert result.limit == 10 # Default - assert result.user == "trustgraph" # Default assert result.collection == "default" # Default def test_request_translator_encode(self): @@ -81,7 +72,6 @@ class TestDocumentEmbeddingsRequestContract: request = DocumentEmbeddingsRequest( vector=[0.5, 0.6], limit=20, - user="test_user", collection="test_collection" ) @@ -90,7 +80,6 @@ class TestDocumentEmbeddingsRequestContract: assert isinstance(result, dict) assert result["vector"] == [0.5, 0.6] assert result["limit"] == 20 - assert result["user"] == "test_user" assert result["collection"] == "test_collection" @@ -219,7 +208,6 @@ class TestDocumentEmbeddingsMessageCompatibility: request_data = { "vector": [0.1, 0.2, 0.3], "limit": 5, - "user": "test_user", "collection": "test_collection" } diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index 6b7f82e7..59db99f6 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -132,7 +132,6 @@ class TestDocumentRagMessageContracts: # Test required fields query = DocumentRagQuery(**query_data) assert hasattr(query, 'query') - assert hasattr(query, 'user') assert hasattr(query, 'collection') assert hasattr(query, 'doc_limit') @@ -154,12 +153,10 @@ class TestDocumentRagMessageContracts: # Test valid query valid_query = DocumentRagQuery( query="What is AI?", - user="test_user", collection="test_collection", doc_limit=5 ) assert valid_query.query == "What is AI?" - assert valid_query.user == "test_user" assert valid_query.collection == "test_collection" assert valid_query.doc_limit == 5 @@ -400,7 +397,6 @@ class TestMetadataMessageContracts: metadata = Metadata(**metadata_data) assert metadata.id == "test-doc-123" - assert metadata.user == "test_user" assert metadata.collection == "test_collection" def test_error_schema_contract(self): @@ -491,7 +487,7 @@ class TestSchemaEvolutionContracts: required_fields = { "TextCompletionRequest": ["system", "prompt"], "TextCompletionResponse": ["error", "response", "model"], - "DocumentRagQuery": ["query", "user", "collection"], + "DocumentRagQuery": ["query", "collection"], "DocumentRagResponse": ["error", "response"], "AgentRequest": ["question", "history"], "AgentResponse": ["error"], diff --git a/tests/contract/test_orchestrator_contracts.py b/tests/contract/test_orchestrator_contracts.py index ab168ece..d0833297 100644 --- a/tests/contract/test_orchestrator_contracts.py +++ b/tests/contract/test_orchestrator_contracts.py @@ -18,7 +18,6 @@ class TestOrchestrationFieldContracts: def test_agent_request_orchestration_fields_roundtrip(self): req = AgentRequest( question="Test question", - user="testuser", collection="default", correlation_id="corr-123", parent_session_id="parent-sess", @@ -42,7 +41,6 @@ class TestOrchestrationFieldContracts: def test_agent_request_orchestration_fields_default_empty(self): req = AgentRequest( question="Test question", - user="testuser", ) assert req.correlation_id == "" @@ -82,7 +80,6 @@ class TestSubagentCompletionStepContract: ) req = AgentRequest( question="goal", - user="testuser", correlation_id="corr-123", history=[step], ) @@ -126,7 +123,6 @@ class TestSynthesisStepContract: req = AgentRequest( question="Original question", - user="testuser", pattern="supervisor", correlation_id="", session_id="parent-sess", diff --git a/tests/contract/test_rows_cassandra_contracts.py b/tests/contract/test_rows_cassandra_contracts.py index bf85b9fb..55a06751 100644 --- a/tests/contract/test_rows_cassandra_contracts.py +++ b/tests/contract/test_rows_cassandra_contracts.py @@ -22,7 +22,6 @@ class TestRowsCassandraContracts: # Create test object with all required fields test_metadata = Metadata( id="test-doc-001", - user="test_user", collection="test_collection", ) @@ -47,7 +46,6 @@ class TestRowsCassandraContracts: # Verify metadata structure assert hasattr(test_object.metadata, 'id') - assert hasattr(test_object.metadata, 'user') assert hasattr(test_object.metadata, 'collection') # Verify types @@ -150,7 +148,6 @@ class TestRowsCassandraContracts: original = ExtractedObject( metadata=Metadata( id="serial-001", - user="test_user", collection="test_coll", ), schema_name="test_schema", @@ -168,7 +165,6 @@ class TestRowsCassandraContracts: # Verify round-trip assert decoded.metadata.id == original.metadata.id - assert decoded.metadata.user == original.metadata.user assert decoded.metadata.collection == original.metadata.collection assert decoded.schema_name == original.schema_name assert decoded.values == original.values @@ -228,8 +224,7 @@ class TestRowsCassandraContracts: # Create test object test_obj = ExtractedObject( metadata=Metadata( - id="meta-001", - user="user123", # -> keyspace + id="meta-001", # -> keyspace collection="coll456", # -> partition key ), schema_name="table789", # -> table name @@ -242,7 +237,6 @@ class TestRowsCassandraContracts: # - metadata.user -> Cassandra keyspace # - schema_name -> Cassandra table # - metadata.collection -> Part of primary key - assert test_obj.metadata.user # Required for keyspace assert test_obj.schema_name # Required for table assert test_obj.metadata.collection # Required for partition key @@ -256,7 +250,6 @@ class TestRowsCassandraContractsBatch: # Create test object with multiple values in batch test_metadata = Metadata( id="batch-doc-001", - user="test_user", collection="test_collection", ) @@ -302,7 +295,6 @@ class TestRowsCassandraContractsBatch: """Test empty batch ExtractedObject contract""" test_metadata = Metadata( id="empty-batch-001", - user="test_user", collection="test_collection", ) @@ -324,7 +316,6 @@ class TestRowsCassandraContractsBatch: """Test single-item batch (backward compatibility) contract""" test_metadata = Metadata( id="single-batch-001", - user="test_user", collection="test_collection", ) @@ -353,7 +344,6 @@ class TestRowsCassandraContractsBatch: original = ExtractedObject( metadata=Metadata( id="batch-serial-001", - user="test_user", collection="test_coll", ), schema_name="test_schema", @@ -375,7 +365,6 @@ class TestRowsCassandraContractsBatch: # Verify round-trip for batch assert decoded.metadata.id == original.metadata.id - assert decoded.metadata.user == original.metadata.user assert decoded.metadata.collection == original.metadata.collection assert decoded.schema_name == original.schema_name assert len(decoded.values) == len(original.values) @@ -425,8 +414,7 @@ class TestRowsCassandraContractsBatch: # 3. Be stored in the same keyspace (user) test_metadata = Metadata( - id="partition-test-001", - user="consistent_user", # Same keyspace + id="partition-test-001", # Same keyspace collection="consistent_collection", # Same partition ) @@ -443,7 +431,6 @@ class TestRowsCassandraContractsBatch: ) # Verify consistency contract - assert batch_object.metadata.user # Must have user for keyspace assert batch_object.metadata.collection # Must have collection for partition key # Verify unique primary keys in batch diff --git a/tests/contract/test_rows_graphql_query_contracts.py b/tests/contract/test_rows_graphql_query_contracts.py index db796306..e5baa6e2 100644 --- a/tests/contract/test_rows_graphql_query_contracts.py +++ b/tests/contract/test_rows_graphql_query_contracts.py @@ -21,29 +21,25 @@ class TestRowsGraphQLQueryContracts: """Test RowsQueryRequest schema structure and required fields""" # Create test request with all required fields test_request = RowsQueryRequest( - user="test_user", collection="test_collection", query='{ customers { id name email } }', variables={"status": "active", "limit": "10"}, operation_name="GetCustomers" ) - + # Verify all required fields are present - assert hasattr(test_request, 'user') - assert hasattr(test_request, 'collection') + assert hasattr(test_request, 'collection') assert hasattr(test_request, 'query') assert hasattr(test_request, 'variables') assert hasattr(test_request, 'operation_name') - + # Verify field types - assert isinstance(test_request.user, str) assert isinstance(test_request.collection, str) assert isinstance(test_request.query, str) assert isinstance(test_request.variables, dict) assert isinstance(test_request.operation_name, str) - + # Verify content - assert test_request.user == "test_user" assert test_request.collection == "test_collection" assert "customers" in test_request.query assert test_request.variables["status"] == "active" @@ -53,15 +49,13 @@ class TestRowsGraphQLQueryContracts: """Test RowsQueryRequest with minimal required fields""" # Create request with only essential fields minimal_request = RowsQueryRequest( - user="user", collection="collection", query='{ test }', variables={}, operation_name="" ) - + # Verify minimal request is valid - assert minimal_request.user == "user" assert minimal_request.collection == "collection" assert minimal_request.query == '{ test }' assert minimal_request.variables == {} @@ -187,22 +181,20 @@ class TestRowsGraphQLQueryContracts: """Test that request/response can be serialized/deserialized correctly""" # Create original request original_request = RowsQueryRequest( - user="serialization_test", collection="test_data", query='{ orders(limit: 5) { id total customer { name } } }', variables={"limit": "5", "status": "active"}, operation_name="GetRecentOrders" ) - + # Test request serialization using Pulsar schema request_schema = AvroSchema(RowsQueryRequest) - + # Encode and decode request encoded_request = request_schema.encode(original_request) decoded_request = request_schema.decode(encoded_request) - + # Verify request round-trip - assert decoded_request.user == original_request.user assert decoded_request.collection == original_request.collection assert decoded_request.query == original_request.query assert decoded_request.variables == original_request.variables @@ -245,7 +237,7 @@ class TestRowsGraphQLQueryContracts: """Test supported GraphQL query formats""" # Test basic query basic_query = RowsQueryRequest( - user="test", collection="test", query='{ customers { id } }', + collection="test", query='{ customers { id } }', variables={}, operation_name="" ) assert "customers" in basic_query.query @@ -254,7 +246,7 @@ class TestRowsGraphQLQueryContracts: # Test query with variables parameterized_query = RowsQueryRequest( - user="test", collection="test", + collection="test", query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }', variables={"status": "active", "limit": "10"}, operation_name="GetCustomers" @@ -266,7 +258,7 @@ class TestRowsGraphQLQueryContracts: # Test complex nested query nested_query = RowsQueryRequest( - user="test", collection="test", + collection="test", query=''' { customers(limit: 10) { @@ -297,7 +289,7 @@ class TestRowsGraphQLQueryContracts: # This test verifies the current contract, though ideally we'd support all JSON types variables_test = RowsQueryRequest( - user="test", collection="test", query='{ test }', + collection="test", query='{ test }', variables={ "string_var": "test_value", "numeric_var": "123", # Numbers as strings due to Map(String()) limitation @@ -318,22 +310,18 @@ class TestRowsGraphQLQueryContracts: def test_cassandra_context_fields_contract(self): """Test that request contains necessary fields for Cassandra operations""" - # Verify request has fields needed for Cassandra keyspace/table targeting + # Verify request has fields needed for partition key targeting request = RowsQueryRequest( - user="keyspace_name", # Maps to Cassandra keyspace collection="partition_collection", # Used in partition key query='{ objects { id } }', variables={}, operation_name="" ) - - # These fields are required for proper Cassandra operations - assert request.user # Required for keyspace identification - assert request.collection # Required for partition key - + + # Required for partition key + assert request.collection + # Verify field naming follows TrustGraph patterns (matching other query services) - # This matches TriplesQueryRequest, DocumentEmbeddingsRequest patterns - assert hasattr(request, 'user') # Same as TriplesQueryRequest.user - assert hasattr(request, 'collection') # Same as TriplesQueryRequest.collection + assert hasattr(request, 'collection') def test_graphql_extensions_contract(self): """Test GraphQL extensions field format and usage""" @@ -405,7 +393,7 @@ class TestRowsGraphQLQueryContracts: # Request to execute specific operation multi_op_request = RowsQueryRequest( - user="test", collection="test", + collection="test", query=multi_op_query, variables={}, operation_name="GetCustomers" @@ -418,7 +406,7 @@ class TestRowsGraphQLQueryContracts: # Test single operation (operation_name optional) single_op_request = RowsQueryRequest( - user="test", collection="test", + collection="test", query='{ customers { id } }', variables={}, operation_name="" ) diff --git a/tests/contract/test_schema_field_contracts.py b/tests/contract/test_schema_field_contracts.py index 4b7c3da5..5be745b8 100644 --- a/tests/contract/test_schema_field_contracts.py +++ b/tests/contract/test_schema_field_contracts.py @@ -41,10 +41,11 @@ class TestSchemaFieldContracts: def test_metadata_fields(self): # NOTE: there is no `metadata` field. A previous regression # constructed Metadata(metadata=...) and crashed at runtime. + # `user` was also dropped in the workspace refactor — workspace + # now flows via flow.workspace, not via message payload. assert _field_names(Metadata) == { "id", "root", - "user", "collection", } diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index d8f4c5cb..8208ef1b 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -93,7 +93,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="structured-data-001", - user="test_user", collection="test_collection", ) @@ -118,7 +117,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-obj-001", - user="test_user", collection="test_collection", ) @@ -143,7 +141,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-batch-001", - user="test_user", collection="test_collection", ) @@ -177,7 +174,6 @@ class TestStructuredDataSchemaContracts: # Arrange metadata = Metadata( id="extracted-empty-001", - user="test_user", collection="test_collection", ) @@ -277,7 +273,6 @@ class TestStructuredEmbeddingsContracts: # Arrange metadata = Metadata( id="struct-embed-001", - user="test_user", collection="test_collection", ) @@ -308,7 +303,7 @@ class TestStructuredDataSerializationContracts: def test_structured_data_submission_serialization(self): """Test StructuredDataSubmission serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") submission_data = { "metadata": metadata, "format": "json", @@ -323,7 +318,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_serialization(self): """Test ExtractedObject serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") object_data = { "metadata": metadata, "schema_name": "test_schema", @@ -373,7 +368,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_batch_serialization(self): """Test ExtractedObject batch serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") batch_object_data = { "metadata": metadata, "schema_name": "test_schema", @@ -392,7 +387,7 @@ class TestStructuredDataSerializationContracts: def test_extracted_object_empty_batch_serialization(self): """Test ExtractedObject empty batch serialization contract""" # Arrange - metadata = Metadata(id="test", user="user", collection="col") + metadata = Metadata(id="test", collection="col") empty_batch_data = { "metadata": metadata, "schema_name": "test_schema", diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index 2442bf10..883e5261 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration: async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config): """Test basic agent integration with structured query tool""" # Arrange - Load tool configuration - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") # Create agent request request = AgentRequest( @@ -66,7 +66,6 @@ class TestAgentStructuredQueryIntegration: state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -119,6 +118,7 @@ Args: { # Mock flow parameter in agent_processor.on_request flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -146,14 +146,13 @@ Args: { async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config): """Test agent handling of structured query errors""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="Find data from a table that doesn't exist using structured query.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -199,6 +198,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -221,14 +221,13 @@ Args: { async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config): """Test agent using structured query in multi-step reasoning""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="First find all customers from California, then tell me how many orders they have made.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -279,6 +278,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -313,14 +313,13 @@ Args: { } } - await agent_processor.on_tools_config(tool_config_with_collection, "v1") + await agent_processor.on_tools_config("default", tool_config_with_collection, "v1") request = AgentRequest( question="Query the sales data for recent transactions.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -371,6 +370,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) @@ -394,10 +394,10 @@ Args: { async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config): """Test that structured query tool arguments are properly validated""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") # Check that the tool was registered with correct arguments - tools = agent_processor.agent.tools + tools = agent_processor.agents["default"].tools assert "structured-query" in tools structured_tool = tools["structured-query"] @@ -414,14 +414,13 @@ Args: { async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config): """Test that structured query results are properly formatted for agent consumption""" # Arrange - await agent_processor.on_tools_config(structured_query_tool_config, "v1") + await agent_processor.on_tools_config("default", structured_query_tool_config, "v1") request = AgentRequest( question="Get customer information and format it nicely.", state="", group=None, history=[], - user="test_user" ) msg = MagicMock() @@ -482,6 +481,7 @@ Args: { flow = MagicMock() flow.side_effect = flow_context + flow.workspace = "default" # Act await agent_processor.on_request(msg, consumer, flow) diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 6c83fb05..1e4276fe 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow: # Create a mock message to trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): # This should create TrustGraph with environment config - await processor.store_triples(mock_message) + await processor.store_triples('test_user', mock_message) # Verify Cluster was created with correct hosts mock_cluster.assert_called_once() @@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('test_user', mock_message) # Should use CLI parameters, not environment mock_cluster.assert_called_once() @@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd: # Mock query to trigger TrustGraph creation mock_query = MagicMock() - mock_query.user = 'default_user' mock_query.collection = 'default_collection' mock_query.s = None mock_query.p = None @@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd: mock_tg_instance.get_all.return_value = [] processor.tg = mock_tg_instance - await processor.query_triples(mock_query) + await processor.query_triples('default_user', mock_query) # Should use defaults mock_cluster.assert_called_once() @@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'legacy_user' mock_message.metadata.collection = 'legacy_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('legacy_user', mock_message) # Should use defaults since old parameters are not recognized mock_cluster.assert_called_once() @@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'precedence_user' mock_message.metadata.collection = 'precedence_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('precedence_user', mock_message) # Should use new parameters, not old ones mock_cluster.assert_called_once() @@ -354,13 +349,12 @@ class TestMultipleHostsHandling: # Trigger TrustGraph creation mock_message = MagicMock() - mock_message.metadata.user = 'single_user' mock_message.metadata.collection = 'single_collection' mock_message.triples = [] # Mock collection_exists to return True with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples('single_user', mock_message) # Single host should be converted to list mock_cluster.assert_called_once() diff --git a/tests/integration/test_cassandra_integration.py b/tests/integration/test_cassandra_integration.py index 2f5a4195..42273d58 100644 --- a/tests/integration/test_cassandra_integration.py +++ b/tests/integration/test_cassandra_integration.py @@ -115,7 +115,7 @@ class TestCassandraIntegration: # Create test message storage_message = Triples( - metadata=Metadata(user="testuser", collection="testcol"), + metadata=Metadata(collection="testcol"), triples=[ Triple( s=Term(type=IRI, iri="http://example.org/person1"), @@ -178,7 +178,7 @@ class TestCassandraIntegration: # Store test data for querying query_test_message = Triples( - metadata=Metadata(user="testuser", collection="testcol"), + metadata=Metadata(collection="testcol"), triples=[ Triple( s=Term(type=IRI, iri="http://example.org/alice"), @@ -212,7 +212,6 @@ class TestCassandraIntegration: p=None, # None for wildcard o=None, # None for wildcard limit=10, - user="testuser", collection="testcol" ) s_results = await query_processor.query_triples(s_query) @@ -232,7 +231,6 @@ class TestCassandraIntegration: p=Term(type=IRI, iri="http://example.org/knows"), o=None, # None for wildcard limit=10, - user="testuser", collection="testcol" ) p_results = await query_processor.query_triples(p_query) @@ -259,7 +257,7 @@ class TestCassandraIntegration: # Create multiple coroutines for concurrent storage async def store_person_data(person_id, name, age, department): message = Triples( - metadata=Metadata(user="concurrent_test", collection="people"), + metadata=Metadata(collection="people"), triples=[ Triple( s=Term(type=IRI, iri=f"http://example.org/{person_id}"), @@ -329,7 +327,7 @@ class TestCassandraIntegration: # Create a knowledge graph about a company company_graph = Triples( - metadata=Metadata(user="integration_test", collection="company"), + metadata=Metadata(collection="company"), triples=[ # People and their types Triple( diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 8c165385..78a85acf 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -99,7 +99,6 @@ class TestDocumentRagIntegration: # Act result = await document_rag.query( query=query, - user=user, collection=collection, doc_limit=doc_limit ) @@ -110,7 +109,6 @@ class TestDocumentRagIntegration: mock_doc_embeddings_client.query.assert_called_once_with( vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], limit=doc_limit, - user=user, collection=collection ) @@ -278,14 +276,12 @@ class TestDocumentRagIntegration: # Act await document_rag.query( f"query from {user} in {collection}", - user=user, collection=collection ) # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args - assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection @pytest.mark.asyncio @@ -353,6 +349,5 @@ class TestDocumentRagIntegration: # Assert mock_doc_embeddings_client.query.assert_called_once() call_args = mock_doc_embeddings_client.query.call_args - assert call_args.kwargs['user'] == "trustgraph" assert call_args.kwargs['collection'] == "default" assert call_args.kwargs['limit'] == 20 diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index e2c032ad..49ddc3a2 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -107,7 +107,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query=query, - user="test_user", collection="test_collection", doc_limit=10, streaming=True, @@ -141,7 +140,6 @@ class TestDocumentRagStreaming: # Act - Non-streaming non_streaming_result = await document_rag_streaming.query( query=query, - user=user, collection=collection, doc_limit=doc_limit, streaming=False @@ -155,7 +153,6 @@ class TestDocumentRagStreaming: streaming_result = await document_rag_streaming.query( query=query, - user=user, collection=collection, doc_limit=doc_limit, streaming=True, @@ -178,7 +175,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=5, streaming=True, @@ -200,7 +196,6 @@ class TestDocumentRagStreaming: # Arrange & Act result = await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=5, streaming=True, @@ -223,7 +218,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query="unknown topic", - user="test_user", collection="test_collection", doc_limit=10, streaming=True, @@ -247,7 +241,6 @@ class TestDocumentRagStreaming: with pytest.raises(Exception) as exc_info: await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=5, streaming=True, @@ -272,7 +265,6 @@ class TestDocumentRagStreaming: # Act result = await document_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", doc_limit=limit, streaming=True, @@ -300,7 +292,6 @@ class TestDocumentRagStreaming: # Act await document_rag_streaming.query( query="test query", - user=user, collection=collection, doc_limit=10, streaming=True, @@ -309,5 +300,4 @@ class TestDocumentRagStreaming: # Assert - Verify user/collection were passed to document embeddings client call_args = mock_doc_embeddings_client.query.call_args - assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 9c3cdf45..696df7ec 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -146,7 +146,6 @@ class TestGraphRagIntegration: # Act response = await graph_rag.query( query=query, - user=user, collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, @@ -163,7 +162,6 @@ class TestGraphRagIntegration: call_args = mock_graph_embeddings_client.query.call_args assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['limit'] == entity_limit - assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection # 3. Should query triples to build knowledge subgraph @@ -204,7 +202,6 @@ class TestGraphRagIntegration: # Act await graph_rag.query( query=query, - user="test_user", collection="test_collection", entity_limit=config["entity_limit"], triple_limit=config["triple_limit"] @@ -224,7 +221,6 @@ class TestGraphRagIntegration: with pytest.raises(Exception) as exc_info: await graph_rag.query( query="test query", - user="test_user", collection="test_collection" ) @@ -247,7 +243,6 @@ class TestGraphRagIntegration: # Act response = await graph_rag.query( query="unknown topic", - user="test_user", collection="test_collection", explain_callback=collect_provenance ) @@ -267,7 +262,6 @@ class TestGraphRagIntegration: # First query await graph_rag.query( query=query, - user="test_user", collection="test_collection" ) @@ -277,7 +271,6 @@ class TestGraphRagIntegration: # Second identical query await graph_rag.query( query=query, - user="test_user", collection="test_collection" ) @@ -289,26 +282,27 @@ class TestGraphRagIntegration: assert second_call_count >= 0 # Should complete without errors @pytest.mark.asyncio - async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client): - """Test that different users/collections are properly isolated""" + async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client): + """Test that different collections propagate through to the embeddings query. + + Workspace isolation is enforced by flow.workspace at the service + boundary — not by parameters on GraphRag.query — so this test + verifies collection routing only. + """ # Arrange query = "test query" - user1, collection1 = "user1", "collection1" - user2, collection2 = "user2", "collection2" + collection1 = "collection1" + collection2 = "collection2" # Act - await graph_rag.query(query=query, user=user1, collection=collection1) - await graph_rag.query(query=query, user=user2, collection=collection2) + await graph_rag.query(query=query, collection=collection1) + await graph_rag.query(query=query, collection=collection2) - # Assert - Both users should have separate queries + # Assert - Each call propagated its collection assert mock_graph_embeddings_client.query.call_count == 2 - # Verify first call first_call = mock_graph_embeddings_client.query.call_args_list[0] - assert first_call.kwargs['user'] == user1 assert first_call.kwargs['collection'] == collection1 - # Verify second call second_call = mock_graph_embeddings_client.query.call_args_list[1] - assert second_call.kwargs['user'] == user2 assert second_call.kwargs['collection'] == collection2 diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 95c494bb..48e26618 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -116,7 +116,6 @@ class TestGraphRagStreaming: # Act - query() returns response, provenance via callback response = await graph_rag_streaming.query( query=query, - user="test_user", collection="test_collection", streaming=True, chunk_callback=collector.collect, @@ -154,7 +153,6 @@ class TestGraphRagStreaming: # Act - Non-streaming non_streaming_response = await graph_rag_streaming.query( query=query, - user=user, collection=collection, streaming=False ) @@ -167,7 +165,6 @@ class TestGraphRagStreaming: streaming_response = await graph_rag_streaming.query( query=query, - user=user, collection=collection, streaming=True, chunk_callback=collect @@ -189,7 +186,6 @@ class TestGraphRagStreaming: # Act response = await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -209,7 +205,6 @@ class TestGraphRagStreaming: # Arrange & Act response = await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=None # No callback provided @@ -231,7 +226,6 @@ class TestGraphRagStreaming: # Act response = await graph_rag_streaming.query( query="unknown topic", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -253,7 +247,6 @@ class TestGraphRagStreaming: with pytest.raises(Exception) as exc_info: await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -273,7 +266,6 @@ class TestGraphRagStreaming: # Act await graph_rag_streaming.query( query="test query", - user="test_user", collection="test_collection", entity_limit=entity_limit, triple_limit=triple_limit, diff --git a/tests/integration/test_import_export_graceful_shutdown.py b/tests/integration/test_import_export_graceful_shutdown.py index a3771b80..2fcd2683 100644 --- a/tests/integration/test_import_export_graceful_shutdown.py +++ b/tests/integration/test_import_export_graceful_shutdown.py @@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend): triples_obj = Triples( metadata=Metadata( id=f"export-msg-{i}", - user=msg_data["metadata"]["user"], collection=msg_data["metadata"]["collection"], ), triples=to_subgraph(msg_data["triples"]), diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 84c0905d..48878a00 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -97,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration: return Chunk( metadata=Metadata( id="doc-123", - user="test_user", collection="test_collection", ), chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns." @@ -247,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange metadata = Metadata( id="test-doc", - user="test_user", collection="test_collection", ) @@ -305,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange metadata = Metadata( id="test-doc", - user="test_user", collection="test_collection", ) @@ -375,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration: sample_triples = Triples( metadata=Metadata( id="test-doc", - user="test_user", collection="test_collection", ), triples=[ @@ -390,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration: mock_msg = MagicMock() mock_msg.value.return_value = sample_triples + mock_flow = MagicMock() + mock_flow.workspace = "test_workspace" + # Act - await processor.on_triples(mock_msg, None, None) + await processor.on_triples(mock_msg, None, mock_flow) # Assert - mock_cassandra_store.add_triples.assert_called_once_with(sample_triples) + mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples) @pytest.mark.asyncio async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store): @@ -407,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration: sample_embeddings = GraphEmbeddings( metadata=Metadata( id="test-doc", - user="test_user", collection="test_collection", ), entities=[ @@ -421,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration: mock_msg = MagicMock() mock_msg.value.return_value = sample_embeddings + mock_flow = MagicMock() + mock_flow.workspace = "test_workspace" + # Act - await processor.on_graph_embeddings(mock_msg, None, None) + await processor.on_graph_embeddings(mock_msg, None, mock_flow) # Assert - mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings) + mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings) @pytest.mark.asyncio async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor, @@ -553,7 +554,7 @@ class TestKnowledgeGraphPipelineIntegration: ) sample_chunk = Chunk( - metadata=Metadata(id="test", user="user", collection="collection"), + metadata=Metadata(id="test", collection="collection"), chunk=b"Test chunk" ) @@ -580,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange large_chunk_batch = [ Chunk( - metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"), + metadata=Metadata(id=f"doc-{i}", collection="collection"), chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8") ) for i in range(100) # Large batch @@ -617,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration: # Arrange original_metadata = Metadata( id="test-doc-123", - user="test_user", collection="test_collection", ) @@ -646,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration: entity_contexts_call = entity_contexts_producer.send.call_args[0][0] assert triples_call.metadata.id == "test-doc-123" - assert triples_call.metadata.user == "test_user" assert triples_call.metadata.collection == "test_collection" assert entity_contexts_call.metadata.id == "test-doc-123" - assert entity_contexts_call.metadata.user == "test_user" assert entity_contexts_call.metadata.collection == "test_collection" \ No newline at end of file diff --git a/tests/integration/test_nlp_query_integration.py b/tests/integration/test_nlp_query_integration.py index 16c4543e..08bf1e77 100644 --- a/tests/integration/test_nlp_query_integration.py +++ b/tests/integration/test_nlp_query_integration.py @@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration: ) # Set up schemas - proc.schemas = sample_schemas + proc.schemas = {"default": dict(sample_schemas)} # Mock the client method proc.client = MagicMock() @@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration: } # Act - Update configuration - await integration_processor.on_schema_config(new_schema_config, "v2") + await integration_processor.on_schema_config("default", new_schema_config, "v2") # Arrange - Test query using new schema request = QuestionToStructuredQueryRequest( @@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration: await integration_processor.on_message(msg, consumer, flow) # Assert - assert "inventory" in integration_processor.schemas + assert "inventory" in integration_processor.schemas["default"] response_call = flow_response.send.call_args response = response_call[0][0] assert response.detected_schemas == ["inventory"] @@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration: graphql_generation_template="custom-graphql-generator" ) - custom_processor.schemas = sample_schemas + custom_processor.schemas = {"default": dict(sample_schemas)} custom_processor.client = MagicMock() request = QuestionToStructuredQueryRequest( @@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration: ] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)] ) - integration_processor.schemas.update(large_schema_set) + integration_processor.schemas["default"].update(large_schema_set) request = QuestionToStructuredQueryRequest( question="Show me data from table_05 and table_12", @@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration: msg.properties.return_value = {"id": f"concurrent-test-{i}"} flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response @@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration: consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" flow_response = AsyncMock() flow.return_value = flow_response diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index 22ba9a3f..8d58c764 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration: return AsyncMock() context.side_effect = context_router + context.workspace = "default" return context @pytest.mark.asyncio @@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration: processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Act - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Assert - assert len(processor.schemas) == 2 - assert "customer_records" in processor.schemas - assert "product_catalog" in processor.schemas - + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 2 + assert "customer_records" in ws_schemas + assert "product_catalog" in ws_schemas + # Verify customer schema - customer_schema = processor.schemas["customer_records"] + customer_schema = ws_schemas["customer_records"] assert customer_schema.name == "customer_records" assert len(customer_schema.fields) == 4 - + # Verify product schema - product_schema = processor.schemas["product_catalog"] + product_schema = ws_schemas["product_catalog"] assert product_schema.name == "product_catalog" assert len(product_schema.fields) == 4 @@ -237,12 +239,11 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create realistic customer data chunk metadata = Metadata( id="customer-doc-001", - user="integration_test", collection="test_documents", ) @@ -304,12 +305,11 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create realistic product data chunk metadata = Metadata( id="product-doc-001", - user="integration_test", collection="test_documents", ) @@ -368,7 +368,7 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create multiple test chunks chunks_data = [ @@ -382,7 +382,6 @@ class TestObjectExtractionServiceIntegration: for chunk_id, text in chunks_data: metadata = Metadata( id=chunk_id, - user="concurrent_test", collection="test_collection", ) chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8')) @@ -431,19 +430,21 @@ class TestObjectExtractionServiceIntegration: "customer_records": integration_config["schema"]["customer_records"] } } - await processor.on_schema_config(initial_config, version=1) - - assert len(processor.schemas) == 1 - assert "customer_records" in processor.schemas - assert "product_catalog" not in processor.schemas - + await processor.on_schema_config("default", initial_config, version=1) + + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 1 + assert "customer_records" in ws_schemas + assert "product_catalog" not in ws_schemas + # Act - Reload with full configuration - await processor.on_schema_config(integration_config, version=2) - + await processor.on_schema_config("default", integration_config, version=2) + # Assert - assert len(processor.schemas) == 2 - assert "customer_records" in processor.schemas - assert "product_catalog" in processor.schemas + ws_schemas = processor.schemas["default"] + assert len(ws_schemas) == 2 + assert "customer_records" in ws_schemas + assert "product_catalog" in ws_schemas @pytest.mark.asyncio async def test_error_resilience_integration(self, integration_config): @@ -474,13 +475,14 @@ class TestObjectExtractionServiceIntegration: return AsyncMock() failing_flow.side_effect = failing_context_router + failing_flow.workspace = "default" processor.flow = failing_flow # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create test chunk - metadata = Metadata(id="error-test", user="test", collection="test") + metadata = Metadata(id="error-test", collection="test") chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process") mock_msg = MagicMock() @@ -510,12 +512,11 @@ class TestObjectExtractionServiceIntegration: processor.convert_values_to_strings = convert_values_to_strings # Load configuration - await processor.on_schema_config(integration_config, version=1) + await processor.on_schema_config("default", integration_config, version=1) # Create chunk with rich metadata original_metadata = Metadata( id="metadata-test-chunk", - user="test_user", collection="test_collection", ) @@ -544,6 +545,5 @@ class TestObjectExtractionServiceIntegration: assert extracted_obj is not None # Verify metadata propagation - assert extracted_obj.metadata.user == "test_user" assert extracted_obj.metadata.collection == "test_collection" assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference \ No newline at end of file diff --git a/tests/integration/test_prompt_streaming_integration.py b/tests/integration/test_prompt_streaming_integration.py index a1414e2d..84a3cdec 100644 --- a/tests/integration/test_prompt_streaming_integration.py +++ b/tests/integration/test_prompt_streaming_integration.py @@ -87,6 +87,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" return context @pytest.fixture @@ -109,7 +110,7 @@ class TestPromptStreaming: def prompt_processor_streaming(self, mock_prompt_manager): """Create Prompt processor with streaming support""" processor = MagicMock() - processor.manager = mock_prompt_manager + processor.managers = {"default": mock_prompt_manager} processor.config_key = "prompt" # Bind the actual on_request method @@ -248,6 +249,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" request = PromptRequest( id="test_prompt", @@ -341,6 +343,7 @@ class TestPromptStreaming: return AsyncMock() context.side_effect = context_router + context.workspace = "default" request = PromptRequest( id="test_prompt", diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 83a90412..279c81ef 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol: # Act await graph_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -267,7 +262,6 @@ class TestDocumentRagStreamingProtocol: # Act await document_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=callback @@ -290,7 +284,6 @@ class TestDocumentRagStreamingProtocol: # Act await document_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect @@ -314,7 +307,6 @@ class TestDocumentRagStreamingProtocol: # Act await document_rag.query( query="test query", - user="test_user", collection="test_collection", streaming=True, chunk_callback=collect diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index a2b8ae08..1358d420 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class _MockFlowDefault: + """Mock Flow with default workspace for testing.""" + workspace = "default" + name = "default" + id = "test-processor" + + +mock_flow_default = _MockFlowDefault() + @pytest.mark.integration class TestRowsCassandraIntegration: """Integration tests for Cassandra row storage with unified table""" @@ -125,14 +136,13 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) - assert "customer_records" in processor.schemas + await processor.on_schema_config("default", config, version=1) + assert "customer_records" in processor.schemas["default"] # Step 2: Process an ExtractedObject test_obj = ExtractedObject( metadata=Metadata( id="doc-001", - user="test_user", collection="import_2024", ), schema_name="customer_records", @@ -149,7 +159,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify Cassandra interactions assert mock_cluster.connect.called @@ -158,7 +168,7 @@ class TestRowsCassandraIntegration: keyspace_calls = [call for call in mock_session.execute.call_args_list if "CREATE KEYSPACE" in str(call)] assert len(keyspace_calls) == 1 - assert "test_user" in str(keyspace_calls[0]) + assert "default" in str(keyspace_calls[0]) # Verify unified table creation (rows table, not per-schema table) table_calls = [call for call in mock_session.execute.call_args_list @@ -209,12 +219,12 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) - assert len(processor.schemas) == 2 + await processor.on_schema_config("default", config, version=1) + assert len(processor.schemas["default"]) == 2 # Process objects for different schemas product_obj = ExtractedObject( - metadata=Metadata(id="p1", user="shop", collection="catalog"), + metadata=Metadata(id="p1", collection="catalog"), schema_name="products", values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], confidence=0.9, @@ -222,7 +232,7 @@ class TestRowsCassandraIntegration: ) order_obj = ExtractedObject( - metadata=Metadata(id="o1", user="shop", collection="sales"), + metadata=Metadata(id="o1", collection="sales"), schema_name="orders", values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], confidence=0.85, @@ -233,7 +243,7 @@ class TestRowsCassandraIntegration: for obj in [product_obj, order_obj]: msg = MagicMock() msg.value.return_value = obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # All data goes into the same unified rows table table_calls = [call for call in mock_session.execute.call_args_list @@ -256,18 +266,20 @@ class TestRowsCassandraIntegration: with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): # Schema with multiple indexed fields - processor.schemas["indexed_data"] = RowSchema( - name="indexed_data", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="category", type="string", size=50, indexed=True), - Field(name="status", type="string", size=50, indexed=True), - Field(name="description", type="string", size=200) # Not indexed - ] - ) + processor.schemas["default"] = { + "indexed_data": RowSchema( + name="indexed_data", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True), + Field(name="status", type="string", size=50, indexed=True), + Field(name="description", type="string", size=200) # Not indexed + ] + ) + } test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test"), + metadata=Metadata(id="t1", collection="test"), schema_name="indexed_data", values=[{ "id": "123", @@ -282,7 +294,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 data inserts (one per indexed field: id, category, status) rows_insert_calls = [call for call in mock_session.execute.call_args_list @@ -342,13 +354,12 @@ class TestRowsCassandraIntegration: } } - await processor.on_schema_config(config, version=1) + await processor.on_schema_config("default", config, version=1) # Process batch object with multiple values batch_obj = ExtractedObject( metadata=Metadata( id="batch-001", - user="test_user", collection="batch_import", ), schema_name="batch_customers", @@ -376,7 +387,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify unified table creation table_calls = [call for call in mock_session.execute.call_args_list @@ -396,14 +407,16 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["empty_test"] = RowSchema( - name="empty_test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) + processor.schemas["default"] = { + "empty_test": RowSchema( + name="empty_test", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + } # Process empty batch object empty_obj = ExtractedObject( - metadata=Metadata(id="empty-1", user="test", collection="empty"), + metadata=Metadata(id="empty-1", collection="empty"), schema_name="empty_test", values=[], # Empty batch confidence=1.0, @@ -413,7 +426,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = empty_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should not create any data insert statements for empty batch # (partition registration may still happen) @@ -428,17 +441,19 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["map_test"] = RowSchema( - name="map_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="name", type="string", size=100), - Field(name="count", type="integer", size=0) - ] - ) + processor.schemas["default"] = { + "map_test": RowSchema( + name="map_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="name", type="string", size=100), + Field(name="count", type="integer", size=0) + ] + ) + } test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test"), + metadata=Metadata(id="t1", collection="test"), schema_name="map_test", values=[{"id": "123", "name": "Test Item", "count": "42"}], confidence=0.9, @@ -448,7 +463,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify insert uses map for data rows_insert_calls = [call for call in mock_session.execute.call_args_list @@ -473,16 +488,18 @@ class TestRowsCassandraIntegration: processor, mock_cluster, mock_session = processor_with_mocks with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["partition_test"] = RowSchema( - name="partition_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="category", type="string", size=50, indexed=True) - ] - ) + processor.schemas["default"] = { + "partition_test": RowSchema( + name="partition_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True) + ] + ) + } test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="my_collection"), + metadata=Metadata(id="t1", collection="my_collection"), schema_name="partition_test", values=[{"id": "123", "category": "test"}], confidence=0.9, @@ -492,7 +509,7 @@ class TestRowsCassandraIntegration: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify partition registration partition_inserts = [call for call in mock_session.execute.call_args_list diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index a717901b..29b4464d 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_schema_configuration_and_generation(self, processor, sample_schema_config): """Test schema configuration loading and GraphQL schema generation""" # Load schema configuration - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Verify schemas were loaded assert len(processor.schemas) == 2 @@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config): """Test Cassandra connection and dynamic table creation""" # Load schema configuration - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra processor.connect_cassandra() @@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config): """Test inserting data and querying via GraphQL""" # Load schema and connect - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Setup test data @@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_graphql_query_with_filters(self, processor, sample_schema_config): """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() keyspace = "test_user" @@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_graphql_error_handling(self, processor, sample_schema_config): """Test GraphQL error handling for invalid queries""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) # Test invalid field query invalid_query = ''' @@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_message_processing_integration(self, processor, sample_schema_config): """Test full message processing workflow""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Create mock message @@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_concurrent_queries(self, processor, sample_schema_config): """Test handling multiple concurrent GraphQL queries""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() # Create multiple query tasks @@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration: } } - await processor.on_schema_config(initial_config, version=1) + await processor.on_schema_config("default", initial_config, version=1) assert len(processor.schemas) == 1 assert "simple" in processor.schemas @@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration: } } - await processor.on_schema_config(updated_config, version=2) + await processor.on_schema_config("default", updated_config, version=2) # Verify updated schemas assert len(processor.schemas) == 2 @@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration: async def test_large_result_set_handling(self, processor, sample_schema_config): """Test handling of large query result sets""" # Setup - await processor.on_schema_config(sample_schema_config, version=1) + await processor.on_schema_config("default", sample_schema_config, version=1) processor.connect_cassandra() keyspace = "large_test_user" @@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance: } } - await processor.on_schema_config(schema_config, version=1) + await processor.on_schema_config("default", schema_config, version=1) # Measure query execution time start_time = time.time() diff --git a/tests/integration/test_structured_query_integration.py b/tests/integration/test_structured_query_integration.py index d5fb5672..67c85406 100644 --- a/tests/integration/test_structured_query_integration.py +++ b/tests/integration/test_structured_query_integration.py @@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration: # Arrange - Create realistic query request request = StructuredQueryRequest( question="Show me all customers from California who have made purchases over $500", - user="trustgraph", collection="default" ) @@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration: assert "orders" in objects_call_args.query assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string assert objects_call_args.variables["state"] == "California" - assert objects_call_args.user == "trustgraph" assert objects_call_args.collection == "default" # Verify response diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index bb58e5ee..0a27b118 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming: # Setup mock agent manager mock_agent_instance = AsyncMock() mock_agent_manager_class.return_value = mock_agent_instance + mock_agent_instance.tools = {} + mock_agent_instance.additional_context = "" + processor.agents["default"] = mock_agent_instance # Mock react to call think and observe callbacks async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): @@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming: msg = MagicMock() msg.value.return_value = AgentRequest( question="What is 2 + 2?", - user="trustgraph", streaming=False # Non-streaming mode ) msg.properties.return_value = {"id": "test-id"} @@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming: # Setup flow mock consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" mock_producer = AsyncMock() @@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming: # Setup mock agent manager mock_agent_instance = AsyncMock() mock_agent_manager_class.return_value = mock_agent_instance + mock_agent_instance.tools = {} + mock_agent_instance.additional_context = "" + processor.agents["default"] = mock_agent_instance # Mock react to return Final directly async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): @@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming: msg = MagicMock() msg.value.return_value = AgentRequest( question="What is 2 + 2?", - user="trustgraph", streaming=False # Non-streaming mode ) msg.properties.return_value = {"id": "test-id"} @@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming: # Setup flow mock consumer = MagicMock() flow = MagicMock() + flow.workspace = "default" mock_producer = AsyncMock() diff --git a/tests/unit/test_agent/test_aggregator.py b/tests/unit/test_agent/test_aggregator.py index afb19499..87a9e3bc 100644 --- a/tests/unit/test_agent/test_aggregator.py +++ b/tests/unit/test_agent/test_aggregator.py @@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep from trustgraph.agent.orchestrator.aggregator import Aggregator -def _make_request(question="Test question", user="testuser", +def _make_request(question="Test question", collection="default", streaming=False, session_id="parent-session", task_type="research", framing="test framing", conversation_id="conv-1"): return AgentRequest( question=question, - user=user, collection=collection, streaming=streaming, session_id=session_id, @@ -127,7 +126,6 @@ class TestBuildSynthesisRequest: req = agg.build_synthesis_request( "corr-1", original_question="Original question", - user="testuser", collection="default", ) @@ -148,7 +146,7 @@ class TestBuildSynthesisRequest: agg.record_completion("corr-1", "goal-b", "answer-b") req = agg.build_synthesis_request( - "corr-1", "question", "user", "default", + "corr-1", "question", "default", ) # Last history step should be the synthesis step @@ -168,7 +166,7 @@ class TestBuildSynthesisRequest: agg.record_completion("corr-1", "goal-a", "answer-a") agg.build_synthesis_request( - "corr-1", "question", "user", "default", + "corr-1", "question", "default", ) # Entry should be removed @@ -178,7 +176,7 @@ class TestBuildSynthesisRequest: agg = Aggregator() with pytest.raises(RuntimeError, match="No results"): agg.build_synthesis_request( - "unknown", "question", "user", "default", + "unknown", "question", "default", ) diff --git a/tests/unit/test_agent/test_completion_dispatch.py b/tests/unit/test_agent/test_completion_dispatch.py index 8c01f126..0d28d168 100644 --- a/tests/unit/test_agent/test_completion_dispatch.py +++ b/tests/unit/test_agent/test_completion_dispatch.py @@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator def _make_request(**kwargs): defaults = dict( question="Test question", - user="testuser", collection="default", ) defaults.update(kwargs) @@ -130,7 +129,6 @@ class TestAggregatorIntegration: synth = agg.build_synthesis_request( "corr-1", original_question="Original question", - user="testuser", collection="default", ) @@ -160,7 +158,7 @@ class TestAggregatorIntegration: agg.record_completion("corr-1", "goal", "answer") synth = agg.build_synthesis_request( - "corr-1", "question", "user", "default", + "corr-1", "question", "default", ) # correlation_id must be empty so it's not intercepted diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py index 63d87ba1..7a1ec4c1 100644 --- a/tests/unit/test_agent/test_orchestrator_provenance_integration.py +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -126,7 +126,6 @@ def make_base_request(**kwargs): state="", group=[], history=[], - user="testuser", collection="default", streaming=False, session_id="test-session-123", diff --git a/tests/unit/test_agent/test_pattern_base_subagent.py b/tests/unit/test_agent/test_pattern_base_subagent.py index 1523b592..bb176ba4 100644 --- a/tests/unit/test_agent/test_pattern_base_subagent.py +++ b/tests/unit/test_agent/test_pattern_base_subagent.py @@ -21,7 +21,6 @@ class MockProcessor: def _make_request(**kwargs): defaults = dict( question="Test question", - user="testuser", collection="default", ) defaults.update(kwargs) diff --git a/tests/unit/test_agent/test_tool_service.py b/tests/unit/test_agent/test_tool_service.py index 8bcf39ce..369a3c73 100644 --- a/tests/unit/test_agent/test_tool_service.py +++ b/tests/unit/test_agent/test_tool_service.py @@ -167,39 +167,28 @@ class TestToolServiceRequest: """Test cases for tool service request format""" def test_request_format(self): - """Test that request is properly formatted with user, config, and arguments""" - # Arrange - user = "alice" + """Test that request is properly formatted with config and arguments""" config_values = {"style": "pun", "collection": "jokes"} arguments = {"topic": "programming"} - # Act - simulate request building request = { - "user": user, "config": json.dumps(config_values), "arguments": json.dumps(arguments) } - # Assert - assert request["user"] == "alice" assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"} assert json.loads(request["arguments"]) == {"topic": "programming"} def test_request_with_empty_config(self): """Test request when no config values are provided""" - # Arrange - user = "bob" config_values = {} arguments = {"query": "test"} - # Act request = { - "user": user, "config": json.dumps(config_values) if config_values else "{}", "arguments": json.dumps(arguments) if arguments else "{}" } - # Assert assert request["config"] == "{}" assert json.loads(request["arguments"]) == {"query": "test"} @@ -386,18 +375,13 @@ class TestJokeServiceLogic: assert map_topic_to_category("random topic") == "default" assert map_topic_to_category("") == "default" - def test_joke_response_personalization(self): - """Test that joke responses include user personalization""" - # Arrange - user = "alice" + def test_joke_response_format(self): + """Test that joke response is formatted as expected""" style = "pun" joke = "Why do programmers prefer dark mode? Because light attracts bugs!" - # Act - response = f"Hey {user}! Here's a {style} for you:\n\n{joke}" + response = f"Here's a {style} for you:\n\n{joke}" - # Assert - assert "Hey alice!" in response assert "pun" in response assert joke in response @@ -439,20 +423,14 @@ class TestDynamicToolServiceBase: def test_request_parsing(self): """Test parsing of incoming request""" - # Arrange request_data = { - "user": "alice", "config": '{"style": "pun"}', "arguments": '{"topic": "programming"}' } - # Act - user = request_data.get("user", "trustgraph") config = json.loads(request_data["config"]) if request_data["config"] else {} arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {} - # Assert - assert user == "alice" assert config == {"style": "pun"} assert arguments == {"topic": "programming"} diff --git a/tests/unit/test_agent/test_tool_service_lifecycle.py b/tests/unit/test_agent/test_tool_service_lifecycle.py index 65cdb542..874ef0e6 100644 --- a/tests/unit/test_agent/test_tool_service_lifecycle.py +++ b/tests/unit/test_agent/test_tool_service_lifecycle.py @@ -1,6 +1,6 @@ """ Tests for tool service lifecycle, invoke contract, streaming responses, -multi-tenancy, and error propagation. +and error propagation. Tests the actual DynamicToolService, ToolService, and ToolServiceClient classes rather than plain dicts. @@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract: svc = DynamicToolService.__new__(DynamicToolService) with pytest.raises(NotImplementedError): - await svc.invoke("user", {}, {}) + await svc.invoke({}, {}) @pytest.mark.asyncio async def test_on_request_calls_invoke_with_parsed_args(self): @@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract: calls = [] - async def tracking_invoke(user, config, arguments): - calls.append({"user": user, "config": config, "arguments": arguments}) + async def tracking_invoke(config, arguments): + calls.append({"config": config, "arguments": arguments}) return "ok" svc.invoke = tracking_invoke @@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract: msg = MagicMock() msg.value.return_value = ToolServiceRequest( - user="alice", config='{"style": "pun"}', arguments='{"topic": "cats"}', ) @@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract: await svc.on_request(msg, MagicMock(), None) assert len(calls) == 1 - assert calls[0]["user"] == "alice" assert calls[0]["config"] == {"style": "pun"} assert calls[0]["arguments"] == {"topic": "cats"} - @pytest.mark.asyncio - async def test_on_request_empty_user_defaults_to_trustgraph(self): - """Empty user field should default to 'trustgraph'.""" - from trustgraph.base.dynamic_tool_service import DynamicToolService - - svc = DynamicToolService.__new__(DynamicToolService) - svc.id = "test-svc" - svc.producer = AsyncMock() - - received_user = None - - async def capture_invoke(user, config, arguments): - nonlocal received_user - received_user = user - return "ok" - - svc.invoke = capture_invoke - - if not hasattr(DynamicToolService, "tool_service_metric"): - DynamicToolService.tool_service_metric = MagicMock() - - msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="", config="", arguments="") - msg.properties.return_value = {"id": "req-2"} - - await svc.on_request(msg, MagicMock(), None) - - assert received_user == "trustgraph" - @pytest.mark.asyncio async def test_on_request_string_response_sent_directly(self): """String return from invoke → response field is the string.""" @@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def string_invoke(user, config, arguments): + async def string_invoke(config, arguments): return "hello world" svc.invoke = string_invoke @@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract: DynamicToolService.tool_service_metric = MagicMock() msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r1"} await svc.on_request(msg, MagicMock(), None) @@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def dict_invoke(user, config, arguments): + async def dict_invoke(config, arguments): return {"result": 42} svc.invoke = dict_invoke @@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract: DynamicToolService.tool_service_metric = MagicMock() msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r2"} await svc.on_request(msg, MagicMock(), None) @@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def failing_invoke(user, config, arguments): + async def failing_invoke(config, arguments): raise ValueError("bad input") svc.invoke = failing_invoke msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r3"} await svc.on_request(msg, MagicMock(), None) @@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def rate_limited_invoke(user, config, arguments): + async def rate_limited_invoke(config, arguments): raise TooManyRequests("rate limited") svc.invoke = rate_limited_invoke msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "r4"} with pytest.raises(TooManyRequests): @@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract: svc.id = "test-svc" svc.producer = AsyncMock() - async def ok_invoke(user, config, arguments): + async def ok_invoke(config, arguments): return "ok" svc.invoke = ok_invoke @@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract: DynamicToolService.tool_service_metric = MagicMock() msg = MagicMock() - msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}") msg.properties.return_value = {"id": "unique-42"} await svc.on_request(msg, MagicMock(), None) @@ -241,7 +210,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def mock_invoke(name, params): + async def mock_invoke(workspace, name, params): return "tool result" svc.invoke_tool = mock_invoke @@ -260,6 +229,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}') @@ -280,7 +250,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def mock_invoke(name, params): + async def mock_invoke(workspace, name, params): return {"data": [1, 2, 3]} svc.invoke_tool = mock_invoke @@ -298,6 +268,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") @@ -317,7 +288,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def failing_invoke(name, params): + async def failing_invoke(workspace, name, params): raise RuntimeError("tool broke") svc.invoke_tool = failing_invoke @@ -330,6 +301,7 @@ class TestToolServiceOnRequest: flow_callable.producer = {"response": mock_response_pub} flow_callable.name = "test-flow" + flow_callable.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") @@ -350,7 +322,7 @@ class TestToolServiceOnRequest: svc = ToolService.__new__(ToolService) svc.id = "test-tool" - async def rate_limited(name, params): + async def rate_limited(workspace, name, params): raise TooManyRequests("slow down") svc.invoke_tool = rate_limited @@ -362,6 +334,7 @@ class TestToolServiceOnRequest: flow = MagicMock() flow.producer = {"response": AsyncMock()} flow.name = "test-flow" + flow.workspace = "default" with pytest.raises(TooManyRequests): await svc.on_request(msg, MagicMock(), flow) @@ -376,7 +349,8 @@ class TestToolServiceOnRequest: received = {} - async def capture_invoke(name, params): + async def capture_invoke(workspace, name, params): + received["workspace"] = workspace received["name"] = name received["params"] = params return "ok" @@ -390,6 +364,7 @@ class TestToolServiceOnRequest: flow = lambda name: mock_pub flow.producer = {"response": mock_pub} flow.name = "f" + flow.workspace = "default" msg = MagicMock() msg.value.return_value = ToolRequest( @@ -421,7 +396,6 @@ class TestToolServiceClientCall: )) result = await client.call( - user="alice", config={"style": "pun"}, arguments={"topic": "cats"}, ) @@ -430,7 +404,6 @@ class TestToolServiceClientCall: req = client.request.call_args[0][0] assert isinstance(req, ToolServiceRequest) - assert req.user == "alice" assert json.loads(req.config) == {"style": "pun"} assert json.loads(req.arguments) == {"topic": "cats"} @@ -446,7 +419,7 @@ class TestToolServiceClientCall: )) with pytest.raises(RuntimeError, match="service down"): - await client.call(user="u", config={}, arguments={}) + await client.call(config={}, arguments={}) @pytest.mark.asyncio async def test_call_empty_config_sends_empty_json(self): @@ -458,7 +431,7 @@ class TestToolServiceClientCall: error=None, response="ok", )) - await client.call(user="u", config=None, arguments=None) + await client.call(config=None, arguments=None) req = client.request.call_args[0][0] assert req.config == "{}" @@ -474,7 +447,7 @@ class TestToolServiceClientCall: error=None, response="ok", )) - await client.call(user="u", config={}, arguments={}, timeout=30) + await client.call(config={}, arguments={}, timeout=30) _, kwargs = client.request.call_args assert kwargs["timeout"] == 30 @@ -509,7 +482,7 @@ class TestToolServiceClientStreaming: received.append(text) result = await client.call_streaming( - user="u", config={}, arguments={}, callback=callback, + config={}, arguments={}, callback=callback, ) assert result == "chunk1chunk2" @@ -534,7 +507,7 @@ class TestToolServiceClientStreaming: with pytest.raises(RuntimeError, match="stream failed"): await client.call_streaming( - user="u", config={}, arguments={}, + config={}, arguments={}, callback=AsyncMock(), ) @@ -564,61 +537,9 @@ class TestToolServiceClientStreaming: received.append(text) result = await client.call_streaming( - user="u", config={}, arguments={}, callback=callback, + config={}, arguments={}, callback=callback, ) # Empty response is falsy, so callback shouldn't be called for it assert result == "data" assert received == ["data"] - - -# --------------------------------------------------------------------------- -# Multi-tenancy -# --------------------------------------------------------------------------- - -class TestMultiTenancy: - - @pytest.mark.asyncio - async def test_user_propagated_to_invoke(self): - """User from request should reach the invoke method.""" - from trustgraph.base.dynamic_tool_service import DynamicToolService - - svc = DynamicToolService.__new__(DynamicToolService) - svc.id = "test" - svc.producer = AsyncMock() - - users_seen = [] - - async def tracking(user, config, arguments): - users_seen.append(user) - return "ok" - - svc.invoke = tracking - - if not hasattr(DynamicToolService, "tool_service_metric"): - DynamicToolService.tool_service_metric = MagicMock() - - for u in ["tenant-a", "tenant-b", "tenant-c"]: - msg = MagicMock() - msg.value.return_value = ToolServiceRequest( - user=u, config="{}", arguments="{}", - ) - msg.properties.return_value = {"id": f"req-{u}"} - await svc.on_request(msg, MagicMock(), None) - - assert users_seen == ["tenant-a", "tenant-b", "tenant-c"] - - @pytest.mark.asyncio - async def test_client_sends_user_in_request(self): - """ToolServiceClient.call should include user in request.""" - from trustgraph.base.tool_service_client import ToolServiceClient - - client = ToolServiceClient.__new__(ToolServiceClient) - client.request = AsyncMock(return_value=ToolServiceResponse( - error=None, response="ok", - )) - - await client.call(user="isolated-tenant", config={}, arguments={}) - - req = client.request.call_args[0][0] - assert req.user == "isolated-tenant" diff --git a/tests/unit/test_base/test_async_processor_config.py b/tests/unit/test_base/test_async_processor_config.py index f1a83fef..3dffd775 100644 --- a/tests/unit/test_base/test_async_processor_config.py +++ b/tests/unit/test_base/test_async_processor_config.py @@ -1,17 +1,14 @@ """ Tests for AsyncProcessor config notify pattern: - register_config_handler with types filtering -- on_config_notify version comparison and type matching -- fetch_config with short-lived client -- fetch_and_apply_config retry logic +- on_config_notify version comparison, type/workspace matching +- fetch_and_apply_config retry logic over per-workspace fetches """ import pytest from unittest.mock import AsyncMock, MagicMock, patch, Mock -from trustgraph.schema import Term, IRI, LITERAL -# Patch heavy dependencies before importing AsyncProcessor @pytest.fixture def processor(): """Create an AsyncProcessor with mocked dependencies.""" @@ -68,6 +65,13 @@ class TestRegisterConfigHandler: assert len(processor.config_handlers) == 2 +def _notify_msg(version, changes): + """Build a Mock config-notify message with given version and changes dict.""" + msg = Mock() + msg.value.return_value = Mock(version=version, changes=changes) + return msg + + class TestOnConfigNotify: @pytest.mark.asyncio @@ -77,9 +81,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=3, types=["prompt"]) - + msg = _notify_msg(3, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -91,9 +93,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=5, types=["prompt"]) - + msg = _notify_msg(5, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -105,9 +105,7 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - msg = Mock() - msg.value.return_value = Mock(version=2, types=["schema"]) - + msg = _notify_msg(2, {"schema": ["default"]}) await processor.on_config_notify(msg, None, None) handler.assert_not_called() @@ -121,40 +119,36 @@ class TestOnConfigNotify: handler = AsyncMock() processor.register_config_handler(handler, types=["prompt"]) - # Mock fetch_config - mock_config = {"prompt": {"key": "value"}} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={"key": "value"}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) - handler.assert_called_once_with(mock_config, 2) + handler.assert_called_once_with( + "default", {"prompt": {"key": "value"}}, 2 + ) assert processor.config_version == 2 @pytest.mark.asyncio - async def test_handler_without_types_always_called(self, processor): + async def test_handler_without_types_ignored_on_notify(self, processor): + """Handlers registered without types never fire on notifications.""" processor.config_version = 1 handler = AsyncMock() - processor.register_config_handler(handler) # No types = all + processor.register_config_handler(handler) # No types - mock_config = {"anything": {}} - with patch.object( - processor, 'fetch_config', - new_callable=AsyncMock, - return_value=(mock_config, 2) - ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["whatever"]) + msg = _notify_msg(2, {"whatever": ["default"]}) + await processor.on_config_notify(msg, None, None) - await processor.on_config_notify(msg, None, None) - - handler.assert_called_once_with(mock_config, 2) + handler.assert_not_called() + # Version still advances past the notify + assert processor.config_version == 2 @pytest.mark.asyncio async def test_mixed_handlers_type_filtering(self, processor): @@ -168,156 +162,149 @@ class TestOnConfigNotify: processor.register_config_handler(schema_handler, types=["schema"]) processor.register_config_handler(all_handler) - mock_config = {"prompt": {}} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) await processor.on_config_notify(msg, None, None) - prompt_handler.assert_called_once() + prompt_handler.assert_called_once_with( + "default", {"prompt": {}}, 2 + ) schema_handler.assert_not_called() - all_handler.assert_called_once() + all_handler.assert_not_called() @pytest.mark.asyncio - async def test_empty_types_invokes_all(self, processor): - """Empty types list (startup signal) should invoke all handlers.""" + async def test_multi_workspace_notify_invokes_handler_per_ws( + self, processor + ): + """Notify affecting multiple workspaces invokes handler once per workspace.""" processor.config_version = 1 - h1 = AsyncMock() - h2 = AsyncMock() - processor.register_config_handler(h1, types=["prompt"]) - processor.register_config_handler(h2, types=["schema"]) + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) - mock_config = {} + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - return_value=(mock_config, 2) + return_value={}, ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=[]) - + msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]}) await processor.on_config_notify(msg, None, None) - h1.assert_called_once() - h2.assert_called_once() + assert handler.call_count == 2 + called_workspaces = {c.args[0] for c in handler.call_args_list} + assert called_workspaces == {"ws1", "ws2"} @pytest.mark.asyncio async def test_fetch_failure_handled(self, processor): processor.config_version = 1 handler = AsyncMock() - processor.register_config_handler(handler) + processor.register_config_handler(handler, types=["prompt"]) + mock_client = AsyncMock() with patch.object( - processor, 'fetch_config', + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_workspace', new_callable=AsyncMock, - side_effect=RuntimeError("Connection failed") + side_effect=RuntimeError("Connection failed"), ): - msg = Mock() - msg.value.return_value = Mock(version=2, types=["prompt"]) - + msg = _notify_msg(2, {"prompt": ["default"]}) # Should not raise await processor.on_config_notify(msg, None, None) handler.assert_not_called() -class TestFetchConfig: - - @pytest.mark.asyncio - async def test_fetch_returns_config_and_version(self, processor): - mock_resp = Mock() - mock_resp.error = None - mock_resp.config = {"prompt": {"key": "val"}} - mock_resp.version = 42 - - mock_client = AsyncMock() - mock_client.request.return_value = mock_resp - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - config, version = await processor.fetch_config() - - assert config == {"prompt": {"key": "val"}} - assert version == 42 - mock_client.stop.assert_called_once() - - @pytest.mark.asyncio - async def test_fetch_raises_on_error_response(self, processor): - mock_resp = Mock() - mock_resp.error = Mock(message="not found") - mock_resp.config = {} - mock_resp.version = 0 - - mock_client = AsyncMock() - mock_client.request.return_value = mock_resp - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - with pytest.raises(RuntimeError, match="Config error"): - await processor.fetch_config() - - mock_client.stop.assert_called_once() - - @pytest.mark.asyncio - async def test_fetch_stops_client_on_exception(self, processor): - mock_client = AsyncMock() - mock_client.request.side_effect = TimeoutError("timeout") - - with patch.object( - processor, '_create_config_client', return_value=mock_client - ): - with pytest.raises(TimeoutError): - await processor.fetch_config() - - mock_client.stop.assert_called_once() - - class TestFetchAndApplyConfig: @pytest.mark.asyncio - async def test_applies_config_to_all_handlers(self, processor): - h1 = AsyncMock() - h2 = AsyncMock() - processor.register_config_handler(h1, types=["prompt"]) - processor.register_config_handler(h2, types=["schema"]) + async def test_applies_config_per_workspace(self, processor): + """Startup fetch invokes handler once per workspace affected.""" + h = AsyncMock() + processor.register_config_handler(h, types=["prompt"]) + + mock_client = AsyncMock() + + async def fake_fetch_all(client, config_type): + return { + "ws1": {"k": "v1"}, + "ws2": {"k": "v2"}, + }, 10 - mock_config = {"prompt": {}, "schema": {}} with patch.object( - processor, 'fetch_config', - new_callable=AsyncMock, - return_value=(mock_config, 10) + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, ): await processor.fetch_and_apply_config() - # On startup, all handlers are invoked regardless of type - h1.assert_called_once_with(mock_config, 10) - h2.assert_called_once_with(mock_config, 10) + assert h.call_count == 2 + call_map = {c.args[0]: c.args[1] for c in h.call_args_list} + assert call_map["ws1"] == {"prompt": {"k": "v1"}} + assert call_map["ws2"] == {"prompt": {"k": "v2"}} assert processor.config_version == 10 @pytest.mark.asyncio - async def test_retries_on_failure(self, processor): - call_count = 0 - mock_config = {"prompt": {}} + async def test_handler_without_types_skipped_at_startup(self, processor): + """Handlers registered without types fetch nothing at startup.""" + typed = AsyncMock() + untyped = AsyncMock() + processor.register_config_handler(typed, types=["prompt"]) + processor.register_config_handler(untyped) - async def mock_fetch(): + mock_client = AsyncMock() + + async def fake_fetch_all(client, config_type): + return {"default": {}}, 1 + + with patch.object( + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, + ): + await processor.fetch_and_apply_config() + + typed.assert_called_once() + untyped.assert_not_called() + + @pytest.mark.asyncio + async def test_retries_on_failure(self, processor): + h = AsyncMock() + processor.register_config_handler(h, types=["prompt"]) + + call_count = 0 + + async def fake_fetch_all(client, config_type): nonlocal call_count call_count += 1 if call_count < 3: raise RuntimeError("not ready") - return mock_config, 5 + return {"default": {"k": "v"}}, 5 - with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \ - patch('asyncio.sleep', new_callable=AsyncMock): + mock_client = AsyncMock() + with patch.object( + processor, '_create_config_client', return_value=mock_client + ), patch.object( + processor, '_fetch_type_all_workspaces', + new=fake_fetch_all, + ), patch('asyncio.sleep', new_callable=AsyncMock): await processor.fetch_and_apply_config() assert call_count == 3 assert processor.config_version == 5 + h.assert_called_once_with( + "default", {"prompt": {"k": "v"}}, 5 + ) diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py index 705f2bd1..ff9e67e9 100644 --- a/tests/unit/test_base/test_document_embeddings_client.py +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): result = await client.query( vector=vector, limit=10, - user="test_user", collection="test_collection", timeout=30 ) @@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): assert isinstance(call_args, DocumentEmbeddingsRequest) assert call_args.vector == vector assert call_args.limit == 10 - assert call_args.user == "test_user" assert call_args.collection == "test_collection" @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') @@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client.request.assert_called_once() call_args = client.request.call_args[0][0] assert call_args.limit == 20 # Default limit - assert call_args.user == "trustgraph" # Default user assert call_args.collection == "default" # Default collection @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') diff --git a/tests/unit/test_base/test_flow_base_modules.py b/tests/unit/test_base/test_flow_base_modules.py index 5bbd7a18..758edcff 100644 --- a/tests/unit/test_base/test_flow_base_modules.py +++ b/tests/unit/test_base/test_flow_base_modules.py @@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs(): spec_two = MagicMock() processor = MagicMock(specifications=[spec_one, spec_two]) - flow = Flow("processor-1", "flow-a", processor, {"answer": 42}) + flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42}) assert flow.id == "processor-1" assert flow.name == "flow-a" + assert flow.workspace == "default" assert flow.producer == {} assert flow.consumer == {} assert flow.parameter == {} @@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs(): def test_flow_start_and_stop_visit_all_consumers(): consumer_one = AsyncMock() consumer_two = AsyncMock() - flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) + flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {}) flow.consumer = {"one": consumer_one, "two": consumer_two} asyncio.run(flow.start()) @@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers(): def test_flow_call_returns_values_in_priority_order(): - flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {}) + flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {}) flow.producer["shared"] = "producer-value" flow.consumer["consumer-only"] = "consumer-value" flow.consumer["shared"] = "consumer-value" diff --git a/tests/unit/test_base/test_flow_parameter_specs.py b/tests/unit/test_base/test_flow_parameter_specs.py index c813d66c..da7e9736 100644 --- a/tests/unit/test_base/test_flow_parameter_specs.py +++ b/tests/unit/test_base/test_flow_parameter_specs.py @@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase): flow_defn = {'config': 'test-config'} # Act - await processor.start_flow(flow_name, flow_defn) + await processor.start_flow("default", flow_name, flow_defn) # Assert - Flow should be created with access to processor specifications - mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn) + mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn) # The flow should have access to the processor's specifications # (The exact mechanism depends on Flow implementation) diff --git a/tests/unit/test_base/test_flow_processor.py b/tests/unit/test_base/test_flow_processor.py index 36a05ec2..350a8b43 100644 --- a/tests/unit/test_base/test_flow_processor.py +++ b/tests/unit/test_base/test_flow_processor.py @@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): flow_name = 'test-flow' flow_defn = {'config': 'test-config'} - await processor.start_flow(flow_name, flow_defn) + await processor.start_flow("default", flow_name, flow_defn) - assert flow_name in processor.flows + assert ("default", flow_name) in processor.flows mock_flow_class.assert_called_once_with( - 'test-processor', flow_name, processor, flow_defn + 'test-processor', flow_name, "default", processor, flow_defn ) mock_flow.start.assert_called_once() @@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): mock_flow_class.return_value = mock_flow flow_name = 'test-flow' - await processor.start_flow(flow_name, {'config': 'test-config'}) + await processor.start_flow("default", flow_name, {'config': 'test-config'}) - await processor.stop_flow(flow_name) + await processor.stop_flow("default", flow_name) - assert flow_name not in processor.flows + assert ("default", flow_name) not in processor.flows mock_flow.stop.assert_called_once() @with_async_processor_patches @@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): processor = FlowProcessor(**config) - await processor.stop_flow('non-existent-flow') + await processor.stop_flow("default", 'non-existent-flow') assert processor.flows == {} @@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) - assert 'test-flow' in processor.flows + assert ("default", 'test-flow') in processor.flows mock_flow_class.assert_called_once_with( - 'test-processor', 'test-flow', processor, + 'test-processor', 'test-flow', "default", processor, {'config': 'test-config'} ) mock_flow.start.assert_called_once() @@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) assert processor.flows == {} @@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): 'other-data': 'some-value' } - await processor.on_configure_flows(config_data, version=1) + await processor.on_configure_flows("default", config_data, version=1) assert processor.flows == {} @@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data1, version=1) + await processor.on_configure_flows("default", config_data1, version=1) config_data2 = { 'processor:test-processor': { @@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): } } - await processor.on_configure_flows(config_data2, version=2) + await processor.on_configure_flows("default", config_data2, version=2) - assert 'flow1' not in processor.flows + assert ("default", 'flow1') not in processor.flows mock_flow1.stop.assert_called_once() - assert 'flow2' in processor.flows + assert ("default", 'flow2') in processor.flows mock_flow2.start.assert_called_once() @with_async_processor_patches diff --git a/tests/unit/test_chunking/conftest.py b/tests/unit/test_chunking/conftest.py index 31dab77d..c1f9ae33 100644 --- a/tests/unit/test_chunking/conftest.py +++ b/tests/unit/test_chunking/conftest.py @@ -28,7 +28,6 @@ def sample_text_document(): """Sample document with moderate length text.""" metadata = Metadata( id="test-doc-1", - user="test-user", collection="test-collection" ) text = "The quick brown fox jumps over the lazy dog. " * 20 @@ -43,7 +42,6 @@ def long_text_document(): """Long document for testing multiple chunks.""" metadata = Metadata( id="test-doc-long", - user="test-user", collection="test-collection" ) # Create a long text that will definitely be chunked @@ -59,7 +57,6 @@ def unicode_text_document(): """Document with various unicode characters.""" metadata = Metadata( id="test-doc-unicode", - user="test-user", collection="test-collection" ) text = """ @@ -84,7 +81,6 @@ def empty_text_document(): """Empty document for edge case testing.""" metadata = Metadata( id="test-doc-empty", - user="test-user", collection="test-collection" ) return TextDocument( diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index d1a5d247..74178ab4 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_text_doc = MagicMock() mock_text_doc.metadata = Metadata( id="test-doc-123", - user="test-user", collection="test-collection" ) mock_text_doc.text = b"This is test document content" diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index dba4ca94..568b335f 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_text_doc = MagicMock() mock_text_doc.metadata = Metadata( id="test-doc-456", - user="test-user", collection="test-collection" ) mock_text_doc.text = b"This is test document content for token chunking" diff --git a/tests/unit/test_cli/test_config_commands.py b/tests/unit/test_cli/test_config_commands.py index 68ae1a54..b5b74688 100644 --- a/tests/unit/test_cli/test_config_commands.py +++ b/tests/unit/test_cli/test_config_commands.py @@ -109,7 +109,8 @@ class TestListConfigItems: url='http://custom.com', config_type='prompt', format_type='json', - token=None + token=None, + workspace='default' ) def test_list_main_uses_defaults(self): @@ -128,7 +129,8 @@ class TestListConfigItems: url='http://localhost:8088/', config_type='prompt', format_type='text', - token=None + token=None, + workspace='default' ) @@ -196,7 +198,8 @@ class TestGetConfigItem: config_type='prompt', key='template-1', format_type='json', - token=None + token=None, + workspace='default' ) @@ -253,7 +256,8 @@ class TestPutConfigItem: config_type='prompt', key='new-template', value='Custom prompt: {input}', - token=None + token=None, + workspace='default' ) def test_put_main_with_stdin_arg(self): @@ -278,7 +282,8 @@ class TestPutConfigItem: config_type='prompt', key='stdin-template', value=stdin_content, - token=None + token=None, + workspace='default' ) def test_put_main_mutually_exclusive_args(self): @@ -334,7 +339,8 @@ class TestDeleteConfigItem: url='http://custom.com', config_type='prompt', key='old-template', - token=None + token=None, + workspace='default' ) diff --git a/tests/unit/test_cli/test_load_knowledge.py b/tests/unit/test_cli/test_load_knowledge.py index 63045ef9..a1588e85 100644 --- a/tests/unit/test_cli/test_load_knowledge.py +++ b/tests/unit/test_cli/test_load_knowledge.py @@ -48,7 +48,7 @@ def knowledge_loader(): return KnowledgeLoader( files=["test.ttl"], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc-123", url="http://test.example.com/", @@ -64,7 +64,7 @@ class TestKnowledgeLoader: loader = KnowledgeLoader( files=["file1.ttl", "file2.ttl"], flow="my-flow", - user="user1", + workspace="user1", collection="col1", document_id="doc1", url="http://example.com/", @@ -73,7 +73,7 @@ class TestKnowledgeLoader: assert loader.files == ["file1.ttl", "file2.ttl"] assert loader.flow == "my-flow" - assert loader.user == "user1" + assert loader.workspace == "user1" assert loader.collection == "col1" assert loader.document_id == "doc1" assert loader.url == "http://example.com/" @@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob . loader = KnowledgeLoader( files=[f.name], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc", url="http://test.example.com/" @@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob . loader = KnowledgeLoader( files=[temp_turtle_file], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc", url="http://test.example.com/", @@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob . # Verify Api was created with correct parameters mock_api_class.assert_called_once_with( url="http://test.example.com/", - token="test-token" + token="test-token", + workspace="test-user" ) # Verify bulk client was obtained @@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob . call_args = mock_bulk.import_triples.call_args assert call_args[1]['flow'] == "test-flow" assert call_args[1]['metadata']['id'] == "test-doc" - assert call_args[1]['metadata']['user'] == "test-user" assert call_args[1]['metadata']['collection'] == "test-collection" # Verify import_entity_contexts was called @@ -198,7 +198,7 @@ class TestCLIArgumentParsing: 'tg-load-knowledge', '-i', 'doc-123', '-f', 'my-flow', - '-U', 'my-user', + '-w', 'my-user', '-C', 'my-collection', '-u', 'http://custom.example.com/', '-t', 'my-token', @@ -216,7 +216,7 @@ class TestCLIArgumentParsing: token='my-token', flow='my-flow', files=['file1.ttl', 'file2.ttl'], - user='my-user', + workspace='my-user', collection='my-collection' ) @@ -242,7 +242,7 @@ class TestCLIArgumentParsing: # Verify defaults were used call_args = mock_loader_class.call_args[1] assert call_args['flow'] == 'default' - assert call_args['user'] == 'trustgraph' + assert call_args['workspace'] == 'default' assert call_args['collection'] == 'default' assert call_args['url'] == 'http://localhost:8088/' assert call_args['token'] is None @@ -287,7 +287,7 @@ class TestErrorHandling: loader = KnowledgeLoader( files=[temp_turtle_file], flow="test-flow", - user="test-user", + workspace="test-user", collection="test-collection", document_id="test-doc", url="http://test.example.com/" diff --git a/tests/unit/test_cli/test_tool_commands.py b/tests/unit/test_cli/test_tool_commands.py index 9c204614..72624d27 100644 --- a/tests/unit/test_cli/test_tool_commands.py +++ b/tests/unit/test_cli/test_tool_commands.py @@ -145,7 +145,8 @@ class TestSetToolStructuredQuery: group=None, state=None, applicable_states=None, - token=None + token=None, + workspace='default' ) def test_set_main_structured_query_no_arguments_needed(self): @@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery: group=None, state=None, applicable_states=None, - token=None + token=None, + workspace='default' ) def test_valid_types_includes_row_embeddings_query(self): @@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery: show_main() - mock_show.assert_called_once_with(url='http://custom.com', token=None) + mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default') class TestShowToolsRowEmbeddingsQuery: diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index edf4ac81..6c466877 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient: # Act result = client.request( vector=vector, - user="test_user", collection="test_collection", limit=10, timeout=300 @@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient: # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.call.assert_called_once_with( - user="test_user", collection="test_collection", vector=vector, limit=10, @@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient: # Assert assert result == ["test_chunk"] client.call.assert_called_once_with( - user="trustgraph", collection="default", vector=vector, limit=10, diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py index 8287427b..1b35a238 100644 --- a/tests/unit/test_concurrency/test_graph_rag_concurrency.py +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -31,7 +31,6 @@ def _make_query( query = Query( rag=rag, - user="test-user", collection="test-collection", verbose=False, entity_limit=entity_limit, @@ -208,7 +207,6 @@ class TestBatchTripleQueries: assert calls[0].kwargs["p"] is None assert calls[0].kwargs["o"] is None assert calls[0].kwargs["limit"] == 15 - assert calls[0].kwargs["user"] == "test-user" assert calls[0].kwargs["collection"] == "test-collection" assert calls[0].kwargs["batch_size"] == 20 diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index 80c27fe8..d677b82f 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -28,10 +28,12 @@ def mock_flow_config(): """Mock flow configuration.""" mock_config = Mock() mock_config.flows = { - "test-flow": { - "interfaces": { - "triples-store": {"flow": "test-triples-queue"}, - "graph-embeddings-store": {"flow": "test-ge-queue"} + "test-user": { + "test-flow": { + "interfaces": { + "triples-store": {"flow": "test-triples-queue"}, + "graph-embeddings-store": {"flow": "test-ge-queue"} + } } } } @@ -43,7 +45,7 @@ def mock_flow_config(): def mock_request(): """Mock knowledge load request.""" request = Mock() - request.user = "test-user" + request.workspace = "test-user" request.id = "test-doc-id" request.collection = "test-collection" request.flow = "test-flow" @@ -71,7 +73,6 @@ def sample_triples(): return Triples( metadata=Metadata( id="test-doc-id", - user="test-user", collection="default", # This should be overridden ), triples=[ @@ -90,7 +91,6 @@ def sample_graph_embeddings(): return GraphEmbeddings( metadata=Metadata( id="test-doc-id", - user="test-user", collection="default", # This should be overridden ), entities=[ @@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore: mock_triples_pub.send.assert_called_once() sent_triples = mock_triples_pub.send.call_args[0][1] assert sent_triples.metadata.collection == "test-collection" - assert sent_triples.metadata.user == "test-user" assert sent_triples.metadata.id == "test-doc-id" @pytest.mark.asyncio @@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore: mock_ge_pub.send.assert_called_once() sent_ge = mock_ge_pub.send.call_args[0][1] assert sent_ge.metadata.collection == "test-collection" - assert sent_ge.metadata.user == "test-user" assert sent_ge.metadata.id == "test-doc-id" @pytest.mark.asyncio @@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore: """Test that load_kg_core falls back to 'default' when request.collection is None.""" # Create request with None collection mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_request.collection = None # Should fall back to "default" mock_request.flow = "test-flow" @@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore: """Test that load_kg_core validates flow configuration before processing.""" # Request with invalid flow mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_request.collection = "test-collection" mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows @@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore: # Test missing ID mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = None # Missing mock_request.collection = "test-collection" mock_request.flow = "test-flow" @@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods: async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples): """Test that get_kg_core preserves collection field from stored data.""" mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_respond = AsyncMock() @@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods: async def test_list_kg_cores(self, knowledge_manager): """Test listing knowledge cores.""" mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_respond = AsyncMock() @@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods: async def test_delete_kg_core(self, knowledge_manager): """Test deleting knowledge cores.""" mock_request = Mock() - mock_request.user = "test-user" + mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_respond = AsyncMock() diff --git a/tests/unit/test_decoding/test_universal_processor.py b/tests/unit/test_decoding/test_universal_processor.py index 4daa9b68..36804860 100644 --- a/tests/unit/test_decoding/test_universal_processor.py +++ b/tests/unit/test_decoding/test_universal_processor.py @@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): # Mock message with inline data content = b"# Document Title\nBody text content." - mock_metadata = Metadata(id="test-doc", user="testuser", + mock_metadata = Metadata(id="test-doc", collection="default") mock_document = Document( metadata=mock_metadata, @@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): # Mock message content = b"fake pdf" - mock_metadata = Metadata(id="test-doc", user="testuser", + mock_metadata = Metadata(id="test-doc", collection="default") mock_document = Document( metadata=mock_metadata, @@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): ] content = b"fake pdf" - mock_metadata = Metadata(id="test-doc", user="testuser", + mock_metadata = Metadata(id="test-doc", collection="default") mock_document = Document( metadata=mock_metadata, diff --git a/tests/unit/test_direct/test_milvus_collection_naming.py b/tests/unit/test_direct/test_milvus_collection_naming.py index d948caff..57c00a54 100644 --- a/tests/unit/test_direct/test_milvus_collection_naming.py +++ b/tests/unit/test_direct/test_milvus_collection_naming.py @@ -12,7 +12,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_basic(self): """Test basic collection name creation""" result = make_safe_collection_name( - user="test_user", + workspace="test_user", collection="test_collection", prefix="doc" ) @@ -21,7 +21,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_special_characters(self): """Test collection name creation with special characters that need sanitization""" result = make_safe_collection_name( - user="user@domain.com", + workspace="user@domain.com", collection="test-collection.v2", prefix="entity" ) @@ -30,7 +30,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_unicode(self): """Test collection name creation with Unicode characters""" result = make_safe_collection_name( - user="测试用户", + workspace="测试用户", collection="colección_española", prefix="doc" ) @@ -39,7 +39,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_spaces(self): """Test collection name creation with spaces""" result = make_safe_collection_name( - user="test user", + workspace="test user", collection="my test collection", prefix="entity" ) @@ -48,7 +48,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self): """Test collection name creation with multiple consecutive special characters""" result = make_safe_collection_name( - user="user@@@domain!!!", + workspace="user@@@domain!!!", collection="test---collection...v2", prefix="doc" ) @@ -57,7 +57,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_with_leading_trailing_underscores(self): """Test collection name creation with leading/trailing special characters""" result = make_safe_collection_name( - user="__test_user__", + workspace="__test_user__", collection="@@test_collection##", prefix="entity" ) @@ -66,7 +66,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_empty_user(self): """Test collection name creation with empty user (should fallback to 'default')""" result = make_safe_collection_name( - user="", + workspace="", collection="test_collection", prefix="doc" ) @@ -75,7 +75,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_empty_collection(self): """Test collection name creation with empty collection (should fallback to 'default')""" result = make_safe_collection_name( - user="test_user", + workspace="test_user", collection="", prefix="doc" ) @@ -84,7 +84,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_both_empty(self): """Test collection name creation with both user and collection empty""" result = make_safe_collection_name( - user="", + workspace="", collection="", prefix="doc" ) @@ -93,7 +93,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_only_special_characters(self): """Test collection name creation with only special characters (should fallback to 'default')""" result = make_safe_collection_name( - user="@@@!!!", + workspace="@@@!!!", collection="---###", prefix="entity" ) @@ -102,7 +102,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_whitespace_only(self): """Test collection name creation with whitespace-only strings""" result = make_safe_collection_name( - user=" \n\t ", + workspace=" \n\t ", collection=" \r\n ", prefix="doc" ) @@ -111,7 +111,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_mixed_valid_invalid_chars(self): """Test collection name creation with mixed valid and invalid characters""" result = make_safe_collection_name( - user="user123@test", + workspace="user123@test", collection="coll_2023.v1", prefix="entity" ) @@ -147,7 +147,7 @@ class TestMilvusCollectionNaming: long_collection = "b" * 100 result = make_safe_collection_name( - user=long_user, + workspace=long_user, collection=long_collection, prefix="doc" ) @@ -159,7 +159,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_numeric_values(self): """Test collection name creation with numeric user/collection values""" result = make_safe_collection_name( - user="user123", + workspace="user123", collection="collection456", prefix="doc" ) @@ -168,7 +168,7 @@ class TestMilvusCollectionNaming: def test_make_safe_collection_name_case_sensitivity(self): """Test that collection name creation preserves case""" result = make_safe_collection_name( - user="TestUser", + workspace="TestUser", collection="TestCollection", prefix="Doc" ) diff --git a/tests/unit/test_embeddings/test_document_embeddings_processor.py b/tests/unit/test_embeddings/test_document_embeddings_processor.py index 9cd93c4f..314d81c3 100644 --- a/tests/unit/test_embeddings/test_document_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_document_embeddings_processor.py @@ -20,9 +20,8 @@ def processor(): ) -def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", - user="test", collection="default"): - metadata = Metadata(id=doc_id, user=user, collection=collection) +def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"): + metadata = Metadata(id=doc_id, collection=collection) value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id) msg = MagicMock() msg.value.return_value = value @@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor: @pytest.mark.asyncio async def test_metadata_preserved(self, processor): """Output should carry the original metadata.""" - msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1") + msg = _make_chunk_message(collection="reports", doc_id="d1") mock_request = AsyncMock(return_value=EmbeddingsResponse( error=None, vectors=[[0.0]] @@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor: await processor.on_message(msg, MagicMock(), flow) result = mock_output.send.call_args[0][0] - assert result.metadata.user == "alice" assert result.metadata.collection == "reports" assert result.metadata.id == "d1" diff --git a/tests/unit/test_embeddings/test_graph_embeddings_processor.py b/tests/unit/test_embeddings/test_graph_embeddings_processor.py index 5d535349..f3cec4d2 100644 --- a/tests/unit/test_embeddings/test_graph_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_graph_embeddings_processor.py @@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"): return MagicMock(entity=entity, context=context, chunk_id=chunk_id) -def _make_message(entities, doc_id="doc-1", user="test", collection="default"): - metadata = Metadata(id=doc_id, user=user, collection=collection) +def _make_message(entities, doc_id="doc-1", collection="default"): + metadata = Metadata(id=doc_id, collection=collection) value = EntityContexts(metadata=metadata, entities=entities) msg = MagicMock() msg.value.return_value = value @@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing: _make_entity_context(f"E{i}", f"ctx {i}") for i in range(5) ] - msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main") + msg = _make_message(entities, doc_id="doc-42", collection="main") mock_embed = AsyncMock(return_value=[[0.0]] * 5) mock_output = AsyncMock() @@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing: for call in mock_output.send.call_args_list: result = call[0][0] assert result.metadata.id == "doc-42" - assert result.metadata.user == "alice" assert result.metadata.collection == "main" @pytest.mark.asyncio diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py index 45a22e48..36ecd013 100644 --- a/tests/unit/test_embeddings/test_row_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } } - await processor.on_schema_config(config_data, 1) + await processor.on_schema_config("default", config_data, 1) - assert 'customers' in processor.schemas - assert processor.schemas['customers'].name == 'customers' - assert len(processor.schemas['customers'].fields) == 3 + assert 'customers' in processor.schemas["default"] + assert processor.schemas["default"]['customers'].name == 'customers' + assert len(processor.schemas["default"]['customers'].fields) == 3 async def test_on_schema_config_handles_missing_type(self): """Test that missing schema type is handled gracefully""" @@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): 'other_type': {} } - await processor.on_schema_config(config_data, 1) + await processor.on_schema_config("default", config_data, 1) - assert processor.schemas == {} + assert processor.schemas.get("default", {}) == {} async def test_on_message_drops_unknown_collection(self): """Test that messages for unknown collections are dropped""" @@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('default', 'test_collection')] = {} # No schemas registered metadata = MagicMock() @@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('default', 'test_collection')] = {} # Set up schema - processor.schemas['customers'] = RowSchema( - name='customers', - description='Customer records', - fields=[ - Field(name='id', type='text', primary=True), - Field(name='name', type='text', indexed=True), - ] - ) + processor.schemas["default"] = { + 'customers': RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + ] + ) + } metadata = MagicMock() metadata.user = 'test_user' @@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): return MagicMock() mock_flow = MagicMock(side_effect=flow_factory) + mock_flow.workspace = "default" await processor.on_message(mock_msg, MagicMock(), mock_flow) diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py index cbc9a05a..09bb7988 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py @@ -34,11 +34,10 @@ def _make_defn(entity, definition): return {"entity": entity, "definition": definition} -def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", - user="user-1", collection="col-1", document_id=""): +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""): chunk = Chunk( metadata=Metadata( - id=meta_id, root=root, user=user, collection=collection, + id=meta_id, root=root, collection=collection, ), chunk=text.encode("utf-8"), document_id=document_id, @@ -229,8 +228,7 @@ class TestMetadataPreservation: defs = [_make_defn("X", "def X")] flow, triples_pub, _, _ = _make_flow(defs) msg = _make_chunk_msg( - "text", meta_id="c-1", root="r-1", - user="u-1", collection="coll-1", + "text", meta_id="c-1", root="r-1", collection="coll-1", ) await proc.on_message(msg, MagicMock(), flow) @@ -238,7 +236,6 @@ class TestMetadataPreservation: for triples_msg in _sent_triples(triples_pub): assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.root == "r-1" - assert triples_msg.metadata.user == "u-1" assert triples_msg.metadata.collection == "coll-1" @pytest.mark.asyncio @@ -247,8 +244,7 @@ class TestMetadataPreservation: defs = [_make_defn("X", "def X")] flow, _, ecs_pub, _ = _make_flow(defs) msg = _make_chunk_msg( - "text", meta_id="c-2", root="r-2", - user="u-2", collection="coll-2", + "text", meta_id="c-2", root="r-2", collection="coll-2", ) await proc.on_message(msg, MagicMock(), flow) diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py index d9861cf3..b85c9e00 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py @@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True): } -def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", - user="user-1", collection="col-1", document_id=""): +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""): """Build a mock message wrapping a Chunk.""" chunk = Chunk( metadata=Metadata( - id=meta_id, root=root, user=user, collection=collection, + id=meta_id, root=root, collection=collection, ), chunk=text.encode("utf-8"), document_id=document_id, @@ -189,8 +188,7 @@ class TestMetadataPreservation: rels = [_make_rel("X", "rel", "Y")] flow, pub, _ = _make_flow(rels) msg = _make_chunk_msg( - "text", meta_id="c-1", root="r-1", - user="u-1", collection="coll-1", + "text", meta_id="c-1", root="r-1", collection="coll-1", ) await proc.on_message(msg, MagicMock(), flow) @@ -198,7 +196,6 @@ class TestMetadataPreservation: for triples_msg in _sent_triples(pub): assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.root == "r-1" - assert triples_msg.metadata.user == "u-1" assert triples_msg.metadata.collection == "coll-1" diff --git a/tests/unit/test_gateway/test_config_receiver.py b/tests/unit/test_gateway/test_config_receiver.py index 90ba8d33..56e96178 100644 --- a/tests/unit/test_gateway/test_config_receiver.py +++ b/tests/unit/test_gateway/test_config_receiver.py @@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader ConfigReceiver.config_loader = Mock() +def _notify(version, changes): + msg = Mock() + msg.value.return_value = Mock(version=version, changes=changes) + return msg + + class TestConfigReceiver: """Test cases for ConfigReceiver class""" @@ -47,98 +53,70 @@ class TestConfigReceiver: assert handler2 in config_receiver.flow_handlers @pytest.mark.asyncio - async def test_on_config_notify_new_version(self): - """Test on_config_notify triggers fetch for newer version""" + async def test_on_config_notify_new_version_fetches_per_workspace(self): + """Notify with newer version fetches each affected workspace.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 1 - # Mock fetch_and_apply fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with newer version - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=["flow"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - assert len(fetch_calls) == 1 + msg = _notify(2, {"flow": ["ws1", "ws2"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert set(fetch_calls) == {"ws1", "ws2"} + assert config_receiver.config_version == 2 @pytest.mark.asyncio async def test_on_config_notify_old_version_ignored(self): - """Test on_config_notify ignores older versions""" + """Older-version notifies are ignored.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 5 fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with older version - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=3, types=["flow"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - assert len(fetch_calls) == 0 + msg = _notify(3, {"flow": ["ws1"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert fetch_calls == [] @pytest.mark.asyncio async def test_on_config_notify_irrelevant_types_ignored(self): - """Test on_config_notify ignores types the gateway doesn't care about""" + """Notifies without flow changes advance version but skip fetch.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) config_receiver.config_version = 1 fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - # Create notify message with non-flow type - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=["prompt"]) + async def mock_fetch(workspace, retry=False): + fetch_calls.append(workspace) - await config_receiver.on_config_notify(mock_msg, None, None) + config_receiver.fetch_and_apply_workspace = mock_fetch - # Version should be updated but no fetch - assert len(fetch_calls) == 0 + msg = _notify(2, {"prompt": ["ws1"]}) + await config_receiver.on_config_notify(msg, None, None) + + assert fetch_calls == [] assert config_receiver.config_version == 2 - @pytest.mark.asyncio - async def test_on_config_notify_flow_type_triggers_fetch(self): - """Test on_config_notify fetches for flow-related types""" - mock_backend = Mock() - config_receiver = ConfigReceiver(mock_backend) - config_receiver.config_version = 1 - - fetch_calls = [] - async def mock_fetch(**kwargs): - fetch_calls.append(kwargs) - config_receiver.fetch_and_apply = mock_fetch - - for type_name in ["flow"]: - fetch_calls.clear() - config_receiver.config_version = 1 - - mock_msg = Mock() - mock_msg.value.return_value = Mock(version=2, types=[type_name]) - - await config_receiver.on_config_notify(mock_msg, None, None) - - assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}" - @pytest.mark.asyncio async def test_on_config_notify_exception_handling(self): - """Test on_config_notify handles exceptions gracefully""" + """on_config_notify swallows exceptions from message decode.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Create notify message that causes an exception mock_msg = Mock() mock_msg.value.side_effect = Exception("Test exception") @@ -146,19 +124,18 @@ class TestConfigReceiver: await config_receiver.on_config_notify(mock_msg, None, None) @pytest.mark.asyncio - async def test_fetch_and_apply_with_new_flows(self): - """Test fetch_and_apply starts new flows""" + async def test_fetch_and_apply_workspace_starts_new_flows(self): + """fetch_and_apply_workspace starts newly-configured flows.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Mock _create_config_client to return a mock client mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { "flow1": '{"name": "test_flow_1"}', - "flow2": '{"name": "test_flow_2"}' + "flow2": '{"name": "test_flow_2"}', } } @@ -167,36 +144,39 @@ class TestConfigReceiver: config_receiver._create_config_client = Mock(return_value=mock_client) start_flow_calls = [] - async def mock_start_flow(id, flow): - start_flow_calls.append((id, flow)) + + async def mock_start_flow(workspace, id, flow): + start_flow_calls.append((workspace, id, flow)) + config_receiver.start_flow = mock_start_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") assert config_receiver.config_version == 5 - assert "flow1" in config_receiver.flows - assert "flow2" in config_receiver.flows + assert "flow1" in config_receiver.flows["default"] + assert "flow2" in config_receiver.flows["default"] assert len(start_flow_calls) == 2 + assert all(c[0] == "default" for c in start_flow_calls) @pytest.mark.asyncio - async def test_fetch_and_apply_with_removed_flows(self): - """Test fetch_and_apply stops removed flows""" + async def test_fetch_and_apply_workspace_stops_removed_flows(self): + """fetch_and_apply_workspace stops flows no longer configured.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Pre-populate with existing flows config_receiver.flows = { - "flow1": {"name": "test_flow_1"}, - "flow2": {"name": "test_flow_2"} + "default": { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"}, + } } - # Config now only has flow1 mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { - "flow1": '{"name": "test_flow_1"}' + "flow1": '{"name": "test_flow_1"}', } } @@ -205,20 +185,22 @@ class TestConfigReceiver: config_receiver._create_config_client = Mock(return_value=mock_client) stop_flow_calls = [] - async def mock_stop_flow(id, flow): - stop_flow_calls.append((id, flow)) + + async def mock_stop_flow(workspace, id, flow): + stop_flow_calls.append((workspace, id, flow)) + config_receiver.stop_flow = mock_stop_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert "flow1" in config_receiver.flows - assert "flow2" not in config_receiver.flows + assert "flow1" in config_receiver.flows["default"] + assert "flow2" not in config_receiver.flows["default"] assert len(stop_flow_calls) == 1 - assert stop_flow_calls[0][0] == "flow2" + assert stop_flow_calls[0][:2] == ("default", "flow2") @pytest.mark.asyncio - async def test_fetch_and_apply_with_no_flows(self): - """Test fetch_and_apply with empty config""" + async def test_fetch_and_apply_workspace_with_no_flows(self): + """Empty workspace config clears any local flow state.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) @@ -231,88 +213,100 @@ class TestConfigReceiver: mock_client.request.return_value = mock_resp config_receiver._create_config_client = Mock(return_value=mock_client) - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert config_receiver.flows == {} + assert config_receiver.flows.get("default", {}) == {} assert config_receiver.config_version == 1 @pytest.mark.asyncio async def test_start_flow_with_handlers(self): - """Test start_flow method with multiple handlers""" + """start_flow fans out to every registered flow handler.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() - handler1.start_flow = Mock() + handler1.start_flow = AsyncMock() handler2 = Mock() - handler2.start_flow = Mock() + handler2.start_flow = AsyncMock() config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) flow_data = {"name": "test_flow", "steps": []} - await config_receiver.start_flow("flow1", flow_data) + await config_receiver.start_flow("default", "flow1", flow_data) - handler1.start_flow.assert_called_once_with("flow1", flow_data) - handler2.start_flow.assert_called_once_with("flow1", flow_data) + handler1.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) + handler2.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_start_flow_with_handler_exception(self): - """Test start_flow method handles handler exceptions""" + """Handler exceptions in start_flow do not propagate.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler = Mock() - handler.start_flow = Mock(side_effect=Exception("Handler error")) + handler.start_flow = AsyncMock(side_effect=Exception("Handler error")) config_receiver.add_handler(handler) flow_data = {"name": "test_flow", "steps": []} # Should not raise - await config_receiver.start_flow("flow1", flow_data) + await config_receiver.start_flow("default", "flow1", flow_data) - handler.start_flow.assert_called_once_with("flow1", flow_data) + handler.start_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_stop_flow_with_handlers(self): - """Test stop_flow method with multiple handlers""" + """stop_flow fans out to every registered flow handler.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() - handler1.stop_flow = Mock() + handler1.stop_flow = AsyncMock() handler2 = Mock() - handler2.stop_flow = Mock() + handler2.stop_flow = AsyncMock() config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) flow_data = {"name": "test_flow", "steps": []} - await config_receiver.stop_flow("flow1", flow_data) + await config_receiver.stop_flow("default", "flow1", flow_data) - handler1.stop_flow.assert_called_once_with("flow1", flow_data) - handler2.stop_flow.assert_called_once_with("flow1", flow_data) + handler1.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) + handler2.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @pytest.mark.asyncio async def test_stop_flow_with_handler_exception(self): - """Test stop_flow method handles handler exceptions""" + """Handler exceptions in stop_flow do not propagate.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) handler = Mock() - handler.stop_flow = Mock(side_effect=Exception("Handler error")) + handler.stop_flow = AsyncMock(side_effect=Exception("Handler error")) config_receiver.add_handler(handler) flow_data = {"name": "test_flow", "steps": []} # Should not raise - await config_receiver.stop_flow("flow1", flow_data) + await config_receiver.stop_flow("default", "flow1", flow_data) - handler.stop_flow.assert_called_once_with("flow1", flow_data) + handler.stop_flow.assert_awaited_once_with( + "default", "flow1", flow_data + ) @patch('asyncio.create_task') @pytest.mark.asyncio @@ -329,25 +323,25 @@ class TestConfigReceiver: mock_create_task.assert_called_once() @pytest.mark.asyncio - async def test_fetch_and_apply_mixed_flow_operations(self): - """Test fetch_and_apply with mixed add/remove operations""" + async def test_fetch_and_apply_workspace_mixed_flow_operations(self): + """fetch_and_apply_workspace adds, keeps and removes flows in one pass.""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - # Pre-populate config_receiver.flows = { - "flow1": {"name": "test_flow_1"}, - "flow2": {"name": "test_flow_2"} + "default": { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"}, + } } - # Config removes flow1, keeps flow2, adds flow3 mock_resp = Mock() mock_resp.error = None mock_resp.version = 5 mock_resp.config = { "flow": { "flow2": '{"name": "test_flow_2"}', - "flow3": '{"name": "test_flow_3"}' + "flow3": '{"name": "test_flow_3"}', } } @@ -358,20 +352,22 @@ class TestConfigReceiver: start_calls = [] stop_calls = [] - async def mock_start_flow(id, flow): - start_calls.append((id, flow)) - async def mock_stop_flow(id, flow): - stop_calls.append((id, flow)) + async def mock_start_flow(workspace, id, flow): + start_calls.append((workspace, id, flow)) + + async def mock_stop_flow(workspace, id, flow): + stop_calls.append((workspace, id, flow)) config_receiver.start_flow = mock_start_flow config_receiver.stop_flow = mock_stop_flow - await config_receiver.fetch_and_apply() + await config_receiver.fetch_and_apply_workspace("default") - assert "flow1" not in config_receiver.flows - assert "flow2" in config_receiver.flows - assert "flow3" in config_receiver.flows + ws_flows = config_receiver.flows["default"] + assert "flow1" not in ws_flows + assert "flow2" in ws_flows + assert "flow3" in ws_flows assert len(start_calls) == 1 - assert start_calls[0][0] == "flow3" + assert start_calls[0][:2] == ("default", "flow3") assert len(stop_calls) == 1 - assert stop_calls[0][0] == "flow1" + assert stop_calls[0][:2] == ("default", "flow1") diff --git a/tests/unit/test_gateway/test_core_import_export_roundtrip.py b/tests/unit/test_gateway/test_core_import_export_roundtrip.py index 843a2b7b..cb2554ee 100644 --- a/tests/unit/test_gateway/test_core_import_export_roundtrip.py +++ b/tests/unit/test_gateway/test_core_import_export_roundtrip.py @@ -36,7 +36,6 @@ def _ge_response_dict(): "metadata": { "id": "doc-1", "root": "", - "user": "alice", "collection": "testcoll", }, "entities": [ @@ -59,7 +58,6 @@ def _triples_response_dict(): "metadata": { "id": "doc-1", "root": "", - "user": "alice", "collection": "testcoll", }, "triples": [ @@ -73,9 +71,9 @@ def _triples_response_dict(): } -def _make_request(id_="doc-1", user="alice"): +def _make_request(id_="doc-1", workspace="alice"): request = Mock() - request.query = {"id": id_, "user": user} + request.query = {"id": id_, "workspace": workspace} return request @@ -149,12 +147,8 @@ class TestCoreExportWireFormat: msg_type, payload = items[0] assert msg_type == "ge" - # Metadata envelope: only id/user/collection — no stale `m["m"]`. - assert payload["m"] == { - "i": "doc-1", - "u": "alice", - "c": "testcoll", - } + # Metadata envelope: only id/collection — no stale `m["m"]`. + assert payload["m"] == {"i": "doc-1", "c": "testcoll"} # Entities: each carries the *singular* `v` and the term envelope assert len(payload["e"]) == 2 @@ -202,11 +196,7 @@ class TestCoreExportWireFormat: msg_type, payload = items[0] assert msg_type == "t" - assert payload["m"] == { - "i": "doc-1", - "u": "alice", - "c": "testcoll", - } + assert payload["m"] == {"i": "doc-1", "c": "testcoll"} assert len(payload["t"]) == 1 @@ -240,7 +230,7 @@ class TestCoreImportWireFormat: payload = msgpack.packb(( "ge", { - "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "m": {"i": "doc-1", "c": "testcoll"}, "e": [ { "e": {"t": "i", "i": "http://example.org/alice"}, @@ -266,7 +256,7 @@ class TestCoreImportWireFormat: req = captured[0] assert req["operation"] == "put-kg-core" - assert req["user"] == "alice" + assert req["workspace"] == "alice" assert req["id"] == "doc-1" ge = req["graph-embeddings"] @@ -275,7 +265,6 @@ class TestCoreImportWireFormat: assert "metadata" not in ge["metadata"] assert ge["metadata"] == { "id": "doc-1", - "user": "alice", "collection": "default", } @@ -302,7 +291,7 @@ class TestCoreImportWireFormat: payload = msgpack.packb(( "t", { - "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "m": {"i": "doc-1", "c": "testcoll"}, "t": [ { "s": {"t": "i", "i": "http://example.org/alice"}, @@ -407,11 +396,10 @@ class TestCoreImportExportRoundTrip: original = _ge_response_dict()["graph-embeddings"] ge = req["graph-embeddings"] - # The import side overrides id/user from the URL query (intentional), + # The import side overrides id from the URL query (intentional), # so we only round-trip the entity payload itself. assert ge["metadata"]["id"] == original["metadata"]["id"] - assert ge["metadata"]["user"] == original["metadata"]["user"] - + assert len(ge["entities"]) == len(original["entities"]) for got, want in zip(ge["entities"], original["entities"]): assert got["vector"] == want["vector"] diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 4ebcb5b9..f091a46d 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -72,10 +72,10 @@ class TestDispatcherManager: flow_data = {"name": "test_flow", "steps": []} - await manager.start_flow("flow1", flow_data) - - assert "flow1" in manager.flows - assert manager.flows["flow1"] == flow_data + await manager.start_flow("default", "flow1", flow_data) + + assert ("default", "flow1") in manager.flows + assert manager.flows[("default", "flow1")] == flow_data @pytest.mark.asyncio async def test_stop_flow(self): @@ -86,11 +86,11 @@ class TestDispatcherManager: # Pre-populate with a flow flow_data = {"name": "test_flow", "steps": []} - manager.flows["flow1"] = flow_data - - await manager.stop_flow("flow1", flow_data) - - assert "flow1" not in manager.flows + manager.flows[("default", "flow1")] = flow_data + + await manager.stop_flow("default", "flow1", flow_data) + + assert ("default", "flow1") not in manager.flows def test_dispatch_global_service_returns_wrapper(self): """Test dispatch_global_service returns DispatcherWrapper""" @@ -275,12 +275,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \ patch('uuid.uuid4') as mock_uuid: mock_uuid.return_value = "test-uuid" @@ -290,7 +290,7 @@ class TestDispatcherManager: mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__contains__.return_value = True - + params = {"flow": "test_flow", "kind": "triples"} result = await manager.process_flow_import("ws", "running", params) @@ -326,12 +326,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers: mock_dispatchers.__contains__.return_value = False @@ -348,12 +348,12 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "triples-store": {"flow": "test_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \ patch('uuid.uuid4') as mock_uuid: mock_uuid.return_value = "test-uuid" @@ -404,7 +404,7 @@ class TestDispatcherManager: params = {"flow": "test_flow", "kind": "agent"} result = await manager.process_flow_service("data", "responder", params) - manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent") + manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent") assert result == "flow_result" @pytest.mark.asyncio @@ -415,14 +415,14 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Add flow to the flows dictionary - manager.flows["test_flow"] = {"services": {"agent": {}}} - + manager.flows[("default", "test_flow")] = {"services": {"agent": {}}} + # Pre-populate with existing dispatcher mock_dispatcher = Mock() mock_dispatcher.process = AsyncMock(return_value="cached_result") - manager.dispatchers[("test_flow", "agent")] = mock_dispatcher - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") + manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") mock_dispatcher.process.assert_called_once_with("data", "responder") assert result == "cached_result" @@ -435,7 +435,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "agent": { "request": "agent_request_queue", @@ -443,7 +443,7 @@ class TestDispatcherManager: } } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() mock_dispatcher = Mock() @@ -452,23 +452,23 @@ class TestDispatcherManager: mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__contains__.return_value = True - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent") - + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") + # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( backend=mock_backend, request_queue="agent_request_queue", response_queue="agent_response_queue", timeout=120, - consumer="api-gateway-test_flow-agent-request", - subscriber="api-gateway-test_flow-agent-request" + consumer="api-gateway-default-test_flow-agent-request", + subscriber="api-gateway-default-test_flow-agent-request" ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - + # Verify dispatcher was cached - assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher assert result == "new_result" @pytest.mark.asyncio @@ -479,26 +479,26 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "text-load": {"flow": "text_load_queue"} } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \ patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers: mock_rr_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = True - + mock_dispatcher_class = Mock() mock_dispatcher = Mock() mock_dispatcher.start = AsyncMock() mock_dispatcher.process = AsyncMock(return_value="sender_result") mock_dispatcher_class.return_value = mock_dispatcher mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class - - result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load") - + + result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load") + # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( backend=mock_backend, @@ -506,9 +506,9 @@ class TestDispatcherManager: ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - + # Verify dispatcher was cached - assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher assert result == "sender_result" @pytest.mark.asyncio @@ -519,7 +519,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) with pytest.raises(RuntimeError, match="Invalid flow"): - await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent") + await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent") @pytest.mark.asyncio async def test_invoke_flow_service_unsupported_kind_by_flow(self): @@ -529,14 +529,14 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow without agent interface - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "text-completion": {"request": "req", "response": "resp"} } } - + with pytest.raises(RuntimeError, match="This kind not supported by flow"): - await manager.invoke_flow_service("data", "responder", "test_flow", "agent") + await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") @pytest.mark.asyncio async def test_invoke_flow_service_invalid_kind(self): @@ -546,7 +546,7 @@ class TestDispatcherManager: manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow with interface but unsupported kind - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "invalid-kind": {"request": "req", "response": "resp"} } @@ -558,7 +558,7 @@ class TestDispatcherManager: mock_sender_dispatchers.__contains__.return_value = False with pytest.raises(RuntimeError, match="Invalid kind"): - await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") + await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind") @pytest.mark.asyncio async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): @@ -608,7 +608,7 @@ class TestDispatcherManager: mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver) - manager.flows["test_flow"] = { + manager.flows[("default", "test_flow")] = { "interfaces": { "agent": { "request": "agent_request_queue", @@ -630,7 +630,7 @@ class TestDispatcherManager: mock_rr_dispatchers.__contains__.return_value = True results = await asyncio.gather(*[ - manager.invoke_flow_service("data", "responder", "test_flow", "agent") + manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent") for _ in range(5) ]) @@ -638,5 +638,5 @@ class TestDispatcherManager: "Dispatcher class instantiated more than once — duplicate consumer bug" ) assert mock_dispatcher.start.call_count == 1 - assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher + assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher assert all(r == "result" for r in results) \ No newline at end of file diff --git a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py index 8eddeba9..4ecfce08 100644 --- a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py +++ b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py @@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing: assert isinstance(sent, EntityContexts) assert isinstance(sent.metadata, Metadata) assert sent.metadata.id == "doc-123" - assert sent.metadata.user == "testuser" assert sent.metadata.collection == "testcollection" assert len(sent.entities) == 2 diff --git a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py index fa277178..09a2d510 100644 --- a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py +++ b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py @@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing: assert isinstance(sent, GraphEmbeddings) assert isinstance(sent.metadata, Metadata) assert sent.metadata.id == "doc-123" - assert sent.metadata.user == "testuser" assert sent.metadata.collection == "testcollection" assert len(sent.entities) == 2 diff --git a/tests/unit/test_gateway/test_rows_import_dispatcher.py b/tests/unit/test_gateway/test_rows_import_dispatcher.py index f029e9a2..0134dd39 100644 --- a/tests/unit/test_gateway/test_rows_import_dispatcher.py +++ b/tests/unit/test_gateway/test_rows_import_dispatcher.py @@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing: # Check metadata assert sent_object.metadata.id == "obj-123" - assert sent_object.metadata.user == "testuser" assert sent_object.metadata.collection == "testcollection" @patch('trustgraph.gateway.dispatch.rows_import.Publisher') diff --git a/tests/unit/test_gateway/test_text_document_translator.py b/tests/unit/test_gateway/test_text_document_translator.py index 84eedefc..da44e798 100644 --- a/tests/unit/test_gateway/test_text_document_translator.py +++ b/tests/unit/test_gateway/test_text_document_translator.py @@ -23,7 +23,6 @@ class TestTextDocumentTranslator: ) assert msg.metadata.id == "doc-1" - assert msg.metadata.user == "alice" assert msg.metadata.collection == "research" assert msg.text == payload.encode("utf-8") diff --git a/tests/unit/test_knowledge_graph/conftest.py b/tests/unit/test_knowledge_graph/conftest.py index 8e8d9e43..d0c47784 100644 --- a/tests/unit/test_knowledge_graph/conftest.py +++ b/tests/unit/test_knowledge_graph/conftest.py @@ -29,10 +29,9 @@ class Triple: self.o = o class Metadata: - def __init__(self, id, user, collection, root=""): + def __init__(self, id, collection, root=""): self.id = id self.root = root - self.user = user self.collection = collection class Triples: @@ -108,7 +107,6 @@ def sample_triples(sample_triple): """Sample Triples batch object""" metadata = Metadata( id="test-doc-123", - user="test_user", collection="test_collection", ) @@ -123,7 +121,6 @@ def sample_chunk(): """Sample text chunk for processing""" metadata = Metadata( id="test-chunk-456", - user="test_user", collection="test_collection", ) diff --git a/tests/unit/test_knowledge_graph/test_agent_extraction.py b/tests/unit/test_knowledge_graph/test_agent_extraction.py index ec985e3b..3c40a2a2 100644 --- a/tests/unit/test_knowledge_graph/test_agent_extraction.py +++ b/tests/unit/test_knowledge_graph/test_agent_extraction.py @@ -322,7 +322,6 @@ This is not JSON at all assert isinstance(sent_triples, Triples) # Check metadata fields individually since implementation creates new Metadata object assert sent_triples.metadata.id == sample_metadata.id - assert sent_triples.metadata.user == sample_metadata.user assert sent_triples.metadata.collection == sample_metadata.collection assert len(sent_triples.triples) == 1 assert sent_triples.triples[0].s.iri == "test:subject" @@ -346,7 +345,6 @@ This is not JSON at all assert isinstance(sent_contexts, EntityContexts) # Check metadata fields individually since implementation creates new Metadata object assert sent_contexts.metadata.id == sample_metadata.id - assert sent_contexts.metadata.user == sample_metadata.user assert sent_contexts.metadata.collection == sample_metadata.collection assert len(sent_contexts.entities) == 1 assert sent_contexts.entities[0].entity.iri == "test:entity" diff --git a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py index f82e4cc8..2d758481 100644 --- a/tests/unit/test_knowledge_graph/test_object_extraction_logic.py +++ b/tests/unit/test_knowledge_graph/test_object_extraction_logic.py @@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic: """Test ExtractedObject creation and properties""" # Arrange metadata = Metadata( - id="test-extraction-001", - user="test_user", + id="test-extraction-001", collection="test_collection", ) @@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic: assert extracted_obj.values[0]["customer_id"] == "CUST001" assert extracted_obj.confidence == 0.95 assert "John Doe" in extracted_obj.source_span - assert extracted_obj.metadata.user == "test_user" def test_config_parsing_error_handling(self): """Test configuration parsing with invalid JSON""" diff --git a/tests/unit/test_knowledge_graph/test_triple_construction.py b/tests/unit/test_knowledge_graph/test_triple_construction.py index e45c69aa..db13a7c1 100644 --- a/tests/unit/test_knowledge_graph/test_triple_construction.py +++ b/tests/unit/test_knowledge_graph/test_triple_construction.py @@ -371,7 +371,6 @@ class TestTripleConstructionLogic: metadata = Metadata( id="test-doc-123", - user="test_user", collection="test_collection", ) @@ -384,7 +383,6 @@ class TestTripleConstructionLogic: # Assert assert isinstance(triples_batch, Triples) assert triples_batch.metadata.id == "test-doc-123" - assert triples_batch.metadata.user == "test_user" assert triples_batch.metadata.collection == "test_collection" assert len(triples_batch.triples) == 2 diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py index eef83e1e..7e7be480 100644 --- a/tests/unit/test_librarian/test_chunked_upload.py +++ b/tests/unit/test_librarian/test_chunked_upload.py @@ -33,12 +33,12 @@ def _make_librarian(min_chunk_size=1): def _make_doc_metadata( - doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc" + doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc" ): meta = MagicMock() meta.id = doc_id meta.kind = kind - meta.user = user + meta.workspace = workspace meta.title = title meta.time = 1700000000 meta.comments = "" @@ -47,27 +47,27 @@ def _make_doc_metadata( def _make_begin_request( - doc_id="doc-1", kind="application/pdf", user="alice", + doc_id="doc-1", kind="application/pdf", workspace="alice", total_size=10_000_000, chunk_size=0 ): req = MagicMock() - req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user) + req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace) req.total_size = total_size req.chunk_size = chunk_size return req -def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"): +def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"): req = MagicMock() req.upload_id = upload_id req.chunk_index = chunk_index - req.user = user + req.workspace = workspace req.content = base64.b64encode(content) return req def _make_session( - user="alice", total_chunks=5, chunk_size=2_000_000, + workspace="alice", total_chunks=5, chunk_size=2_000_000, total_size=10_000_000, chunks_received=None, object_id="obj-1", s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1", ): @@ -76,11 +76,11 @@ def _make_session( if document_metadata is None: document_metadata = json.dumps({ "id": document_id, "kind": "application/pdf", - "user": user, "title": "Test", "time": 1700000000, + "workspace": workspace, "title": "Test", "time": 1700000000, "comments": "", "tags": [], }) return { - "user": user, + "workspace": workspace, "total_chunks": total_chunks, "chunk_size": chunk_size, "total_size": total_size, @@ -259,10 +259,10 @@ class TestUploadChunk: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session - req = _make_upload_chunk_request(user="bob") + req = _make_upload_chunk_request(workspace="bob") with pytest.raises(RequestError, match="Not authorized"): await lib.upload_chunk(req) @@ -353,7 +353,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.complete_upload(req) @@ -375,7 +375,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" await lib.complete_upload(req) @@ -394,7 +394,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" with pytest.raises(RequestError, match="Missing chunks"): await lib.complete_upload(req) @@ -406,7 +406,7 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-gone" - req.user = "alice" + req.workspace = "alice" with pytest.raises(RequestError, match="not found"): await lib.complete_upload(req) @@ -414,12 +414,12 @@ class TestCompleteUpload: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session req = MagicMock() req.upload_id = "up-1" - req.user = "bob" + req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): await lib.complete_upload(req) @@ -439,7 +439,7 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.abort_upload(req) @@ -456,7 +456,7 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-gone" - req.user = "alice" + req.workspace = "alice" with pytest.raises(RequestError, match="not found"): await lib.abort_upload(req) @@ -464,12 +464,12 @@ class TestAbortUpload: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session req = MagicMock() req.upload_id = "up-1" - req.user = "bob" + req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): await lib.abort_upload(req) @@ -492,7 +492,7 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.get_upload_status(req) @@ -510,7 +510,7 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-expired" - req.user = "alice" + req.workspace = "alice" resp = await lib.get_upload_status(req) @@ -527,7 +527,7 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.user = "alice" + req.workspace = "alice" resp = await lib.get_upload_status(req) @@ -539,12 +539,12 @@ class TestGetUploadStatus: @pytest.mark.asyncio async def test_rejects_wrong_user(self): lib = _make_librarian() - session = _make_session(user="alice") + session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session req = MagicMock() req.upload_id = "up-1" - req.user = "bob" + req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): await lib.get_upload_status(req) @@ -564,7 +564,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 @@ -587,7 +587,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 @@ -608,7 +608,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 @@ -630,7 +630,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x") req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 0 # Should use default 1MB @@ -649,7 +649,7 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=raw) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 1000 @@ -666,7 +666,7 @@ class TestStreamDocument: lib.blob_store.get_size = AsyncMock(return_value=5000) req = MagicMock() - req.user = "alice" + req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 512 @@ -698,7 +698,7 @@ class TestListUploads: ] req = MagicMock() - req.user = "alice" + req.workspace = "alice" resp = await lib.list_uploads(req) @@ -713,7 +713,7 @@ class TestListUploads: lib.table_store.list_upload_sessions.return_value = [] req = MagicMock() - req.user = "alice" + req.workspace = "alice" resp = await lib.list_uploads(req) diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py index 184560f0..e65ef2e3 100644 --- a/tests/unit/test_provenance/test_dag_structure.py +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -239,7 +239,7 @@ def _make_processor(tools=None): agent = MagicMock() agent.tools = tools or {} agent.additional_context = "" - processor.agent = agent + processor.agents = {"default": agent} processor.aggregator = MagicMock() return processor @@ -254,6 +254,7 @@ def _make_flow(): return producers[name] flow = MagicMock(side_effect=factory) + flow.workspace = "default" return flow @@ -299,7 +300,7 @@ class TestAgentReactDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -344,7 +345,6 @@ class TestAgentReactDagStructure: request1 = AgentRequest( question="What is 6x7?", - user="testuser", collection="default", streaming=False, session_id=session_id, @@ -433,7 +433,7 @@ class TestAgentPlanDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -480,7 +480,6 @@ class TestAgentPlanDagStructure: # Iteration 1: planning request1 = AgentRequest( question="Test?", - user="testuser", collection="default", streaming=False, session_id=session_id, @@ -537,7 +536,7 @@ class TestAgentSupervisorDagStructure: service.max_iterations = 10 service.save_answer_content = AsyncMock() service.provenance_session_uri = processor.provenance_session_uri - service.agent = processor.agent + service.agents = processor.agents service.aggregator = processor.aggregator service.react_pattern = ReactPattern(service) @@ -563,7 +562,6 @@ class TestAgentSupervisorDagStructure: request = AgentRequest( question="Research quantum computing", - user="testuser", collection="default", streaming=False, session_id=str(uuid.uuid4()), diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 1cddce97..56ccc398 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: def mock_query_request(self): """Create a mock query request for testing""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 @@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_single_vector(self, processor): """Test querying document embeddings with a single vector""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with correct parameters including user/collection processor.vecstore.search.assert_called_once_with( @@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_longer_vector(self, processor): """Test querying document embeddings with a longer vector""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=3 @@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once_with( @@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_with_limit(self, processor): """Test querying document embeddings respects limit parameter""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=2 @@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with the specified limit processor.vecstore.search.assert_called_once_with( @@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_empty_vectors(self, processor): """Test querying document embeddings with empty vectors list""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[], limit=5 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called processor.vecstore.search.assert_not_called() @@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_empty_search_results(self, processor): """Test querying document embeddings with empty search results""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Mock empty search results processor.vecstore.search.return_value = [] - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called processor.vecstore.search.assert_called_once_with( @@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_unicode_documents(self, processor): """Test querying document embeddings with Unicode document content""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify Unicode content is preserved in ChunkMatch objects assert len(result) == 3 @@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_large_documents(self, processor): """Test querying document embeddings with large document content""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify large content is preserved in ChunkMatch objects assert len(result) == 2 @@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_special_characters(self, processor): """Test querying document embeddings with special characters in documents""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify special characters are preserved in ChunkMatch objects assert len(result) == 3 @@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_zero_limit(self, processor): """Test querying document embeddings with zero limit""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=0 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called (optimization for zero limit) processor.vecstore.search.assert_not_called() @@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_negative_limit(self, processor): """Test querying document embeddings with negative limit""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=-1 ) - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify no search was called (optimization for negative limit) processor.vecstore.search.assert_not_called() @@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_exception_handling(self, processor): """Test exception handling during query processing""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Milvus connection failed"): - await processor.query_document_embeddings(query) + await processor.query_document_embeddings('test_user', query) @pytest.mark.asyncio async def test_query_document_embeddings_different_vector_dimensions(self, processor): """Test querying document embeddings with different vector dimensions""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector limit=5 @@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify search was called with the vector processor.vecstore.search.assert_called_once() @@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_multiple_results(self, processor): """Test querying document embeddings with multiple results""" query = DocumentEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 @@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_document_embeddings(query) + result = await processor.query_document_embeddings('test_user', query) # Verify results are ChunkMatch objects assert len(result) == 3 diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index 397bdf1b..b50a95b8 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify index was accessed correctly (with dimension suffix) expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results1, mock_results2] - chunks = await processor.query_document_embeddings(mock_query_message) + chunks = await processor.query_document_embeddings('default', mock_query_message) # Verify both queries were made assert mock_index.query.call_count == 2 @@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify limit is passed to query mock_index.query.assert_called_once() @@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results_2d, mock_results_4d] - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify different indexes used for different dimensions assert processor.pinecone.Index.call_count == 2 @@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify no queries were made and empty result returned processor.pinecone.Index.assert_not_called() @@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_results.matches = [] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify empty results assert chunks == [] @@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify Unicode content is properly handled assert len(chunks) == 2 @@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify large content is properly handled assert len(chunks) == 1 @@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify all content types are properly handled assert len(chunks) == 5 @@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = Exception("Query failed") with pytest.raises(Exception, match="Query failed"): - await processor.query_document_embeddings(message) + await processor.query_document_embeddings('test_user', message) @pytest.mark.asyncio async def test_query_document_embeddings_index_access_failure(self, processor): @@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: processor.pinecone.Index.side_effect = Exception("Index access failed") with pytest.raises(Exception, match="Index access failed"): - await processor.query_document_embeddings(message) + await processor.query_document_embeddings('test_user', message) @pytest.mark.asyncio async def test_query_document_embeddings_vector_accumulation(self, processor): @@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3] - chunks = await processor.query_document_embeddings(message) + chunks = await processor.query_document_embeddings('test_user', message) # Verify all queries were made assert mock_index.query.call_count == 3 diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index 1d2f0e6d..3602ad51 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'test_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('test_user', mock_message) # Assert # Verify query was called with correct parameters (with dimension suffix) @@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'multi_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('multi_user', mock_message) # Assert # Verify query was called once @@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'limit_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('limit_user', mock_message) # Assert # Verify query was called with exact limit (no multiplication) @@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'empty_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('empty_user', mock_message) # Assert assert result == [] @@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'dim_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('dim_user', mock_message) # Assert # Verify query was called once with correct collection @@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'utf8_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('utf8_user', mock_message) # Assert assert len(result) == 2 @@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(Exception, match="Qdrant connection failed"): - await processor.query_document_embeddings(mock_message) + await processor.query_document_embeddings('error_user', mock_message) @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'zero_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('zero_user', mock_message) # Assert # Should still query (with limit 0) @@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'large_collection' # Act - result = await processor.query_document_embeddings(mock_message) + result = await processor.query_document_embeddings('large_user', mock_message) # Assert # Should query with full limit @@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert # This should raise a KeyError when trying to access payload['chunk_id'] with pytest.raises(KeyError): - await processor.query_document_embeddings(mock_message) + await processor.query_document_embeddings('payload_user', mock_message) @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index f2b8be7e..7e5c4df3 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: def mock_query_request(self): """Create a mock query request for testing""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 @@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_single_vector(self, processor): """Test querying graph embeddings with a single vector""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with correct parameters including user/collection processor.vecstore.search.assert_called_once_with( @@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_multiple_results(self, processor): """Test querying graph embeddings returns multiple results""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 @@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once_with( @@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_with_limit(self, processor): """Test querying graph embeddings respects limit parameter""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=2 @@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with 2*limit for better deduplication processor.vecstore.search.assert_called_once_with( @@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_preserves_order(self, processor): """Test that query results preserve order from the vector store""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 @@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify results are in the same order as returned by the store assert len(result) == 3 @@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_results_limited(self, processor): """Test that results are properly limited when store returns more than requested""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=2 @@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called with the full vector processor.vecstore.search.assert_called_once_with( @@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_empty_vectors(self, processor): """Test querying graph embeddings with empty vectors list""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[], limit=5 ) - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify no search was called processor.vecstore.search.assert_not_called() @@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_empty_search_results(self, processor): """Test querying graph embeddings with empty search results""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Mock empty search results processor.vecstore.search.return_value = [] - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called processor.vecstore.search.assert_called_once_with( @@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor): """Test querying graph embeddings with mixed URI and literal results""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify all results are properly typed assert len(result) == 4 @@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_exception_handling(self, processor): """Test exception handling during query processing""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=5 @@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Milvus connection failed"): - await processor.query_graph_embeddings(query) + await processor.query_graph_embeddings('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" @@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_zero_limit(self, processor): """Test querying graph embeddings with zero limit""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3], limit=0 ) - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify no search was called (optimization for zero limit) processor.vecstore.search.assert_not_called() @@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_longer_vector(self, processor): """Test querying graph embeddings with a longer vector""" query = GraphEmbeddingsRequest( - user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], limit=5 @@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: ] processor.vecstore.search.return_value = mock_results - result = await processor.query_graph_embeddings(query) + result = await processor.query_graph_embeddings('test_user', query) # Verify search was called once with the full vector processor.vecstore.search.assert_called_once() diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 2c1a673a..0fc8f7c0 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify index was accessed correctly (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(mock_query_message) + entities = await processor.query_graph_embeddings('default', mock_query_message) # Verify query was made once assert mock_index.query.call_count == 1 @@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify limit is respected assert len(entities) == 2 @@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify no query was made and empty result returned mock_index.query.assert_not_called() @@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify correct index used for 2D vector processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2") @@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify no queries were made and empty result returned processor.pinecone.Index.assert_not_called() @@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_results.matches = [] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Verify empty results assert entities == [] @@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Should get exactly 3 unique entities (respecting limit) assert len(entities) == 3 @@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: ] mock_index.query.return_value = mock_results - entities = await processor.query_graph_embeddings(message) + entities = await processor.query_graph_embeddings('test_user', message) # Should only return 2 entities (respecting limit) mock_index.query.assert_called_once() @@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: mock_index.query.side_effect = Exception("Query failed") with pytest.raises(Exception, match="Query failed"): - await processor.query_graph_embeddings(message) + await processor.query_graph_embeddings('test_user', message) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 9362a8dd..41b6c8a4 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'test_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('test_user', mock_message) # Assert # Verify query was called with correct parameters (with dimension suffix) @@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'multi_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('multi_user', mock_message) # Assert # Verify query was called once @@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'limit_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('limit_user', mock_message) # Assert # Verify query was called with limit * 2 @@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'empty_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('empty_user', mock_message) # Assert assert result == [] @@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'dim_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('dim_user', mock_message) # Assert # Verify query was called once @@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'uri_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('uri_user', mock_message) # Assert assert len(result) == 3 @@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(Exception, match="Qdrant connection failed"): - await processor.query_graph_embeddings(mock_message) + await processor.query_graph_embeddings('error_user', mock_message) @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): mock_message.collection = 'zero_collection' # Act - result = await processor.query_graph_embeddings(mock_message) + result = await processor.query_graph_embeddings('zero_user', mock_message) # Assert # Should still query (with limit 0) diff --git a/tests/unit/test_query/test_memgraph_user_collection_query.py b/tests/unit/test_query/test_memgraph_workspace_collection_query.py similarity index 76% rename from tests/unit/test_query/test_memgraph_user_collection_query.py rename to tests/unit/test_query/test_memgraph_workspace_collection_query.py index 038fb438..d0ab242e 100644 --- a/tests/unit/test_query/test_memgraph_user_collection_query.py +++ b/tests/unit/test_query/test_memgraph_workspace_collection_query.py @@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL -class TestMemgraphQueryUserCollectionIsolation: +class TestMemgraphQueryWorkspaceCollectionIsolation: """Test cases for Memgraph query service with user/collection isolation""" @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_spo_query_with_user_collection(self, mock_graph_db): + async def test_spo_query_with_workspace_collection(self, mock_graph_db): """Test SPO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SPO query for literal includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT 1000" ) @@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation: src="http://example.com/s", rel="http://example.com/p", value="test_object", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_sp_query_with_user_collection(self, mock_graph_db): + async def test_sp_query_with_workspace_collection(self, mock_graph_db): """Test SP query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SP query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT 1000" ) @@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation: expected_literal_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_so_query_with_user_collection(self, mock_graph_db): + async def test_so_query_with_workspace_collection(self, mock_graph_db): """Test SO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SO query for nodes includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT 1000" ) @@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation: expected_query, src="http://example.com/s", uri="http://example.com/o", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_s_only_query_with_user_collection(self, mock_graph_db): + async def test_s_only_query_with_workspace_collection(self, mock_graph_db): """Test S-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify S query includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT 1000" ) @@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, src="http://example.com/s", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_po_query_with_user_collection(self, mock_graph_db): + async def test_po_query_with_workspace_collection(self, mock_graph_db): """Test PO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify PO query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT 1000" ) @@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation: expected_query, uri="http://example.com/p", value="literal", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_p_only_query_with_user_collection(self, mock_graph_db): + async def test_p_only_query_with_workspace_collection(self, mock_graph_db): """Test P-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify P query includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT 1000" ) @@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, uri="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_o_only_query_with_user_collection(self, mock_graph_db): + async def test_o_only_query_with_workspace_collection(self, mock_graph_db): """Test O-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify O query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT 1000" ) @@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, value="test_value", - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_with_user_collection(self, mock_graph_db): + async def test_wildcard_query_with_workspace_collection(self, mock_graph_db): """Test wildcard query (all None) includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify wildcard query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT 1000" ) mock_driver.execute_query.assert_any_call( expected_literal_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) # Verify wildcard query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT 1000" ) mock_driver.execute_query.assert_any_call( expected_node_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='memgraph' ) @@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples('default', query) # Verify defaults were used calls = mock_driver.execute_query.call_args_list @@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation: ([mock_record2], MagicMock(), MagicMock()) # Node query ] - result = await processor.query_triples(query) + result = await processor.query_triples("test_user", query) # Verify results are proper Triple objects assert len(result) == 2 diff --git a/tests/unit/test_query/test_neo4j_user_collection_query.py b/tests/unit/test_query/test_neo4j_workspace_collection_query.py similarity index 75% rename from tests/unit/test_query/test_neo4j_user_collection_query.py rename to tests/unit/test_query/test_neo4j_workspace_collection_query.py index d9cf1eb4..029de617 100644 --- a/tests/unit/test_query/test_neo4j_user_collection_query.py +++ b/tests/unit/test_query/test_neo4j_workspace_collection_query.py @@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL -class TestNeo4jQueryUserCollectionIsolation: +class TestNeo4jQueryWorkspaceCollectionIsolation: """Test cases for Neo4j query service with user/collection isolation""" @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_spo_query_with_user_collection(self, mock_graph_db): + async def test_spo_query_with_workspace_collection(self, mock_graph_db): """Test SPO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SPO query for literal includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT 10" ) @@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation: src="http://example.com/s", rel="http://example.com/p", value="test_object", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_sp_query_with_user_collection(self, mock_graph_db): + async def test_sp_query_with_workspace_collection(self, mock_graph_db): """Test SP query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=Term(type=IRI, iri="http://example.com/p"), @@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SP query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT 10" ) @@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation: expected_literal_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) # Verify SP query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT 10" ) @@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_node_query, src="http://example.com/s", rel="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_so_query_with_user_collection(self, mock_graph_db): + async def test_so_query_with_workspace_collection(self, mock_graph_db): """Test SO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify SO query for nodes includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT 10" ) @@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_query, src="http://example.com/s", uri="http://example.com/o", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_s_only_query_with_user_collection(self, mock_graph_db): + async def test_s_only_query_with_workspace_collection(self, mock_graph_db): """Test S-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify S query includes user/collection expected_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT 10" ) @@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, src="http://example.com/s", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_po_query_with_user_collection(self, mock_graph_db): + async def test_po_query_with_workspace_collection(self, mock_graph_db): """Test PO query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify PO query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT 10" ) @@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation: expected_query, uri="http://example.com/p", value="literal", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_p_only_query_with_user_collection(self, mock_graph_db): + async def test_p_only_query_with_workspace_collection(self, mock_graph_db): """Test P-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=Term(type=IRI, iri="http://example.com/p"), @@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify P query includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT 10" ) @@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, uri="http://example.com/p", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_o_only_query_with_user_collection(self, mock_graph_db): + async def test_o_only_query_with_workspace_collection(self, mock_graph_db): """Test O-only query pattern includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify O query for literals includes user/collection expected_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT 10" ) @@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.assert_any_call( expected_query, value="test_value", - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_with_user_collection(self, mock_graph_db): + async def test_wildcard_query_with_workspace_collection(self, mock_graph_db): """Test wildcard query (all None) includes user/collection filtering""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver @@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=None, p=None, @@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples("test_user", query) # Verify wildcard query for literals includes user/collection expected_literal_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT 10" ) mock_driver.execute_query.assert_any_call( expected_literal_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) # Verify wildcard query for nodes includes user/collection expected_node_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT 10" ) mock_driver.execute_query.assert_any_call( expected_node_query, - user="test_user", + workspace="test_user", collection="test_collection", database_='neo4j' ) @@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation: mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - await processor.query_triples(query) + await processor.query_triples('default', query) # Verify defaults were used calls = mock_driver.execute_query.call_args_list @@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/s"), p=None, @@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation: ([mock_record2], MagicMock(), MagicMock()) # Node query ] - result = await processor.query_triples(query) + result = await processor.query_triples("test_user", query) # Verify results are proper Triple objects assert len(result) == 2 diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index c0d399c3..bb6bbe84 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic: """Test parsing of schema configuration""" processor = MagicMock() processor.schemas = {} + processor.schema_builders = {} + processor.graphql_schemas = {} processor.config_key = "schema" - processor.schema_builder = MagicMock() - processor.schema_builder.clear = MagicMock() - processor.schema_builder.add_schema = MagicMock() - processor.schema_builder.build = MagicMock(return_value=MagicMock()) + processor.query_cassandra = MagicMock() processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic: } # Process config - await processor.on_schema_config(schema_config, version=1) + await processor.on_schema_config("default", schema_config, version=1) # Verify schema was loaded - assert "customer" in processor.schemas - schema = processor.schemas["customer"] + assert "customer" in processor.schemas["default"] + schema = processor.schemas["default"]["customer"] assert schema.name == "customer" assert len(schema.fields) == 3 @@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic: status_field = next(f for f in schema.fields if f.name == "status") assert status_field.enum_values == ["active", "inactive"] - # Verify schema builder was called - processor.schema_builder.add_schema.assert_called_once() - processor.schema_builder.build.assert_called_once() + # Verify per-workspace schema builder was created and graphql schema built + assert "default" in processor.schema_builders + assert "default" in processor.graphql_schemas @pytest.mark.asyncio async def test_graphql_context_handling(self): """Test GraphQL execution context setup""" processor = MagicMock() - processor.graphql_schema = AsyncMock() + graphql_schema = AsyncMock() + processor.graphql_schemas = {"default": graphql_schema} processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) # Mock schema execution mock_result = MagicMock() mock_result.data = {"customers": [{"id": "1", "name": "Test"}]} mock_result.errors = None - processor.graphql_schema.execute.return_value = mock_result + graphql_schema.execute.return_value = mock_result result = await processor.execute_graphql_query( + workspace="default", query='{ customers { id name } }', variables={}, operation_name=None, - user="test_user", collection="test_collection" ) # Verify schema.execute was called with correct context - processor.graphql_schema.execute.assert_called_once() - call_args = processor.graphql_schema.execute.call_args + graphql_schema.execute.assert_called_once() + call_args = graphql_schema.execute.call_args # Verify context was passed context = call_args[1]['context_value'] assert context["processor"] == processor - assert context["user"] == "test_user" + assert context["workspace"] == "default" assert context["collection"] == "test_collection" # Verify result structure @@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic: async def test_error_handling_graphql_errors(self): """Test GraphQL error handling and conversion""" processor = MagicMock() - processor.graphql_schema = AsyncMock() + graphql_schema = AsyncMock() + processor.graphql_schemas = {"default": graphql_schema} processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) # Create a simple object to simulate GraphQL error @@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic: mock_result = MagicMock() mock_result.data = None mock_result.errors = [mock_error] - processor.graphql_schema.execute.return_value = mock_result + graphql_schema.execute.return_value = mock_result result = await processor.execute_graphql_query( + workspace="default", query='{ customers { invalid_field } }', variables={}, operation_name=None, - user="test_user", collection="test_collection" ) @@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic: # Create mock message mock_msg = MagicMock() mock_request = RowsQueryRequest( - user="test_user", collection="test_collection", query='{ customers { id name } }', variables={}, @@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic: # Mock flow mock_flow = MagicMock() + mock_flow.workspace = "default" mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow @@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic: # Verify query was executed processor.execute_graphql_query.assert_called_once_with( + workspace="default", query='{ customers { id name } }', variables={}, operation_name=None, - user="test_user", collection="test_collection" ) @@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic: # Create mock message mock_msg = MagicMock() mock_request = RowsQueryRequest( - user="test_user", collection="test_collection", query='{ invalid_query }', variables={}, @@ -357,7 +357,7 @@ class TestUnifiedTableQueries: # Query with filter on indexed field results = await processor.query_cassandra( - user="test_user", + workspace="test_workspace", collection="test_collection", schema_name="products", row_schema=schema, @@ -374,7 +374,7 @@ class TestUnifiedTableQueries: query = call_args[0][1] params = call_args[0][2] - assert "SELECT data, source FROM test_user.rows" in query + assert "SELECT data, source FROM test_workspace.rows" in query assert "collection = %s" in query assert "schema_name = %s" in query assert "index_name = %s" in query @@ -421,7 +421,7 @@ class TestUnifiedTableQueries: # Query with filter on non-indexed field results = await processor.query_cassandra( - user="test_user", + workspace="test_workspace", collection="test_collection", schema_name="products", row_schema=schema, diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index b620df7e..09681214 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -95,7 +95,6 @@ class TestCassandraQueryProcessor: # Create query request with all SPO values query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -103,7 +102,7 @@ class TestCassandraQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify KnowledgeGraph was created with correct parameters mock_kg_class.assert_called_once_with( @@ -170,7 +169,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -178,7 +176,7 @@ class TestCassandraQueryProcessor: limit=50 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 @@ -207,7 +205,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -215,7 +212,7 @@ class TestCassandraQueryProcessor: limit=25 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 @@ -244,7 +241,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=LITERAL, value='test_predicate'), @@ -252,7 +248,7 @@ class TestCassandraQueryProcessor: limit=10 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 @@ -281,7 +277,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -289,7 +284,7 @@ class TestCassandraQueryProcessor: limit=75 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 @@ -319,7 +314,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -327,7 +321,7 @@ class TestCassandraQueryProcessor: limit=1000 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 @@ -425,7 +419,6 @@ class TestCassandraQueryProcessor: ) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -433,7 +426,7 @@ class TestCassandraQueryProcessor: limit=100 ) - await processor.query_triples(query) + await processor.query_triples('test_user', query) # Verify KnowledgeGraph was created with authentication mock_kg_class.assert_called_once_with( @@ -463,7 +456,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -472,11 +464,11 @@ class TestCassandraQueryProcessor: ) # First query should create TrustGraph - await processor.query_triples(query) + await processor.query_triples('test_user', query) assert mock_kg_class.call_count == 1 # Second query with same table should reuse TrustGraph - await processor.query_triples(query) + await processor.query_triples('test_user', query) assert mock_kg_class.call_count == 1 # Should not increase @pytest.mark.asyncio @@ -504,7 +496,6 @@ class TestCassandraQueryProcessor: # First query query1 = TriplesQueryRequest( - user='user1', collection='collection1', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -512,12 +503,11 @@ class TestCassandraQueryProcessor: limit=100 ) - await processor.query_triples(query1) + await processor.query_triples('user1', query1) assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( - user='user2', collection='collection2', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -525,7 +515,7 @@ class TestCassandraQueryProcessor: limit=100 ) - await processor.query_triples(query2) + await processor.query_triples('user2', query2) assert processor.table == 'user2' # Verify TrustGraph was created twice @@ -544,7 +534,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -553,7 +542,7 @@ class TestCassandraQueryProcessor: ) with pytest.raises(Exception, match="Query failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -582,7 +571,6 @@ class TestCassandraQueryProcessor: processor = Processor(taskgroup=MagicMock()) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=Term(type=LITERAL, value='test_predicate'), @@ -590,7 +578,7 @@ class TestCassandraQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) assert len(result) == 2 assert result[0].o.value == 'object1' @@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations: # PO query pattern (predicate + object, find subjects) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=LITERAL, value='test_predicate'), @@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=50 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify get_po was called (should use optimized po_table) mock_tg_instance.get_po.assert_called_once_with( @@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations: # OS query pattern (object + subject, find predicates) query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value='test_subject'), p=None, @@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=25 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify get_os was called (should use optimized subject_table with clustering) mock_tg_instance.get_os.assert_called_once_with( @@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations: mock_tg_instance.reset_mock() query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=LITERAL, value=s) if s else None, p=Term(type=LITERAL, value=p) if p else None, @@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=10 ) - await processor.query_triples(query) + await processor.query_triples('test_user', query) # Verify the correct method was called method = getattr(mock_tg_instance, expected_method) @@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations: # This is the query pattern that was slow with ALLOW FILTERING query = TriplesQueryRequest( - user='large_dataset_user', collection='massive_collection', s=None, p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'), @@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations: limit=1000 ) - result = await processor.query_triples(query) + result = await processor.query_triples('large_dataset_user', query) # Verify optimized get_po was used (no ALLOW FILTERING needed!) mock_tg_instance.get_po.assert_called_once_with( diff --git a/tests/unit/test_query/test_triples_falkordb_query.py b/tests/unit/test_query/test_triples_falkordb_query.py index d5c047d7..3d7270c6 100644 --- a/tests/unit/test_query/test_triples_falkordb_query.py +++ b/tests/unit/test_query/test_triples_falkordb_query.py @@ -123,7 +123,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -131,7 +130,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -164,7 +163,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -172,7 +170,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -209,7 +207,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -217,7 +214,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -254,7 +251,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -262,7 +258,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -299,7 +295,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -307,7 +302,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -344,7 +339,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -352,7 +346,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -389,7 +383,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -397,7 +390,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -434,7 +427,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -442,7 +434,7 @@ class TestFalkorDBQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_graph.query.call_count == 2 @@ -474,7 +466,6 @@ class TestFalkorDBQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -484,7 +475,7 @@ class TestFalkorDBQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_triples_memgraph_query.py b/tests/unit/test_query/test_triples_memgraph_query.py index f4222af1..a21d9008 100644 --- a/tests/unit/test_query/test_triples_memgraph_query.py +++ b/tests/unit/test_query/test_triples_memgraph_query.py @@ -122,7 +122,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -130,7 +129,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -164,7 +163,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -172,7 +170,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -210,7 +208,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -218,7 +215,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -256,7 +253,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -264,7 +260,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -302,7 +298,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -310,7 +305,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -348,7 +343,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=Term(type=IRI, iri="http://example.com/predicate"), @@ -356,7 +350,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -394,7 +388,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -402,7 +395,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -440,7 +433,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -448,7 +440,7 @@ class TestMemgraphQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -478,7 +470,6 @@ class TestMemgraphQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -488,7 +479,7 @@ class TestMemgraphQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_triples_neo4j_query.py b/tests/unit/test_query/test_triples_neo4j_query.py index e379ed21..3751a858 100644 --- a/tests/unit/test_query/test_triples_neo4j_query.py +++ b/tests/unit/test_query/test_triples_neo4j_query.py @@ -122,7 +122,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -130,7 +129,7 @@ class TestNeo4jQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -164,7 +163,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), @@ -172,7 +170,7 @@ class TestNeo4jQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -210,7 +208,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=None, p=None, @@ -218,7 +215,7 @@ class TestNeo4jQueryProcessor: limit=100 ) - result = await processor.query_triples(query) + result = await processor.query_triples('test_user', query) # Verify both literal and URI queries were executed assert mock_driver.execute_query.call_count == 2 @@ -248,7 +245,6 @@ class TestNeo4jQueryProcessor: # Create query request query = TriplesQueryRequest( - user='test_user', collection='test_collection', s=Term(type=IRI, iri="http://example.com/subject"), p=None, @@ -258,7 +254,7 @@ class TestNeo4jQueryProcessor: # Should raise the exception with pytest.raises(Exception, match="Database connection failed"): - await processor.query_triples(query) + await processor.query_triples('test_user', query) def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py index aded7253..2170c763 100644 --- a/tests/unit/test_reliability/test_metadata_preservation.py +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator: "title": "Test Document", "comments": "No comments", "metadata": [], - "user": "alice", + "workspace": "alice", "tags": ["finance", "q4"], "parent-id": "doc-100", "document-type": "page", @@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator: assert obj.time == 1710000000 assert obj.kind == "application/pdf" assert obj.title == "Test Document" - assert obj.user == "alice" + assert obj.workspace == "alice" assert obj.tags == ["finance", "q4"] assert obj.parent_id == "doc-100" assert obj.document_type == "page" wire = self.tx.encode(obj) assert wire["id"] == "doc-123" - assert wire["user"] == "alice" + assert wire["workspace"] == "alice" assert wire["parent-id"] == "doc-100" assert wire["document-type"] == "page" @@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator: def test_falsy_fields_omitted_from_wire(self): """Empty string fields should be omitted from wire format.""" - obj = DocumentMetadata(id="", time=0, user="") + obj = DocumentMetadata(id="", time=0, workspace="") wire = self.tx.encode(obj) assert "id" not in wire - assert "user" not in wire + assert "workspace" not in wire # --------------------------------------------------------------------------- @@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator: "document-id": "doc-123", "time": 1710000000, "flow": "default", - "user": "alice", + "workspace": "alice", "collection": "my-collection", "tags": ["tag1"], } @@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator: assert obj.id == "proc-1" assert obj.document_id == "doc-123" assert obj.flow == "default" - assert obj.user == "alice" + assert obj.workspace == "alice" assert obj.collection == "my-collection" assert obj.tags == ["tag1"] wire = self.tx.encode(obj) assert wire["id"] == "proc-1" assert wire["document-id"] == "doc-123" - assert wire["user"] == "alice" + assert wire["workspace"] == "alice" assert wire["collection"] == "my-collection" def test_missing_fields_use_defaults(self): obj = self.tx.decode({}) assert obj.id is None - assert obj.user is None + assert obj.workspace is None assert obj.collection is None def test_tags_none_omitted(self): @@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator: wire = self.tx.encode(obj) assert wire["tags"] == [] - def test_user_and_collection_preserved(self): + def test_workspace_and_collection_preserved(self): """Core pipeline routing fields must survive round-trip.""" - data = {"user": "bob", "collection": "research"} + data = {"workspace": "bob", "collection": "research"} obj = self.tx.decode(data) wire = self.tx.encode(obj) - assert wire["user"] == "bob" + assert wire["workspace"] == "bob" assert wire["collection"] == "research" diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 41a5c621..2296e961 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [] # Empty vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) # No upsert should be called proc.qdrant.upsert.assert_not_called() @@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = None # None vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.1, 0.2, 0.3] msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.1, 0.2, 0.3] msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_called_once() @pytest.mark.asyncio @@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "alice" msg.metadata.collection = "docs" emb = MagicMock() @@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.0] * 384 # 384-dim vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("alice", msg) call_args = proc.qdrant.upsert.call_args assert "d_alice_docs_384" in call_args[1]["collection_name"] @@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [0.1, 0.2, 0.3] msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [0.1, 0.2, 0.3] msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [] # Empty vector msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection: entity.chunk_id = "c1" msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_called_once() @pytest.mark.asyncio @@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "alice" msg.metadata.collection = "graphs" entity = MagicMock() @@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection: entity.chunk_id = "" msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("alice", msg) # Collection should be created with correct dimension proc.qdrant.create_collection.assert_called_once() @@ -290,11 +280,10 @@ class TestCollectionValidation: proc.collection_exists = MagicMock(return_value=False) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "deleted-col" msg.chunks = [MagicMock()] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -306,9 +295,8 @@ class TestCollectionValidation: proc.collection_exists = MagicMock(return_value=False) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "deleted-col" msg.entities = [MagicMock()] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 1ff85f5a..fd140b95 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -92,14 +92,13 @@ class TestQuery: # Initialize Query with defaults query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) # Verify initialization assert query.rag == mock_rag - assert query.user == "test_user" assert query.collection == "test_collection" assert query.verbose is False assert query.doc_limit == 20 # Default value @@ -112,7 +111,7 @@ class TestQuery: # Initialize Query with custom doc_limit query = Query( rag=mock_rag, - user="custom_user", + workspace="test_workspace", collection="custom_collection", verbose=True, doc_limit=50 @@ -120,7 +119,6 @@ class TestQuery: # Verify initialization assert query.rag == mock_rag - assert query.user == "custom_user" assert query.collection == "custom_collection" assert query.verbose is True assert query.doc_limit == 50 @@ -137,7 +135,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -162,7 +160,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -184,7 +182,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -223,7 +221,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False, doc_limit=15 @@ -240,7 +238,6 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[0.1, 0.2, 0.3], limit=15, - user="test_user", collection="test_collection" ) @@ -286,7 +283,6 @@ class TestQuery: result = await document_rag.query( query="test query", - user="test_user", collection="test_collection", doc_limit=10 ) @@ -304,7 +300,6 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[0.1, 0.2, 0.3], limit=10, - user="test_user", collection="test_collection" ) @@ -350,7 +345,6 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[[0.1, 0.2]], limit=20, # Default doc_limit - user="trustgraph", # Default user collection="default" # Default collection ) @@ -380,7 +374,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=True, doc_limit=5 @@ -453,7 +447,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False ) @@ -509,7 +503,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=True ) @@ -558,7 +552,6 @@ class TestQuery: result = await document_rag.query( query=query_text, - user="research_user", collection="ml_knowledge", doc_limit=25 ) @@ -619,7 +612,7 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", + workspace="test_workspace", collection="test_collection", verbose=False, doc_limit=10 diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index a5d42f3a..dde3acc1 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -1,6 +1,6 @@ """ Unit test for DocumentRAG service parameter passing fix. -Tests that user and collection parameters from the message are correctly +Tests that the collection parameter from the message is correctly passed to the DocumentRag.query() method. """ @@ -16,13 +16,13 @@ class TestDocumentRagService: @patch('trustgraph.retrieval.document_rag.rag.DocumentRag') @pytest.mark.asyncio - async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class): + async def test_collection_parameter_passed_to_query(self, mock_document_rag_class): """ - Test that user and collection from message are passed to DocumentRag.query(). - - This is a regression test for the bug where user/collection parameters - were ignored, causing wrong collection names like 'd_trustgraph_default_384' - instead of 'd_my_user_test_coll_1_384'. + Test that collection from message is passed to DocumentRag.query(). + + This is a regression test for the bug where the collection parameter + was ignored, causing wrong collection names like 'd_trustgraph_default_384' + instead of one that reflects the requested collection. """ # Setup processor processor = Processor( @@ -30,17 +30,16 @@ class TestDocumentRagService: id="test-processor", doc_limit=10 ) - + # Setup mock DocumentRag instance mock_rag_instance = AsyncMock() mock_document_rag_class.return_value = mock_rag_instance mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None}) - - # Setup message with custom user/collection + + # Setup message with custom collection msg = MagicMock() msg.value.return_value = DocumentRagQuery( query="test query", - user="my_user", # Custom user (not default "trustgraph") collection="test_coll_1", # Custom collection (not default "default") doc_limit=5 ) @@ -64,7 +63,7 @@ class TestDocumentRagService: # Verify: DocumentRag.query was called with correct parameters mock_rag_instance.query.assert_called_once_with( "test query", - user="my_user", # Must be from message, not hardcoded default + workspace=ANY, # Workspace comes from flow.workspace (mock) collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, explain_callback=ANY, # Explainability callback is always passed @@ -103,7 +102,6 @@ class TestDocumentRagService: msg = MagicMock() msg.value.return_value = DocumentRagQuery( query="What is a cat?", - user="trustgraph", collection="default", doc_limit=10, streaming=False # Non-streaming mode diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 00a9551f..e0f41357 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -78,14 +78,12 @@ class TestQuery: # Initialize Query with defaults query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) # Verify initialization assert query.rag == mock_rag - assert query.user == "test_user" assert query.collection == "test_collection" assert query.verbose is False assert query.entity_limit == 50 # Default value @@ -101,7 +99,6 @@ class TestQuery: # Initialize Query with custom parameters query = Query( rag=mock_rag, - user="custom_user", collection="custom_collection", verbose=True, entity_limit=100, @@ -112,7 +109,6 @@ class TestQuery: # Verify initialization assert query.rag == mock_rag - assert query.user == "custom_user" assert query.collection == "custom_collection" assert query.verbose is True assert query.entity_limit == 100 @@ -133,7 +129,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -156,7 +151,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=True ) @@ -177,7 +171,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -201,7 +194,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -244,7 +236,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, entity_limit=25 @@ -269,7 +260,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -277,7 +267,7 @@ class TestQuery: result = await query.maybe_label("entity1") assert result == "Entity One Label" - mock_cache.get.assert_called_once_with("test_user:test_collection:entity1") + mock_cache.get.assert_called_once_with("test_collection:entity1") @pytest.mark.asyncio async def test_maybe_label_with_label_lookup(self): @@ -295,7 +285,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -307,13 +296,12 @@ class TestQuery: p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, - user="test_user", collection="test_collection", g="" ) assert result == "Human Readable Label" - cache_key = "test_user:test_collection:http://example.com/entity" + cache_key = "test_collection:http://example.com/entity" mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label") @pytest.mark.asyncio @@ -330,7 +318,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -342,13 +329,12 @@ class TestQuery: p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, - user="test_user", collection="test_collection", g="" ) assert result == "unlabeled_entity" - cache_key = "test_user:test_collection:unlabeled_entity" + cache_key = "test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") @pytest.mark.asyncio @@ -375,7 +361,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, triple_limit=10 @@ -388,15 +373,15 @@ class TestQuery: mock_triples_client.query_stream.assert_any_call( s="entity1", p=None, o=None, limit=10, - user="test_user", collection="test_collection", batch_size=20, g="" + collection="test_collection", batch_size=20, g="" ) mock_triples_client.query_stream.assert_any_call( s=None, p="entity1", o=None, limit=10, - user="test_user", collection="test_collection", batch_size=20, g="" + collection="test_collection", batch_size=20, g="" ) mock_triples_client.query_stream.assert_any_call( s=None, p=None, o="entity1", limit=10, - user="test_user", collection="test_collection", batch_size=20, g="" + collection="test_collection", batch_size=20, g="" ) expected_subgraph = { @@ -415,7 +400,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False ) @@ -435,7 +419,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, max_subgraph_size=2 @@ -455,7 +438,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, max_path_length=1 @@ -493,7 +475,6 @@ class TestQuery: query = Query( rag=mock_rag, - user="test_user", collection="test_collection", verbose=False, max_subgraph_size=100 @@ -601,7 +582,6 @@ class TestQuery: try: response = await graph_rag.query( query="test query", - user="test_user", collection="test_collection", entity_limit=25, triple_limit=15, diff --git a/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py index 603bd204..5208bf7f 100644 --- a/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py +++ b/tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py @@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="What is quantum computing?", - user="trustgraph", collection="default", streaming=False, ) diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py index 606aa7fe..a637a350 100644 --- a/tests/unit/test_retrieval/test_graph_rag_service.py +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -52,7 +52,6 @@ class TestGraphRagService: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="What is a cat?", - user="trustgraph", collection="default", entity_limit=50, triple_limit=30, @@ -123,7 +122,6 @@ class TestGraphRagService: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="What is a cat?", - user="trustgraph", collection="default", entity_limit=50, triple_limit=30, @@ -190,7 +188,6 @@ class TestGraphRagService: msg = MagicMock() msg.value.return_value = GraphRagQuery( query="Test query", - user="trustgraph", collection="default", streaming=False ) diff --git a/tests/unit/test_retrieval/test_nlp_query.py b/tests/unit/test_retrieval/test_nlp_query.py index 1fd35c2e..cc285aea 100644 --- a/tests/unit/test_retrieval/test_nlp_query.py +++ b/tests/unit/test_retrieval/test_nlp_query.py @@ -286,11 +286,11 @@ class TestNLPQueryProcessor: } # Act - await processor.on_schema_config(config, "v1") + await processor.on_schema_config("default", config, "v1") # Assert - assert "test_schema" in processor.schemas - schema = processor.schemas["test_schema"] + assert "test_schema" in processor.schemas["default"] + schema = processor.schemas["default"]["test_schema"] assert schema.name == "test_schema" assert schema.description == "Test schema" assert len(schema.fields) == 2 @@ -308,10 +308,10 @@ class TestNLPQueryProcessor: } # Act - await processor.on_schema_config(config, "v1") + await processor.on_schema_config("default", config, "v1") # Assert - bad schema should be ignored - assert "bad_schema" not in processor.schemas + assert "bad_schema" not in processor.schemas.get("default", {}) def test_processor_initialization(self, mock_pulsar_client): """Test processor initialization with correct specifications""" diff --git a/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py b/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py index 8ce1b97e..45ba9fda 100644 --- a/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py +++ b/tests/unit/test_retrieval/test_structured_diag/test_schema_selection.py @@ -101,7 +101,7 @@ def service(mock_schemas): taskgroup=MagicMock(), id="test-processor" ) - service.schemas = mock_schemas + service.schemas = {"default": dict(mock_schemas)} return service @@ -109,6 +109,7 @@ def service(mock_schemas): def mock_flow(): """Create mock flow with prompt service""" flow = MagicMock() + flow.workspace = "default" prompt_request_flow = AsyncMock() flow.return_value.request = prompt_request_flow return flow, prompt_request_flow diff --git a/tests/unit/test_retrieval/test_structured_query.py b/tests/unit/test_retrieval/test_structured_query.py index 9a183f45..20056c2a 100644 --- a/tests/unit/test_retrieval/test_structured_query.py +++ b/tests/unit/test_retrieval/test_structured_query.py @@ -44,7 +44,6 @@ class TestStructuredQueryProcessor: # Arrange request = StructuredQueryRequest( question="Show me all customers from New York", - user="trustgraph", collection="default" ) @@ -110,7 +109,6 @@ class TestStructuredQueryProcessor: assert isinstance(objects_call_args, RowsQueryRequest) assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }' assert objects_call_args.variables == {"state": "NY"} - assert objects_call_args.user == "trustgraph" assert objects_call_args.collection == "default" # Verify response diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index f9d60541..830da334 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test document embeddings @@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings for a single chunk""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify insert was called once for the single chunk with its vector processor.vecstore.insert.assert_called_once_with( @@ -99,14 +97,14 @@ class TestMilvusDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): """Test storing document embeddings for multiple chunks""" - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('test_workspace', mock_message) - # Verify insert was called once per chunk with user/collection parameters + # Verify insert was called once per chunk with workspace/collection parameters expected_calls = [ # Chunk 1 - single vector - ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_workspace', 'test_collection'), # Chunk 2 - single vector - ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_workspace', 'test_collection'), ] assert processor.vecstore.insert.call_count == 2 @@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunk (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no insert was called for empty chunk processor.vecstore.insert.assert_not_called() @@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with None chunk_id""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Note: Implementation passes through None chunk_ids (only skips empty string "") processor.vecstore.insert.assert_called_once_with( @@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with mix of valid and empty chunks""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' valid_chunk = ChunkEmbeddings( @@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [valid_chunk, empty_chunk, another_valid] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify valid chunks were inserted, empty string chunk was skipped expected_calls = [ @@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunks list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.chunks = [] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no insert was called processor.vecstore.insert.assert_not_called() @@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings for chunk with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no insert was called (no vectors to insert) processor.vecstore.insert.assert_not_called() @@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each chunk has a single vector of different dimensions @@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk1, chunk2, chunk3] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify all vectors were inserted regardless of dimension with user/collection parameters expected_calls = [ @@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with Unicode content in chunk_id""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify Unicode chunk_id was stored correctly with user/collection parameters processor.vecstore.insert.assert_called_once_with( @@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with long chunk_id""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a long chunk_id @@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify long chunk_id was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( @@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with whitespace-only chunk""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify whitespace content was inserted (not filtered out) with user/collection parameters processor.vecstore.insert.assert_called_once_with( @@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor: ('test@domain.com', 'test-collection.v1'), ] - for user, collection in test_cases: + for workspace, collection in test_cases: processor.vecstore.reset_mock() # Reset mock for each test case - + message = MagicMock() message.metadata = MagicMock() - message.metadata.user = user message.metadata.collection = collection - + chunk = ChunkEmbeddings( chunk_id="Test content", vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - - await processor.store_document_embeddings(message) - - # Verify insert was called with the correct user/collection + + await processor.store_document_embeddings(workspace, message) + + # Verify insert was called with the correct workspace/collection processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Test content", user, collection + [0.1, 0.2, 0.3], "Test content", workspace, collection ) @pytest.mark.asyncio @@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Store embeddings for user1/collection1 message1 = MagicMock() message1.metadata = MagicMock() - message1.metadata.user = 'user1' message1.metadata.collection = 'collection1' chunk1 = ChunkEmbeddings( chunk_id="User1 content", @@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Store embeddings for user2/collection2 message2 = MagicMock() message2.metadata = MagicMock() - message2.metadata.user = 'user2' message2.metadata.collection = 'collection2' chunk2 = ChunkEmbeddings( chunk_id="User2 content", @@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor: ) message2.chunks = [chunk2] - await processor.store_document_embeddings(message1) - await processor.store_document_embeddings(message2) + await processor.store_document_embeddings('user1', message1) + await processor.store_document_embeddings('user2', message2) # Verify both calls were made with correct parameters expected_calls = [ @@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor: """Test storing document embeddings with special characters in user/collection names""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'user@domain.com' # Email-like user message.metadata.collection = 'test-collection.v1' # Collection with special chars - + chunk = ChunkEmbeddings( chunk_id="Special chars test", vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] - - await processor.store_document_embeddings(message) - - # Verify the exact user/collection strings are passed (sanitization happens in DocVectors) + + await processor.store_document_embeddings('user@domain.com', message) + + # Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors) processor.vecstore.insert.assert_called_once_with( [0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1' ) diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py index fec4f87e..011780ed 100644 --- a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test document embeddings @@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings for a single chunk""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', side_effect=['id1', 'id2']): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index name and operations (with dimension suffix) expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test that writing to non-existent index creates it lazily""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index was created with correct dimension expected_index_name = "d-test_user-test_collection-3" # 3 dimensions @@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunk (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called for empty chunk mock_index.upsert.assert_not_called() @@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with None chunk (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called for None chunk mock_index.upsert.assert_not_called() @@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with chunk that decodes to empty string""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called for empty decoded chunk mock_index.upsert.assert_not_called() @@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each chunk has a single vector of different dimensions @@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with empty chunks list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.chunks = [] mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no operations were performed processor.pinecone.Index.assert_not_called() @@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings for chunk with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify no upsert was called (no vectors to insert) mock_index.upsert.assert_not_called() @@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test that lazy creation happens when index doesn't exist""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index was created processor.pinecone.create_index.assert_called_once() @@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test that lazy creation works correctly""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify index was created and used processor.pinecone.create_index.assert_called_once() @@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with Unicode content""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' chunk = ChunkEmbeddings( @@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify Unicode content was properly decoded and stored call_args = mock_index.upsert.call_args @@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor: """Test storing document embeddings with large document chunks""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a large document chunk @@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', return_value='test-id'): - await processor.store_document_embeddings(message) + await processor.store_document_embeddings('test_user', message) # Verify large content was stored call_args = mock_index.upsert.call_args diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index 98d2dab2..ce6e6b3d 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with chunks and vectors mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_chunk = MagicMock() @@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('test_user', mock_message) # Assert # Verify collection existence was checked (with dimension suffix) @@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with multiple chunks mock_message = MagicMock() - mock_message.metadata.user = 'multi_user' mock_message.metadata.collection = 'multi_collection' mock_chunk1 = MagicMock() @@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk1, mock_chunk2] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('multi_user', mock_message) # Assert # Should be called twice (once per chunk) @@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with multiple chunks, each having a single vector mock_message = MagicMock() - mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' mock_chunk1 = MagicMock() @@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('vector_user', mock_message) # Assert # Should be called 3 times (once per chunk) @@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with empty chunk_id mock_message = MagicMock() - mock_message.metadata.user = 'empty_user' mock_message.metadata.collection = 'empty_collection' mock_chunk_empty = MagicMock() @@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk_empty] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('empty_user', mock_message) # Assert # Should not call upsert for empty chunk_ids @@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'new_user' mock_message.metadata.collection = 'new_collection' mock_chunk = MagicMock() @@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('new_user', mock_message) # Assert - collection should be lazily created expected_collection = 'd_new_user_new_collection_5' # 5 dimensions @@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'error_user' mock_message.metadata.collection = 'error_collection' mock_chunk = MagicMock() @@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Act & Assert - should propagate the creation error with pytest.raises(Exception, match="Connection error"): - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('error_user', mock_message) @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create first mock message mock_message1 = MagicMock() - mock_message1.metadata.user = 'cache_user' mock_message1.metadata.collection = 'cache_collection' mock_chunk1 = MagicMock() @@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message1.chunks = [mock_chunk1] # First call - await processor.store_document_embeddings(mock_message1) + await processor.store_document_embeddings('cache_user', mock_message1) # Reset mock to track second call mock_qdrant_instance.reset_mock() @@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create second mock message with same dimensions mock_message2 = MagicMock() - mock_message2.metadata.user = 'cache_user' mock_message2.metadata.collection = 'cache_collection' mock_chunk2 = MagicMock() @@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message2.chunks = [mock_chunk2] # Act - Second call with same collection - await processor.store_document_embeddings(mock_message2) + await processor.store_document_embeddings('cache_user', mock_message2) # Assert expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions @@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with chunks of different dimensions mock_message = MagicMock() - mock_message.metadata.user = 'dim_user' mock_message.metadata.collection = 'dim_collection' mock_chunk1 = MagicMock() @@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk1, mock_chunk2] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('dim_user', mock_message) # Assert # Should check existence of DIFFERENT collections for each dimension @@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with URI-style chunk_id mock_message = MagicMock() - mock_message.metadata.user = 'uri_user' mock_message.metadata.collection = 'uri_collection' mock_chunk = MagicMock() @@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] # Act - await processor.store_document_embeddings(mock_message) + await processor.store_document_embeddings('uri_user', mock_message) # Assert # Verify the chunk_id was stored correctly diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index e4d60adf..7f3e7469 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test entities with embeddings @@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for a single entity""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify insert was called once with the full vector processor.vecstore.insert.assert_called_once() @@ -102,14 +100,14 @@ class TestMilvusGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): """Test storing graph embeddings for multiple entities""" - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('test_workspace', mock_message) - # Verify insert was called once per entity with user/collection parameters + # Verify insert was called once per entity with workspace/collection parameters expected_calls = [ # Entity 1 - single vector - ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_workspace', 'test_collection'), # Entity 2 - single vector - ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], 'literal entity', 'test_workspace', 'test_collection'), ] assert processor.vecstore.insert.call_count == 2 @@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called for empty entity processor.vecstore.insert.assert_not_called() @@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with None entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called for None entity processor.vecstore.insert.assert_not_called() @@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with mix of valid and invalid entities""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' valid_entity = EntityEmbeddings( @@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [valid_entity, empty_entity, none_entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify only valid entity was inserted with user/collection/chunk_id parameters processor.vecstore.insert.assert_called_once_with( @@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entities list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.entities = [] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called processor.vecstore.insert.assert_not_called() @@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for entity with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no insert was called (no vectors to insert) processor.vecstore.insert.assert_not_called() @@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each entity has a single vector of different dimensions @@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [entity1, entity2, entity3] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify all vectors were inserted regardless of dimension expected_calls = [ @@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for both URI and literal entities""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' uri_entity = EntityEmbeddings( @@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: ) message.entities = [uri_entity, literal_entity] - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify both entities were inserted expected_calls = [ diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py index 9ff53f4e..e0e5ce26 100644 --- a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create test entity embeddings (each entity has a single vector) @@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for a single entity""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', side_effect=['id1']): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index name and operations (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test that writing to non-existent index creates it lazily""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index was created with correct dimension expected_index_name = "t-test_user-test_collection-3" # 3 dimensions @@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no upsert was called for empty entity mock_index.upsert.assert_not_called() @@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with None entity value (should be skipped)""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no upsert was called for None entity mock_index.upsert.assert_not_called() @@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Each entity has a single vector of different dimensions @@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.has_index.return_value = True with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify different indexes were used for different dimensions index_calls = processor.pinecone.Index.call_args_list @@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings with empty entities list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.entities = [] mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no operations were performed processor.pinecone.Index.assert_not_called() @@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test storing graph embeddings for entity with no vectors""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify no upsert was called (no vectors to insert) mock_index.upsert.assert_not_called() @@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test that lazy creation happens when index doesn't exist""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index was created processor.pinecone.create_index.assert_called_once() @@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor: """Test that lazy creation works correctly""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' entity = EntityEmbeddings( @@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: processor.pinecone.Index.return_value = mock_index with patch('uuid.uuid4', return_value='test-id'): - await processor.store_graph_embeddings(message) + await processor.store_graph_embeddings('test_user', message) # Verify index was created and used processor.pinecone.create_index.assert_called_once() diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index 3541ccd4..d636e093 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with entities and vectors mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_entity = MagicMock() @@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('test_user', mock_message) # Assert # Verify collection existence was checked (with dimension suffix) @@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with multiple entities mock_message = MagicMock() - mock_message.metadata.user = 'multi_user' mock_message.metadata.collection = 'multi_collection' mock_entity1 = MagicMock() @@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity1, mock_entity2] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('multi_user', mock_message) # Assert # Should be called twice (once per entity) @@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with three entities mock_message = MagicMock() - mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' mock_entity1 = MagicMock() @@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity1, mock_entity2, mock_entity3] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('vector_user', mock_message) # Assert # Should be called 3 times (once per entity) @@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Create mock message with empty entity value mock_message = MagicMock() - mock_message.metadata.user = 'empty_user' mock_message.metadata.collection = 'empty_collection' mock_entity_empty = MagicMock() @@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.entities = [mock_entity_empty, mock_entity_none] # Act - await processor.store_graph_embeddings(mock_message) + await processor.store_graph_embeddings('empty_user', mock_message) # Assert # Should not call upsert for empty entities diff --git a/tests/unit/test_storage/test_memgraph_user_collection_isolation.py b/tests/unit/test_storage/test_memgraph_workspace_collection_isolation.py similarity index 53% rename from tests/unit/test_storage/test_memgraph_user_collection_isolation.py rename to tests/unit/test_storage/test_memgraph_workspace_collection_isolation.py index 9c330b77..ebc142f3 100644 --- a/tests/unit/test_storage/test_memgraph_user_collection_isolation.py +++ b/tests/unit/test_storage/test_memgraph_workspace_collection_isolation.py @@ -1,5 +1,5 @@ """ -Tests for Memgraph user/collection isolation in storage service +Tests for Memgraph workspace/collection isolation in storage service. """ import pytest @@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch from trustgraph.storage.triples.memgraph.write import Processor -class TestMemgraphUserCollectionIsolation: - """Test cases for Memgraph storage service with user/collection isolation""" +class TestMemgraphWorkspaceCollectionIsolation: + """Test cases for Memgraph storage service with workspace/collection isolation""" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): - """Test that storage creates both legacy and user/collection indexes""" + def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db): + """Test that storage creates both legacy and workspace/collection indexes""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = Processor(taskgroup=MagicMock()) - - # Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total) + + # 4 legacy + 4 workspace/collection = 8 total assert mock_session.run.call_count == 8 - - # Check some specific index creation calls + expected_calls = [ "CREATE INDEX ON :Node", "CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Literal", "CREATE INDEX ON :Literal(value)", - "CREATE INDEX ON :Node(user)", + "CREATE INDEX ON :Node(workspace)", "CREATE INDEX ON :Node(collection)", - "CREATE INDEX ON :Literal(user)", + "CREATE INDEX ON :Literal(workspace)", "CREATE INDEX ON :Literal(collection)" ] - + for expected_call in expected_calls: mock_session.run.assert_any_call(expected_call) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_user_collection(self, mock_graph_db): - """Test that store_triples includes user/collection in all operations""" + async def test_store_triples_with_workspace_collection(self, mock_graph_db): + """Test that store_triples includes workspace/collection in all operations""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - # Mock execute_query response mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 @@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) - # Create mock triple with URI object + from trustgraph.schema import IRI triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" - triple.o.value = "http://example.com/object" - triple.o.is_uri = True + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = IRI + triple.o.iri = "http://example.com/object" - # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) - # Verify user/collection parameters were passed to all operations - # Should have: create_node (subject), create_node (object), relate_node = 3 calls + # create_node (subject), create_node (object), relate_node = 3 calls assert mock_driver.execute_query.call_count == 3 - # Check that user and collection were included in all calls - for call in mock_driver.execute_query.call_args_list: - call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert 'user' in call_kwargs - assert 'collection' in call_kwargs - assert call_kwargs['user'] == "test_user" - assert call_kwargs['collection'] == "test_collection" + for c in mock_driver.execute_query.call_args_list: + kwargs = c.kwargs + assert kwargs['workspace'] == "test_workspace" + assert kwargs['collection'] == "test_collection" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_default_user_collection(self, mock_graph_db): - """Test that defaults are used when user/collection not provided in metadata""" + async def test_store_triples_with_default_collection(self, mock_graph_db): + """Test that default collection is used when not provided in metadata""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - # Mock execute_query response mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 @@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation: processor = Processor(taskgroup=MagicMock()) - # Create mock triple + from trustgraph.schema import IRI, LITERAL triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = LITERAL triple.o.value = "literal_value" - triple.o.is_uri = False - # Create mock message without user/collection metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = None mock_message.metadata.collection = None - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("default", mock_message) - # Verify defaults were used - for call in mock_driver.execute_query.call_args_list: - call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert call_kwargs['user'] == "default" - assert call_kwargs['collection'] == "default" + for c in mock_driver.execute_query.call_args_list: + kwargs = c.kwargs + assert kwargs['workspace'] == "default" + assert kwargs['collection'] == "default" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_create_node_includes_user_collection(self, mock_graph_db): - """Test that create_node includes user/collection properties""" + def test_create_node_includes_workspace_collection(self, mock_graph_db): + """Test that create_node includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - - processor.create_node("http://example.com/node", "test_user", "test_collection") - + + processor.create_node("http://example.com/node", "test_workspace", "test_collection") + mock_driver.execute_query.assert_called_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/node", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_create_literal_includes_user_collection(self, mock_graph_db): - """Test that create_literal includes user/collection properties""" + def test_create_literal_includes_workspace_collection(self, mock_graph_db): + """Test that create_literal includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - - processor.create_literal("test_value", "test_user", "test_collection") - + + processor.create_literal("test_value", "test_workspace", "test_collection") + mock_driver.execute_query.assert_called_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value="test_value", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_relate_node_includes_user_collection(self, mock_graph_db): - """Test that relate_node includes user/collection properties""" + def test_relate_node_includes_workspace_collection(self, mock_graph_db): + """Test that relate_node includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 0 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - + processor.relate_node( "http://example.com/subject", - "http://example.com/predicate", + "http://example.com/predicate", "http://example.com/object", - "test_user", + "test_workspace", "test_collection" ) - + mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="http://example.com/object", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') - def test_relate_literal_includes_user_collection(self, mock_graph_db): - """Test that relate_literal includes user/collection properties""" + def test_relate_literal_includes_workspace_collection(self, mock_graph_db): + """Test that relate_literal includes workspace/collection properties""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 0 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - + processor.relate_literal( "http://example.com/subject", "http://example.com/predicate", "literal_value", - "test_user", + "test_workspace", "test_collection" ) - + mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal_value", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="memgraph" ) @@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation: def test_add_args_includes_memgraph_parameters(self): """Test that add_args properly configures Memgraph-specific parameters""" from argparse import ArgumentParser - from unittest.mock import patch - + parser = ArgumentParser() - - # Mock the parent class add_args method + with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args: Processor.add_args(parser) - - # Verify parent add_args was called mock_parent_add_args.assert_called_once() - - # Verify our specific arguments were added with Memgraph defaults + args = parser.parse_args([]) - + assert hasattr(args, 'graph_host') assert args.graph_host == 'bolt://memgraph:7687' assert hasattr(args, 'username') @@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation: assert args.database == 'memgraph' -class TestMemgraphUserCollectionRegression: - """Regression tests to ensure user/collection isolation prevents data leakage""" +class TestMemgraphWorkspaceCollectionRegression: + """Regression tests to ensure workspace/collection isolation prevents data leakage""" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_regression_no_cross_user_data_access(self, mock_graph_db): - """Regression test: Ensure users cannot access each other's data""" + async def test_regression_no_cross_workspace_data_access(self, mock_graph_db): + """Regression test: Ensure workspaces cannot access each other's data""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - # Mock execute_query response mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 @@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression: processor = Processor(taskgroup=MagicMock()) - # Store data for user1 + from trustgraph.schema import IRI, LITERAL triple = MagicMock() - triple.s.value = "http://example.com/subject" - triple.p.value = "http://example.com/predicate" - triple.o.value = "user1_data" - triple.o.is_uri = False + triple.s.type = IRI + triple.s.iri = "http://example.com/subject" + triple.p.type = IRI + triple.p.iri = "http://example.com/predicate" + triple.o.type = LITERAL + triple.o.value = "ws1_data" - message_user1 = MagicMock() - message_user1.triples = [triple] - message_user1.metadata.user = "user1" - message_user1.metadata.collection = "collection1" + message_ws1 = MagicMock() + message_ws1.triples = [triple] + message_ws1.metadata.collection = "collection1" - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message_user1) + await processor.store_triples("workspace1", message_ws1) - # Verify that all storage operations included user1/collection1 parameters - for call in mock_driver.execute_query.call_args_list: - call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - if 'user' in call_kwargs: - assert call_kwargs['user'] == "user1" - assert call_kwargs['collection'] == "collection1" + for c in mock_driver.execute_query.call_args_list: + kwargs = c.kwargs + if 'workspace' in kwargs: + assert kwargs['workspace'] == "workspace1" + assert kwargs['collection'] == "collection1" @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @pytest.mark.asyncio - async def test_regression_same_uri_different_users(self, mock_graph_db): - """Regression test: Same URI can exist for different users without conflict""" + async def test_regression_same_uri_different_workspaces(self, mock_graph_db): + """Regression test: Same URI can exist in different workspaces without conflict""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Mock execute_query response + mock_result = MagicMock() mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_result.summary = mock_summary mock_driver.execute_query.return_value = mock_result - + processor = Processor(taskgroup=MagicMock()) - - # Same URI for different users should create separate nodes - processor.create_node("http://example.com/same-uri", "user1", "collection1") - processor.create_node("http://example.com/same-uri", "user2", "collection2") - - # Verify both calls were made with different user/collection parameters - calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls - - call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1] - call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1] - - assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1" - assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2" - - # Both should have the same URI but different user/collection - assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri" \ No newline at end of file + + processor.create_node("http://example.com/same-uri", "workspace1", "collection1") + processor.create_node("http://example.com/same-uri", "workspace2", "collection2") + + calls = mock_driver.execute_query.call_args_list[-2:] + + k1 = calls[0].kwargs + k2 = calls[1].kwargs + + assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1" + assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2" + + assert k1['uri'] == k2['uri'] == "http://example.com/same-uri" diff --git a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py b/tests/unit/test_storage/test_neo4j_workspace_collection_isolation.py similarity index 51% rename from tests/unit/test_storage/test_neo4j_user_collection_isolation.py rename to tests/unit/test_storage/test_neo4j_workspace_collection_isolation.py index dce170a7..967c144d 100644 --- a/tests/unit/test_storage/test_neo4j_user_collection_isolation.py +++ b/tests/unit/test_storage/test_neo4j_workspace_collection_isolation.py @@ -1,5 +1,5 @@ """ -Tests for Neo4j user/collection isolation in triples storage and query +Tests for Neo4j workspace/collection isolation in triples storage and query. """ import pytest @@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL from trustgraph.schema import TriplesQueryRequest -class TestNeo4jUserCollectionIsolation: - """Test cases for Neo4j user/collection isolation functionality""" +class TestNeo4jWorkspaceCollectionIsolation: + """Test cases for Neo4j workspace/collection isolation functionality""" @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') - def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): - """Test that storage service creates compound indexes for user/collection""" + def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db): + """Test that storage service creates compound indexes for workspace/collection""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Verify both legacy and new compound indexes are created + expected_indexes = [ "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", - "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", - "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", - "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)", + "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)", + "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)", "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" ] - - # Check that all expected indexes were created + for expected_query in expected_indexes: mock_session.run.assert_any_call(expected_query) @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_user_collection(self, mock_graph_db): - """Test that triples are stored with user/collection properties""" + async def test_store_triples_with_workspace_collection(self, mock_graph_db): + """Test that triples are stored with workspace/collection properties""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Create test message with user/collection metadata - metadata = Metadata( - id="test-id", - user="test_user", - collection="test_collection" - ) - + + metadata = Metadata(id="test-id", collection="test_collection") + triple = Triple( s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), o=Term(type=LITERAL, value="literal_value") ) - - message = Triples( - metadata=metadata, - triples=[triple] - ) - - # Mock execute_query to return summaries + + message = Triples(metadata=metadata, triples=[triple]) + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) - - # Verify nodes and relationships were created with user/collection properties + await processor.store_triples("test_workspace", message) + expected_calls = [ call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/subject", - user="test_user", + workspace="test_workspace", collection="test_collection", database_='neo4j' ), call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value="literal_value", - user="test_user", + workspace="test_workspace", collection="test_collection", database_='neo4j' ), call( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal_value", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_='neo4j' ) ] - + for expected_call in expected_calls: mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs) @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_store_triples_with_default_user_collection(self, mock_graph_db): - """Test that default user/collection are used when not provided""" + async def test_store_triples_with_default_collection(self, mock_graph_db): + """Test that default collection is used when not provided""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Create test message without user/collection + metadata = Metadata(id="test-id") - + triple = Triple( s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), o=Term(type=IRI, iri="http://example.com/object") ) - - message = Triples( - metadata=metadata, - triples=[triple] - ) - - # Mock execute_query + + message = Triples(metadata=metadata, triples=[triple]) + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) - - # Verify defaults were used + await processor.store_triples("default", message) + mock_driver.execute_query.assert_any_call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/subject", - user="default", + workspace="default", collection="default", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_query_triples_filters_by_user_collection(self, mock_graph_db): - """Test that query service filters results by user/collection""" + async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db): + """Test that query service filters results by workspace/collection""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # Create test query + query = TriplesQueryRequest( - user="test_user", collection="test_collection", s=Term(type=IRI, iri="http://example.com/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), o=None ) - - # Mock query results + mock_records = [ MagicMock(data=lambda: {"dest": "http://example.com/object1"}), MagicMock(data=lambda: {"dest": "literal_value"}) ] - + mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock()) - - result = await processor.query_triples(query) - - # Verify queries include user/collection filters + + await processor.query_triples("test_workspace", query) + expected_literal_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest" ) - - expected_node_query = ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " - "RETURN dest.uri as dest" - ) - - # Check that queries were executed with user/collection parameters + calls = mock_driver.execute_query.call_args_list assert any( - expected_literal_query in str(call) and - "user='test_user'" in str(call) and - "collection='test_collection'" in str(call) - for call in calls + expected_literal_query in str(c) and + "workspace='test_workspace'" in str(c) and + "collection='test_collection'" in str(c) + for c in calls ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_query_triples_with_default_user_collection(self, mock_graph_db): - """Test that query service uses defaults when user/collection not provided""" + async def test_query_triples_with_default_collection(self, mock_graph_db): + """Test that query service uses default collection when not provided""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # Create test query without user/collection - query = TriplesQueryRequest( - s=None, - p=None, - o=None - ) - - # Mock empty results + + query = TriplesQueryRequest(s=None, p=None, o=None) + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - - result = await processor.query_triples(query) - - # Verify defaults were used in queries + + await processor.query_triples("default", query) + calls = mock_driver.execute_query.call_args_list assert any( - "user='default'" in str(call) and "collection='default'" in str(call) - for call in calls + "workspace='default'" in str(c) and "collection='default'" in str(c) + for c in calls ) @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_data_isolation_between_users(self, mock_graph_db): - """Test that data from different users is properly isolated""" + async def test_data_isolation_between_workspaces(self, mock_graph_db): + """Test that data from different workspaces is properly isolated""" taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Create messages for different users - message_user1 = Triples( - metadata=Metadata(user="user1", collection="coll1"), + + message_ws1 = Triples( + metadata=Metadata(collection="coll1"), triples=[ Triple( - s=Term(type=IRI, iri="http://example.com/user1/subject"), + s=Term(type=IRI, iri="http://example.com/ws1/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), - o=Term(type=LITERAL, value="user1_data") + o=Term(type=LITERAL, value="ws1_data") ) ] ) - - message_user2 = Triples( - metadata=Metadata(user="user2", collection="coll2"), + + message_ws2 = Triples( + metadata=Metadata(collection="coll2"), triples=[ Triple( - s=Term(type=IRI, iri="http://example.com/user2/subject"), + s=Term(type=IRI, iri="http://example.com/ws2/subject"), p=Term(type=IRI, iri="http://example.com/predicate"), - o=Term(type=LITERAL, value="user2_data") + o=Term(type=LITERAL, value="ws2_data") ) ] ) - - # Mock execute_query + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - # Store data for both users - await processor.store_triples(message_user1) - await processor.store_triples(message_user2) - - # Verify user1 data was stored with user1/coll1 + await processor.store_triples("workspace1", message_ws1) + await processor.store_triples("workspace2", message_ws2) + mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value="user1_data", - user="user1", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value="ws1_data", + workspace="workspace1", collection="coll1", database_='neo4j' ) - - # Verify user2 data was stored with user2/coll2 + mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value="user2_data", - user="user2", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value="ws2_data", + workspace="workspace2", collection="coll2", database_='neo4j' ) @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @pytest.mark.asyncio - async def test_wildcard_query_respects_user_collection(self, mock_graph_db): - """Test that wildcard queries still filter by user/collection""" + async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db): + """Test that wildcard queries still filter by workspace/collection""" mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # Create wildcard query (all nulls) with user/collection + query = TriplesQueryRequest( - user="test_user", collection="test_collection", - s=None, - p=None, - o=None + s=None, p=None, o=None, ) - - # Mock results + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - - result = await processor.query_triples(query) - - # Verify wildcard queries include user/collection filters + + await processor.query_triples("test_workspace", query) + wildcard_query = ( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest" ) - + calls = mock_driver.execute_query.call_args_list assert any( - wildcard_query in str(call) and - "user='test_user'" in str(call) and - "collection='test_collection'" in str(call) - for call in calls + wildcard_query in str(c) and + "workspace='test_workspace'" in str(c) and + "collection='test_collection'" in str(c) + for c in calls ) def test_add_args_includes_neo4j_parameters(self): """Test that add_args includes Neo4j-specific parameters""" from argparse import ArgumentParser - from unittest.mock import patch - + parser = ArgumentParser() - + with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'): StorageProcessor.add_args(parser) - + args = parser.parse_args([]) - + assert hasattr(args, 'graph_host') assert hasattr(args, 'username') assert hasattr(args, 'password') assert hasattr(args, 'database') - - # Check defaults + assert args.graph_host == 'bolt://neo4j:7687' assert args.username == 'neo4j' assert args.password == 'password' assert args.database == 'neo4j' -class TestNeo4jUserCollectionRegression: - """Regression tests to ensure user/collection isolation prevents data leaks""" - +class TestNeo4jWorkspaceCollectionRegression: + """Regression tests to ensure workspace/collection isolation prevents data leaks""" + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') - @pytest.mark.asyncio - async def test_regression_no_cross_user_data_access(self, mock_graph_db): + @pytest.mark.asyncio + async def test_regression_no_cross_workspace_data_access(self, mock_graph_db): """ - Regression test: Ensure user1 cannot access user2's data - - This test guards against the bug where all users shared the same - Neo4j graph space, causing data contamination between users. + Regression test: Ensure workspace1 cannot access workspace2's data. + + Guards against a bug where all data shared the same Neo4j graph + space, causing data contamination between workspaces. """ mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver - + processor = QueryProcessor(taskgroup=MagicMock()) - - # User1 queries for all triples - query_user1 = TriplesQueryRequest( - user="user1", + + query_ws1 = TriplesQueryRequest( collection="collection1", s=None, p=None, o=None ) - - # Mock that the database has data but none matching user1/collection1 + mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) - - result = await processor.query_triples(query_user1) - - # Verify empty results (user1 cannot see other users' data) + + result = await processor.query_triples("workspace1", query_ws1) + assert len(result) == 0 - - # Verify the query included user/collection filters + calls = mock_driver.execute_query.call_args_list - for call in calls: - query_str = str(call) + for c in calls: + query_str = str(c) if "MATCH" in query_str: - assert "user: $user" in query_str or "user='user1'" in query_str + assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str assert "collection: $collection" in query_str or "collection='collection1'" in query_str - + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @pytest.mark.asyncio - async def test_regression_same_uri_different_users(self, mock_graph_db): + async def test_regression_same_uri_different_workspaces(self, mock_graph_db): """ - Regression test: Same URI in different user contexts should create separate nodes - - This ensures that http://example.com/entity for user1 is completely separate - from http://example.com/entity for user2. + Regression test: Same URI in different workspace contexts should create separate nodes. + + Ensures http://example.com/entity in workspace1 is completely + separate from the same URI in workspace2. """ taskgroup_mock = MagicMock() mock_driver = MagicMock() mock_graph_db.driver.return_value = mock_driver mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - + processor = StorageProcessor(taskgroup=taskgroup_mock) - - # Same URI for different users + shared_uri = "http://example.com/shared_entity" - - message_user1 = Triples( - metadata=Metadata(user="user1", collection="coll1"), + + message_ws1 = Triples( + metadata=Metadata(collection="coll1"), triples=[ Triple( s=Term(type=IRI, iri=shared_uri), p=Term(type=IRI, iri="http://example.com/p"), - o=Term(type=LITERAL, value="user1_value") + o=Term(type=LITERAL, value="ws1_value") ) ] ) - - message_user2 = Triples( - metadata=Metadata(user="user2", collection="coll2"), + + message_ws2 = Triples( + metadata=Metadata(collection="coll2"), triples=[ Triple( s=Term(type=IRI, iri=shared_uri), p=Term(type=IRI, iri="http://example.com/p"), - o=Term(type=LITERAL, value="user2_value") + o=Term(type=LITERAL, value="ws2_value") ) ] ) - - # Mock execute_query + mock_summary = MagicMock() mock_summary.counters.nodes_created = 1 mock_summary.result_available_after = 10 mock_driver.execute_query.return_value.summary = mock_summary - # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message_user1) - await processor.store_triples(message_user2) - - # Verify two separate nodes were created with same URI but different user/collection - user1_node_call = call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + await processor.store_triples("workspace1", message_ws1) + await processor.store_triples("workspace2", message_ws2) + + ws1_node_call = call( + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri=shared_uri, - user="user1", + workspace="workspace1", collection="coll1", database_='neo4j' ) - - user2_node_call = call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + + ws2_node_call = call( + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri=shared_uri, - user="user2", + workspace="workspace2", collection="coll2", database_='neo4j' ) - - mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True) \ No newline at end of file + + mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True) diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index e1c8f3b1..8754f47c 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -1,3 +1,12 @@ + +def _flow_mock(workspace): + """Build a mock flow object that is callable and exposes .workspace.""" + from unittest.mock import MagicMock + f = MagicMock() + f.workspace = workspace + return f + + """ Unit tests for trustgraph.storage.row_embeddings.qdrant.write Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant. @@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) collection_name = processor.get_collection_name( - user="test_user", + workspace="test_workspace", collection="test_collection", schema_name="customer_data", dimension=384 ) - assert collection_name == "rows_test_user_test_collection_customer_data_384" + assert collection_name == "rows_test_workspace_test_collection_customer_data_384" @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_ensure_collection_creates_new(self, mock_qdrant_client): @@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('test_workspace', 'test_collection')] = {} # Create embeddings message metadata = MagicMock() - metadata.user = 'test_user' metadata.collection = 'test_collection' metadata.id = 'doc-123' @@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Verify upsert was called mock_qdrant_instance.upsert.assert_called_once() # Verify upsert parameters upsert_call_args = mock_qdrant_instance.upsert.call_args - assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3' + assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3' point = upsert_call_args[1]['points'][0] assert point.vector == [0.1, 0.2, 0.3] @@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('test_workspace', 'test_collection')] = {} metadata = MagicMock() - metadata.user = 'test_user' metadata.collection = 'test_collection' metadata.id = 'doc-123' @@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Should be called once for the single embedding assert mock_qdrant_instance.upsert.call_count == 1 @@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.known_collections[('test_user', 'test_collection')] = {} + processor.known_collections[('test_workspace', 'test_collection')] = {} metadata = MagicMock() - metadata.user = 'test_user' metadata.collection = 'test_collection' metadata.id = 'doc-123' @@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Should not call upsert for empty vectors mock_qdrant_instance.upsert.assert_not_called() @@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): # No collections registered metadata = MagicMock() - metadata.user = 'unknown_user' metadata.collection = 'unknown_collection' metadata.id = 'doc-123' @@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_msg = MagicMock() mock_msg.value.return_value = embeddings_msg - await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace')) # Should not call upsert for unknown collection mock_qdrant_instance.upsert.assert_not_called() @@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): # Mock collections list mock_coll1 = MagicMock() - mock_coll1.name = 'rows_test_user_test_collection_schema1_384' + mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384' mock_coll2 = MagicMock() - mock_coll2.name = 'rows_test_user_test_collection_schema2_384' + mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384' mock_coll3 = MagicMock() - mock_coll3.name = 'rows_other_user_other_collection_schema_384' + mock_coll3.name = 'rows_other_workspace_other_collection_schema_384' mock_collections = MagicMock() mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3] @@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add('rows_test_user_test_collection_schema1_384') + processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') - await processor.delete_collection('test_user', 'test_collection') + await processor.delete_collection('test_workspace', 'test_collection') # Should delete only the matching collections assert mock_qdrant_instance.delete_collection.call_count == 2 # Verify the cached collection was removed - assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections + assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_delete_collection_schema(self, mock_qdrant_client): @@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_instance = MagicMock() mock_coll1 = MagicMock() - mock_coll1.name = 'rows_test_user_test_collection_customers_384' + mock_coll1.name = 'rows_test_workspace_test_collection_customers_384' mock_coll2 = MagicMock() - mock_coll2.name = 'rows_test_user_test_collection_orders_384' + mock_coll2.name = 'rows_test_workspace_test_collection_orders_384' mock_collections = MagicMock() mock_collections.collections = [mock_coll1, mock_coll2] @@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) await processor.delete_collection_schema( - 'test_user', 'test_collection', 'customers' + 'test_workspace', 'test_collection', 'customers' ) # Should only delete the customers schema collection mock_qdrant_instance.delete_collection.assert_called_once() call_args = mock_qdrant_instance.delete_collection.call_args[0] - assert call_args[0] == 'rows_test_user_test_collection_customers_384' + assert call_args[0] == 'rows_test_workspace_test_collection_customers_384' if __name__ == '__main__': diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index ccf193aa..852f01a1 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class _MockFlowDefault: + """Mock Flow with default workspace for testing.""" + workspace = "default" + name = "default" + id = "test-processor" + + +mock_flow_default = _MockFlowDefault() + class TestRowsCassandraStorageLogic: """Test business logic for unified table implementation""" @@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic: } # Process configuration - await processor.on_schema_config(config, version=1) + await processor.on_schema_config("default", config, version=1) # Verify schema was loaded - assert "customer_records" in processor.schemas - schema = processor.schemas["customer_records"] + assert "customer_records" in processor.schemas["default"] + schema = processor.schemas["default"]["customer_records"] assert schema.name == "customer_records" assert len(schema.fields) == 3 @@ -165,16 +176,18 @@ class TestRowsCassandraStorageLogic: """Test that row processing stores data as map""" processor = MagicMock() processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="value", type="string", size=100) - ] - ) + "default": { + "test_schema": RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="value", type="string", size=100) + ] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -191,7 +204,6 @@ class TestRowsCassandraStorageLogic: test_obj = ExtractedObject( metadata=Metadata( id="test-001", - user="test_user", collection="test_collection", ), schema_name="test_schema", @@ -205,7 +217,7 @@ class TestRowsCassandraStorageLogic: msg.value.return_value = test_obj # Process object - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify insert was executed mock_async_execute.assert_called() @@ -214,7 +226,7 @@ class TestRowsCassandraStorageLogic: values = insert_call[0][2] # Verify using unified rows table - assert "INSERT INTO test_user.rows" in insert_cql + assert "INSERT INTO default.rows" in insert_cql # Values should be: (collection, schema_name, index_name, index_value, data, source) assert values[0] == "test_collection" # collection @@ -230,16 +242,18 @@ class TestRowsCassandraStorageLogic: """Test that row is written once per indexed field""" processor = MagicMock() processor.schemas = { - "multi_index_schema": RowSchema( - name="multi_index_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="category", type="string", indexed=True), - Field(name="status", type="string", indexed=True) - ] - ) + "default": { + "multi_index_schema": RowSchema( + name="multi_index_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="status", type="string", indexed=True) + ] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -255,7 +269,6 @@ class TestRowsCassandraStorageLogic: test_obj = ExtractedObject( metadata=Metadata( id="test-001", - user="test_user", collection="test_collection", ), schema_name="multi_index_schema", @@ -267,7 +280,7 @@ class TestRowsCassandraStorageLogic: msg = MagicMock() msg.value.return_value = test_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 inserts (one per indexed field: id, category, status) assert mock_async_execute.call_count == 3 @@ -290,15 +303,17 @@ class TestRowsCassandraStorageBatchLogic: """Test processing of batch ExtractedObjects""" processor = MagicMock() processor.schemas = { - "batch_schema": RowSchema( - name="batch_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string") - ] - ) + "default": { + "batch_schema": RowSchema( + name="batch_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string") + ] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -315,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic: batch_obj = ExtractedObject( metadata=Metadata( id="batch-001", - user="test_user", collection="batch_collection", ), schema_name="batch_schema", @@ -331,7 +345,7 @@ class TestRowsCassandraStorageBatchLogic: msg = MagicMock() msg.value.return_value = batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Should have 3 inserts (one per row, one index per row since only primary key) assert mock_async_execute.call_count == 3 @@ -349,12 +363,14 @@ class TestRowsCassandraStorageBatchLogic: """Test processing of empty batch ExtractedObjects""" processor = MagicMock() processor.schemas = { - "empty_schema": RowSchema( - name="empty_schema", - fields=[Field(name="id", type="string", primary=True)] - ) + "default": { + "empty_schema": RowSchema( + name="empty_schema", + fields=[Field(name="id", type="string", primary=True)] + ) + } } - processor.tables_initialized = {"test_user"} + processor.tables_initialized = {"default"} processor.registered_partitions = set() processor.session = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) @@ -369,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic: empty_batch_obj = ExtractedObject( metadata=Metadata( id="empty-001", - user="test_user", collection="empty_collection", ), schema_name="empty_schema", @@ -381,7 +396,7 @@ class TestRowsCassandraStorageBatchLogic: msg = MagicMock() msg.value.return_value = empty_batch_obj - await processor.on_object(msg, None, None) + await processor.on_object(msg, None, mock_flow_default) # Verify no insert calls for empty batch processor.session.execute.assert_not_called() @@ -446,19 +461,21 @@ class TestPartitionRegistration: processor.registered_partitions = set() processor.session = MagicMock() processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="category", type="string", indexed=True) - ] - ) + "default": { + "test_schema": RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True) + ] + ) + } } processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) - processor.register_partitions("test_user", "test_collection", "test_schema") + processor.register_partitions("test_user", "test_collection", "test_schema", "default") # Should have 2 inserts (one per index: id, category) assert processor.session.execute.call_count == 2 @@ -473,7 +490,7 @@ class TestPartitionRegistration: processor.session = MagicMock() processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) - processor.register_partitions("test_user", "test_collection", "test_schema") + processor.register_partitions("test_user", "test_collection", "test_schema", "default") # Should not execute any CQL since already registered processor.session.execute.assert_not_called() diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 73272942..04acbb16 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -102,11 +102,10 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify KnowledgeGraph was called with auth parameters mock_kg_class.assert_called_once_with( @@ -129,11 +128,10 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user2' mock_message.metadata.collection = 'collection2' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user2', mock_message) # Verify KnowledgeGraph was called without auth parameters mock_kg_class.assert_called_once_with( @@ -154,16 +152,15 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] # First call should create TrustGraph - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) assert mock_kg_class.call_count == 1 # Second call with same table should reuse TrustGraph - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) assert mock_kg_class.call_count == 1 # Should not increase @pytest.mark.asyncio @@ -205,11 +202,10 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [triple1, triple2] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) assert mock_tg_instance.insert.call_count == 2 @@ -234,11 +230,10 @@ class TestCassandraStorageProcessor: # Create mock message with empty triples mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify no triples were inserted mock_tg_instance.insert.assert_not_called() @@ -255,12 +250,11 @@ class TestCassandraStorageProcessor: # Create mock message mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify sleep was called before re-raising mock_sleep.assert_called_once_with(1) @@ -361,21 +355,19 @@ class TestCassandraStorageProcessor: # First message with table1 mock_message1 = MagicMock() - mock_message1.metadata.user = 'user1' mock_message1.metadata.collection = 'collection1' mock_message1.triples = [] - await processor.store_triples(mock_message1) + await processor.store_triples('user1', mock_message1) assert processor.table == 'user1' assert processor.tg == mock_tg_instance1 # Second message with different table mock_message2 = MagicMock() - mock_message2.metadata.user = 'user2' mock_message2.metadata.collection = 'collection2' mock_message2.triples = [] - await processor.store_triples(mock_message2) + await processor.store_triples('user2', mock_message2) assert processor.table == 'user2' assert processor.tg == mock_tg_instance2 @@ -407,11 +399,10 @@ class TestCassandraStorageProcessor: triple.g = None mock_message = MagicMock() - mock_message.metadata.user = 'test_user' mock_message.metadata.collection = 'test_collection' mock_message.triples = [triple] - await processor.store_triples(mock_message) + await processor.store_triples('test_workspace', mock_message) # Verify the triple was inserted with special characters preserved mock_tg_instance.insert.assert_called_once_with( @@ -440,12 +431,11 @@ class TestCassandraStorageProcessor: mock_kg_class.side_effect = Exception("Connection failed") mock_message = MagicMock() - mock_message.metadata.user = 'new_user' mock_message.metadata.collection = 'new_collection' mock_message.triples = [] with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples(mock_message) + await processor.store_triples('new_user', mock_message) # Table should remain unchanged since self.table = table happens after try/except assert processor.table == ('old_user', 'old_collection') @@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations: processor = Processor(taskgroup=taskgroup_mock) mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify KnowledgeGraph instance uses legacy mode assert mock_tg_instance is not None @@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations: processor = Processor(taskgroup=taskgroup_mock) mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify KnowledgeGraph instance is in optimized mode assert mock_tg_instance is not None @@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations: triple.g = None mock_message = MagicMock() - mock_message.metadata.user = 'user1' mock_message.metadata.collection = 'collection1' mock_message.triples = [triple] - await processor.store_triples(mock_message) + await processor.store_triples('user1', mock_message) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) mock_tg_instance.insert.assert_called_once_with( diff --git a/tests/unit/test_storage/test_triples_falkordb_storage.py b/tests/unit/test_storage/test_triples_falkordb_storage.py index 05dcb2e5..c5b0848e 100644 --- a/tests/unit/test_storage/test_triples_falkordb_storage.py +++ b/tests/unit/test_storage/test_triples_falkordb_storage.py @@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a test triple @@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_node(test_uri, 'test_user', 'test_collection') + processor.create_node(test_uri, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", params={ "uri": test_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_literal(test_value, 'test_user', 'test_collection') + processor.create_literal(test_value, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", params={ "value": test_value, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection') + processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src_uri, "dest": dest_uri, "uri": pred_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection') + processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src_uri, "dest": literal_value, "uri": pred_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor: """Test storing triple with URI object""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple = Triple( @@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify queries were called in the correct order expected_calls = [ # Create subject node - (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), - {"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",), + {"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}), # Create object node - (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), - {"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",), + {"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}), # Create relationship - (("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), - {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), + (("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}), ] assert processor.io.query.call_count == 3 @@ -237,21 +235,21 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) - + await processor.store_triples('test_workspace', mock_message) + # Verify queries were called in the correct order expected_calls = [ # Create subject node - (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), - {"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",), + {"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}), # Create literal object - (("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",), - {"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}), + (("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",), + {"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}), # Create relationship - (("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), - {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), + (("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}), ] assert processor.io.query.call_count == 3 @@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor: """Test storing multiple triples""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple1 = Triple( @@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify total number of queries (3 per triple) assert processor.io.query.call_count == 6 @@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor: """Test storing empty triples list""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.triples = [] @@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify no queries were made processor.io.query.assert_not_called() @@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor: """Test storing triples with mixed URI and literal objects""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple1 = Triple( @@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify total number of queries (3 per triple) assert processor.io.query.call_count == 6 @@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_node(test_uri, 'test_user', 'test_collection') + processor.create_node(test_uri, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", params={ "uri": test_uri, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) @@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_literal(test_value, 'test_user', 'test_collection') + processor.create_literal(test_value, 'test_workspace', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", params={ "value": test_value, - "user": 'test_user', + "workspace": 'test_workspace', "collection": 'test_collection', }, ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_memgraph_storage.py b/tests/unit/test_storage/test_triples_memgraph_storage.py index 162586d5..6a0c68a2 100644 --- a/tests/unit/test_storage/test_triples_memgraph_storage.py +++ b/tests/unit/test_storage/test_triples_memgraph_storage.py @@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor: """Create a mock message for testing""" message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' # Create a test triple @@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor: taskgroup=MagicMock(), id='test-memgraph-storage', graph_host='bolt://localhost:7687', - username='test_user', + username='test_workspace', password='test_pass', database='test_db' ) @@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor: "CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Literal", "CREATE INDEX ON :Literal(value)", - "CREATE INDEX ON :Node(user)", + "CREATE INDEX ON :Node(workspace)", "CREATE INDEX ON :Node(collection)", - "CREATE INDEX ON :Literal(user)", + "CREATE INDEX ON :Literal(workspace)", "CREATE INDEX ON :Literal(collection)" ] @@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.create_node(test_uri, "test_user", "test_collection") + processor.create_node(test_uri, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri=test_uri, - user="test_user", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.create_literal(test_value, "test_user", "test_collection") + processor.create_literal(test_value, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value=test_value, - user="test_user", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection") + processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src=src_uri, dest=dest_uri, uri=pred_uri, - user="test_user", collection="test_collection", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor: processor.io.execute_query.return_value = mock_result - processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection") + processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection") processor.io.execute_query.assert_called_once_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src=src_uri, dest=literal_value, uri=pred_uri, - user="test_user", collection="test_collection", + workspace="test_workspace", collection="test_collection", database_=processor.db ) @@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor: o=Term(type=IRI, iri='http://example.com/object') ) - processor.create_triple(mock_tx, triple, "test_user", "test_collection") + processor.create_triple(mock_tx, triple, "test_workspace", "test_collection") # Verify transaction calls expected_calls = [ # Create subject node - ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create object node - ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create relationship - ("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + ("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", {'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate', - 'user': 'test_user', 'collection': 'test_collection'}) + 'workspace': 'test_workspace', 'collection': 'test_collection'}) ] assert mock_tx.run.call_count == 3 @@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor: o=Term(type=LITERAL, value='literal object') ) - processor.create_triple(mock_tx, triple, "test_user", "test_collection") + processor.create_triple(mock_tx, triple, "test_workspace", "test_collection") # Verify transaction calls expected_calls = [ # Create subject node - ("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create literal object - ("MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - {'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}), + ("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + {'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}), # Create relationship - ("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + ("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", {'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate', - 'user': 'test_user', 'collection': 'test_collection'}) + 'workspace': 'test_workspace', 'collection': 'test_collection'}) ] assert mock_tx.run.call_count == 3 @@ -314,8 +313,8 @@ class TestMemgraphStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) - + await processor.store_triples('test_workspace', mock_message) + # Verify execute_query was called for create_node, create_literal, and relate_literal # (since mock_message has a literal object) assert processor.io.execute_query.call_count == 3 @@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor: # Verify user/collection parameters were included for call in processor.io.execute_query.call_args_list: call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert 'user' in call_kwargs + assert 'workspace' in call_kwargs assert 'collection' in call_kwargs @pytest.mark.asyncio @@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor: # Create message with multiple triples message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' triple1 = Triple( @@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify execute_query was called: # Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls @@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor: # Verify user/collection parameters were included in all calls for call in processor.io.execute_query.call_args_list: call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] - assert call_kwargs['user'] == 'test_user' + assert call_kwargs['workspace'] == 'test_workspace' assert call_kwargs['collection'] == 'test_collection' @pytest.mark.asyncio @@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor: message = MagicMock() message.metadata = MagicMock() - message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' message.triples = [] @@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor: with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(message) + await processor.store_triples('test_workspace', message) # Verify no session calls were made (no triples to process) processor.io.session.assert_not_called() diff --git a/tests/unit/test_storage/test_triples_neo4j_storage.py b/tests/unit/test_storage/test_triples_neo4j_storage.py index a5181ed9..0dcdb55e 100644 --- a/tests/unit/test_storage/test_triples_neo4j_storage.py +++ b/tests/unit/test_storage/test_triples_neo4j_storage.py @@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor: "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", - "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", - "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", - "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)", + "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)", + "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)", "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" ] @@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) # Test create_node - processor.create_node("http://example.com/node", "test_user", "test_collection") + processor.create_node("http://example.com/node", "test_workspace", "test_collection") mock_driver.execute_query.assert_called_with( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/node", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor: processor = Processor(taskgroup=taskgroup_mock) # Test create_literal - processor.create_literal("literal value", "test_user", "test_collection") + processor.create_literal("literal value", "test_workspace", "test_collection") mock_driver.execute_query.assert_called_with( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value="literal value", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor: "http://example.com/subject", "http://example.com/predicate", "http://example.com/object", - "test_user", + "test_workspace", "test_collection" ) mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="http://example.com/object", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor: "http://example.com/subject", "http://example.com/predicate", "literal value", - "test_user", + "test_workspace", "test_collection" ) mock_driver.execute_query.assert_called_with( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject", dest="literal value", uri="http://example.com/predicate", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) @@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor: # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Verify create_node was called for subject and object # Verify relate_node was called expected_calls = [ # Subject node creation ( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Object node creation ( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Relationship creation ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", { "src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", - "user": "test_user", + "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j" } @@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor: # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Verify create_node was called for subject # Verify create_literal was called for object @@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor: expected_calls = [ # Subject node creation ( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Literal creation ( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - {"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + {"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"} ), # Relationship creation ( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", { "src": "http://example.com/subject", "dest": "literal value", "uri": "http://example.com/predicate", - "user": "test_user", + "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j" } @@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor: # Create mock message with metadata mock_message = MagicMock() mock_message.triples = [triple1, triple2] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Should have processed both triples # Triple1: 2 nodes + 1 relationship = 3 calls @@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor: # Create mock message with empty triples and metadata mock_message = MagicMock() mock_message.triples = [] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Should not have made any execute_query calls beyond index creation # Only index creation calls should have been made during initialization @@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor: mock_message = MagicMock() mock_message.triples = [triple] - mock_message.metadata.user = "test_user" mock_message.metadata.collection = "test_collection" # Mock collection_exists to bypass validation in unit tests with patch.object(processor, 'collection_exists', return_value=True): - await processor.store_triples(mock_message) + await processor.store_triples("test_workspace", mock_message) # Verify the triple was processed with special characters preserved mock_driver.execute_query.assert_any_call( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", uri="http://example.com/subject with spaces", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) mock_driver.execute_query.assert_any_call( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", value='literal with "quotes" and unicode: ñáéíóú', - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) mock_driver.execute_query.assert_any_call( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", src="http://example.com/subject with spaces", dest='literal with "quotes" and unicode: ñáéíóú', uri="http://example.com/predicate:with/symbols", - user="test_user", + workspace="test_workspace", collection="test_collection", database_="neo4j" ) diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py index 3222ec83..51cf834f 100644 --- a/tests/unit/test_structured_data/test_row_embeddings_query.py +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None): return proc -def _make_request(vector=None, user="test-user", collection="test-col", +def _make_request(vector=None, collection="test-col", schema_name="customers", limit=10, index_name=None): return RowEmbeddingsRequest( vector=vector or [0.1, 0.2, 0.3], - user=user, collection=collection, schema_name=schema_name, limit=limit, @@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col", ) +def _make_flow(workspace="test-workspace", pub=None): + """Make a mock flow object that is callable and has .workspace.""" + flow = MagicMock() + flow.return_value = pub if pub is not None else AsyncMock() + flow.workspace = workspace + return flow + + def _make_search_point(index_name, index_value, text, score): point = MagicMock() point.payload = { @@ -85,34 +92,33 @@ class TestFindCollection: def test_finds_matching_collection(self): proc = _make_processor() mock_coll = MagicMock() - mock_coll.name = "rows_test_user_test_col_customers_384" + mock_coll.name = "rows_test_workspace_test_col_customers_384" mock_collections = MagicMock() mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-user", "test-col", "customers") + result = proc.find_collection("test-workspace", "test-col", "customers") - # Prefix: rows_test_user_test_col_customers_ - assert result == "rows_test_user_test_col_customers_384" + assert result == "rows_test_workspace_test_col_customers_384" def test_returns_none_when_no_match(self): proc = _make_processor() mock_coll = MagicMock() - mock_coll.name = "rows_other_user_other_col_schema_768" + mock_coll.name = "rows_other_workspace_other_col_schema_768" mock_collections = MagicMock() mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-user", "test-col", "customers") + result = proc.find_collection("test-workspace", "test-col", "customers") assert result is None def test_returns_none_on_error(self): proc = _make_processor() proc.qdrant.get_collections.side_effect = Exception("connection error") - result = proc.find_collection("user", "col", "schema") + result = proc.find_collection("workspace", "col", "schema") assert result is None @@ -127,7 +133,7 @@ class TestQueryRowEmbeddings: proc = _make_processor() request = _make_request(vector=[]) - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert result == [] @pytest.mark.asyncio @@ -136,13 +142,13 @@ class TestQueryRowEmbeddings: proc.find_collection = MagicMock(return_value=None) request = _make_request() - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert result == [] @pytest.mark.asyncio async def test_successful_query_returns_matches(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") points = [ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), @@ -153,7 +159,7 @@ class TestQueryRowEmbeddings: proc.qdrant.query_points.return_value = mock_result request = _make_request() - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert len(result) == 2 assert isinstance(result[0], RowIndexMatch) @@ -166,14 +172,14 @@ class TestQueryRowEmbeddings: async def test_index_name_filter_applied(self): """When index_name is specified, a Qdrant filter should be used.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] proc.qdrant.query_points.return_value = mock_result request = _make_request(index_name="address") - await proc.query_row_embeddings(request) + await proc.query_row_embeddings("test-workspace", request) call_kwargs = proc.qdrant.query_points.call_args[1] assert call_kwargs["query_filter"] is not None @@ -182,14 +188,14 @@ class TestQueryRowEmbeddings: async def test_no_index_name_no_filter(self): """When index_name is empty, no filter should be applied.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] proc.qdrant.query_points.return_value = mock_result request = _make_request(index_name="") - await proc.query_row_embeddings(request) + await proc.query_row_embeddings("test-workspace", request) call_kwargs = proc.qdrant.query_points.call_args[1] assert call_kwargs["query_filter"] is None @@ -198,7 +204,7 @@ class TestQueryRowEmbeddings: async def test_missing_payload_fields_default(self): """Points with missing payload fields should use defaults.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") point = MagicMock() point.payload = {} # Empty payload @@ -209,7 +215,7 @@ class TestQueryRowEmbeddings: proc.qdrant.query_points.return_value = mock_result request = _make_request() - result = await proc.query_row_embeddings(request) + result = await proc.query_row_embeddings("test-workspace", request) assert len(result) == 1 assert result[0].index_name == "" @@ -219,13 +225,13 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_qdrant_error_propagates(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.find_collection = MagicMock(return_value="rows_w_c_s_384") proc.qdrant.query_points.side_effect = Exception("qdrant down") request = _make_request() with pytest.raises(Exception, match="qdrant down"): - await proc.query_row_embeddings(request) + await proc.query_row_embeddings("test-workspace", request) # --------------------------------------------------------------------------- @@ -243,7 +249,7 @@ class TestOnMessage: ]) mock_pub = AsyncMock() - flow = lambda name: mock_pub + flow = _make_flow(pub=mock_pub) msg = MagicMock() msg.value.return_value = _make_request() @@ -264,7 +270,7 @@ class TestOnMessage: ) mock_pub = AsyncMock() - flow = lambda name: mock_pub + flow = _make_flow(pub=mock_pub) msg = MagicMock() msg.value.return_value = _make_request() @@ -284,7 +290,7 @@ class TestOnMessage: proc.query_row_embeddings = AsyncMock(return_value=[]) mock_pub = AsyncMock() - flow = lambda name: mock_pub + flow = _make_flow(pub=mock_pub) msg = MagicMock() msg.value.return_value = _make_request() diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py index 4ab0ffeb..59d15b45 100644 --- a/tests/unit/test_tables/test_knowledge_table_store.py +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -45,12 +45,9 @@ class TestGetGraphEmbeddings: with `vector=` (singular) — the schema field name. A previous version used `vectors=` and TypeError'd at runtime. """ - # Arrange — fake row matching the get_triples_stmt result shape: - # row[0..2] are unused by the method, row[3] is the entities blob fake_row = ( None, None, None, [ - # ((value, is_uri), vector) (("http://example.org/alice", True), [0.1, 0.2, 0.3]), (("http://example.org/bob", True), [0.4, 0.5, 0.6]), (("a literal entity", False), [0.7, 0.8, 0.9]), @@ -67,14 +64,8 @@ class TestGetGraphEmbeddings: async def receiver(msg): received.append(msg) - # Act - await store.get_graph_embeddings( - user="alice", - document_id="doc-1", - receiver=receiver, - ) + await store.get_graph_embeddings("alice", "doc-1", receiver) - # Assert mock_async_execute.assert_called_once_with( store.cassandra, store.get_graph_embeddings_stmt, @@ -86,7 +77,6 @@ class TestGetGraphEmbeddings: assert isinstance(ge, GraphEmbeddings) assert isinstance(ge.metadata, Metadata) assert ge.metadata.id == "doc-1" - assert ge.metadata.user == "alice" assert len(ge.entities) == 3 assert all(isinstance(e, EntityEmbeddings) for e in ge.entities) @@ -122,7 +112,7 @@ class TestGetGraphEmbeddings: async def receiver(msg): received.append(msg) - await store.get_graph_embeddings("u", "d", receiver) + await store.get_graph_embeddings("w", "d", receiver) assert len(received) == 1 assert received[0].entities == [] @@ -149,7 +139,7 @@ class TestGetGraphEmbeddings: async def receiver(msg): received.append(msg) - await store.get_graph_embeddings("u", "d", receiver) + await store.get_graph_embeddings("w", "d", receiver) assert len(received) == 2 assert received[0].entities[0].entity.iri == "http://example.org/a" @@ -194,7 +184,6 @@ class TestGetTriples: assert isinstance(triples_msg, Triples) assert isinstance(triples_msg.metadata, Metadata) assert triples_msg.metadata.id == "doc-1" - assert triples_msg.metadata.user == "alice" assert len(triples_msg.triples) == 1 t = triples_msg.triples[0] diff --git a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py index 72f4796b..56a1583e 100644 --- a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py +++ b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py @@ -30,7 +30,6 @@ def sample(): metadata=Metadata( id="doc-1", root="", - user="alice", collection="testcoll", ), chunks=[ @@ -56,7 +55,6 @@ class TestDocumentEmbeddingsTranslator: assert isinstance(decoded, DocumentEmbeddings) assert isinstance(decoded.metadata, Metadata) assert decoded.metadata.id == "doc-1" - assert decoded.metadata.user == "alice" assert decoded.metadata.collection == "testcoll" assert len(decoded.chunks) == 2 diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py index 57e7ae17..64f2e5d4 100644 --- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -41,7 +41,7 @@ def translator(): def graph_embeddings_request(): return KnowledgeRequest( operation="put-kg-core", - user="alice", + workspace="alice", id="doc-1", flow="default", collection="testcoll", @@ -49,7 +49,6 @@ def graph_embeddings_request(): metadata=Metadata( id="doc-1", root="", - user="alice", collection="testcoll", ), entities=[ @@ -70,7 +69,6 @@ def graph_embeddings_request(): def triples_request(): return KnowledgeRequest( operation="put-kg-core", - user="alice", id="doc-1", flow="default", collection="testcoll", @@ -78,7 +76,6 @@ def triples_request(): metadata=Metadata( id="doc-1", root="", - user="alice", collection="testcoll", ), triples=[ @@ -113,7 +110,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings: assert isinstance(decoded, KnowledgeRequest) assert decoded.operation == "put-kg-core" - assert decoded.user == "alice" + assert decoded.workspace == "alice" assert decoded.id == "doc-1" assert decoded.flow == "default" assert decoded.collection == "testcoll" @@ -123,7 +120,6 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings: assert isinstance(ge, GraphEmbeddings) assert isinstance(ge.metadata, Metadata) assert ge.metadata.id == "doc-1" - assert ge.metadata.user == "alice" assert ge.metadata.collection == "testcoll" assert len(ge.entities) == 2 @@ -143,7 +139,6 @@ class TestKnowledgeRequestTranslatorTriples: assert decoded.triples is not None assert isinstance(decoded.triples.metadata, Metadata) assert decoded.triples.metadata.id == "doc-1" - assert decoded.triples.metadata.user == "alice" assert decoded.triples.metadata.collection == "testcoll" assert len(decoded.triples.triples) == 1 diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index 2f44aad0..e14c61a3 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -27,7 +27,6 @@ Quick Start: # Execute a graph RAG query response = flow.graph_rag( query="What are the main topics?", - user="trustgraph", collection="default" ) ``` @@ -38,7 +37,7 @@ For streaming and async operations: socket = api.socket() flow = socket.flow("default") - for chunk in flow.agent(question="Hello", user="trustgraph"): + for chunk in flow.agent(question="Hello"): print(chunk.content) # Async operations diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index dbdce0a8..9074bac1 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -50,7 +50,7 @@ class Api: token: Optional bearer token for authentication """ - def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None): + def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None, workspace: str = "default"): """ Initialize the TrustGraph API client. @@ -82,6 +82,7 @@ class Api: self.timeout = timeout self.token = token + self.workspace = workspace # Lazy initialization for new clients self._socket_client = None @@ -137,7 +138,7 @@ class Api: config.put([ConfigValue(type="llm", key="model", value="gpt-4")]) ``` """ - return Config(api=self) + return Config(api=self, workspace=self.workspace) def knowledge(self): """ @@ -151,10 +152,10 @@ class Api: knowledge = api.knowledge() # List available KG cores - cores = knowledge.list_kg_cores(user="trustgraph") + cores = knowledge.list_kg_cores() # Load a KG core - knowledge.load_kg_core(id="core-123", user="trustgraph") + knowledge.load_kg_core(id="core-123") ``` """ return Knowledge(api=self) @@ -191,6 +192,12 @@ class Api: if self.token: headers["Authorization"] = f"Bearer {self.token}" + # Ensure every REST request carries the workspace so services can + # scope their behaviour. Callers that already set workspace in the + # payload (e.g. Library client) take precedence. + if isinstance(request, dict) and "workspace" not in request: + request = {**request, "workspace": self.workspace} + # Invoke the API, input is passed as JSON resp = requests.post(url, json=request, timeout=self.timeout, headers=headers) @@ -227,13 +234,12 @@ class Api: document=b"Document content", id="doc-123", metadata=[], - user="trustgraph", title="My Document", comments="Test document" ) # List documents - docs = library.get_documents(user="trustgraph") + docs = library.get_documents() ``` """ return Library(self) @@ -253,11 +259,10 @@ class Api: collection = api.collection() # List collections - colls = collection.list_collections(user="trustgraph") + colls = collection.list_collections() # Update collection metadata collection.update_collection( - user="trustgraph", collection="default", name="Default Collection", description="Main data collection" @@ -286,7 +291,6 @@ class Api: # Stream agent responses for chunk in flow.agent( question="Explain quantum computing", - user="trustgraph", streaming=True ): if hasattr(chunk, 'content'): @@ -297,7 +301,10 @@ class Api: from . socket_client import SocketClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._socket_client = SocketClient(base_url, self.timeout, self.token) + self._socket_client = SocketClient( + base_url, self.timeout, self.token, + workspace=self.workspace, + ) return self._socket_client def bulk(self): @@ -406,7 +413,6 @@ class Api: # Stream agent responses async for chunk in flow.agent( question="Explain quantum computing", - user="trustgraph", streaming=True ): if hasattr(chunk, 'content'): @@ -417,7 +423,10 @@ class Api: from . async_socket_client import AsyncSocketClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token) + self._async_socket_client = AsyncSocketClient( + base_url, self.timeout, self.token, + workspace=self.workspace, + ) return self._async_socket_client def async_bulk(self): diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 68899341..bf0b2ba1 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -326,9 +326,7 @@ class AsyncFlow: # Use flow services result = await flow.graph_rag( - query="What is TrustGraph?", - user="trustgraph", - collection="default" + query="What is TrustGraph?",collection="default" ) ``` """ @@ -385,7 +383,7 @@ class AsyncFlowInstance: """ return await self.flow.request(f"flow/{self.flow_id}/service/{service}", request_data) - async def agent(self, question: str, user: str, state: Optional[Dict] = None, + async def agent(self, question: str, state: Optional[Dict] = None, group: Optional[str] = None, history: Optional[List] = None, **kwargs: Any) -> Dict[str, Any]: """ Execute an agent operation (non-streaming). @@ -399,7 +397,6 @@ class AsyncFlowInstance: Args: question: User question or instruction - user: User identifier state: Optional state dictionary for conversation context group: Optional group identifier for session management history: Optional conversation history list @@ -416,14 +413,12 @@ class AsyncFlowInstance: # Execute agent result = await flow.agent( question="What is the capital of France?", - user="trustgraph" - ) + ) print(f"Answer: {result.get('response')}") ``` """ request_data = { "question": question, - "user": user, "streaming": False # REST doesn't support streaming } if state is not None: @@ -481,7 +476,7 @@ class AsyncFlowInstance: model=result.get("model"), ) - async def graph_rag(self, query: str, user: str, collection: str, + async def graph_rag(self, query: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, **kwargs: Any) -> str: """ @@ -496,7 +491,6 @@ class AsyncFlowInstance: Args: query: User query text - user: User identifier collection: Collection identifier containing the knowledge graph max_subgraph_size: Maximum number of triples per subgraph (default: 1000) max_subgraph_count: Maximum number of subgraphs to retrieve (default: 5) @@ -513,9 +507,7 @@ class AsyncFlowInstance: # Query knowledge graph response = await flow.graph_rag( - query="What are the relationships between these entities?", - user="trustgraph", - collection="medical-kb", + query="What are the relationships between these entities?",collection="medical-kb", max_subgraph_count=3 ) print(response) @@ -523,7 +515,6 @@ class AsyncFlowInstance: """ request_data = { "query": query, - "user": user, "collection": collection, "max-subgraph-size": max_subgraph_size, "max-subgraph-count": max_subgraph_count, @@ -535,7 +526,7 @@ class AsyncFlowInstance: result = await self.request("graph-rag", request_data) return result.get("response", "") - async def document_rag(self, query: str, user: str, collection: str, + async def document_rag(self, query: str, collection: str, doc_limit: int = 10, **kwargs: Any) -> str: """ Execute document-based RAG query (non-streaming). @@ -549,7 +540,6 @@ class AsyncFlowInstance: Args: query: User query text - user: User identifier collection: Collection identifier containing documents doc_limit: Maximum number of document chunks to retrieve (default: 10) **kwargs: Additional service-specific parameters @@ -564,9 +554,7 @@ class AsyncFlowInstance: # Query documents response = await flow.document_rag( - query="What does the documentation say about authentication?", - user="trustgraph", - collection="docs", + query="What does the documentation say about authentication?",collection="docs", doc_limit=5 ) print(response) @@ -574,7 +562,6 @@ class AsyncFlowInstance: """ request_data = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": False @@ -584,7 +571,7 @@ class AsyncFlowInstance: result = await self.request("document-rag", request_data) return result.get("response", "") - async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any): + async def graph_embeddings_query(self, text: str, collection: str, limit: int = 10, **kwargs: Any): """ Query graph embeddings for semantic entity search. @@ -593,7 +580,6 @@ class AsyncFlowInstance: Args: text: Query text for semantic search - user: User identifier collection: Collection identifier containing graph embeddings limit: Maximum number of results to return (default: 10) **kwargs: Additional service-specific parameters @@ -608,9 +594,7 @@ class AsyncFlowInstance: # Find related entities results = await flow.graph_embeddings_query( - text="machine learning algorithms", - user="trustgraph", - collection="tech-kb", + text="machine learning algorithms",collection="tech-kb", limit=5 ) @@ -624,7 +608,6 @@ class AsyncFlowInstance: request_data = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -663,7 +646,7 @@ class AsyncFlowInstance: return await self.request("embeddings", request_data) - async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs: Any): + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any): """ Query RDF triples using pattern matching. @@ -674,7 +657,6 @@ class AsyncFlowInstance: s: Subject pattern (None for wildcard) p: Predicate pattern (None for wildcard) o: Object pattern (None for wildcard) - user: User identifier (None for all users) collection: Collection identifier (None for all collections) limit: Maximum number of triples to return (default: 100) **kwargs: Additional service-specific parameters @@ -689,9 +671,7 @@ class AsyncFlowInstance: # Find all triples with a specific predicate results = await flow.triples_query( - p="knows", - user="trustgraph", - collection="social", + p="knows",collection="social", limit=50 ) @@ -706,15 +686,13 @@ class AsyncFlowInstance: request_data["p"] = str(p) if o is not None: request_data["o"] = str(o) - if user is not None: - request_data["user"] = user if collection is not None: request_data["collection"] = collection request_data.update(kwargs) return await self.request("triples", request_data) - async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + async def rows_query(self, query: str, collection: str, variables: Optional[Dict] = None, operation_name: Optional[str] = None, **kwargs: Any): """ Execute a GraphQL query on stored rows. @@ -724,7 +702,6 @@ class AsyncFlowInstance: Args: query: GraphQL query string - user: User identifier collection: Collection identifier containing rows variables: Optional GraphQL query variables operation_name: Optional operation name for multi-operation queries @@ -750,9 +727,7 @@ class AsyncFlowInstance: ''' result = await flow.rows_query( - query=query, - user="trustgraph", - collection="users", + query=query,collection="users", variables={"status": "active"} ) @@ -762,7 +737,6 @@ class AsyncFlowInstance: """ request_data = { "query": query, - "user": user, "collection": collection } if variables: @@ -774,7 +748,7 @@ class AsyncFlowInstance: return await self.request("rows", request_data) async def row_embeddings_query( - self, text: str, schema_name: str, user: str = "trustgraph", + self, text: str, schema_name: str, collection: str = "default", index_name: Optional[str] = None, limit: int = 10, **kwargs: Any ): @@ -788,7 +762,6 @@ class AsyncFlowInstance: Args: text: Query text for semantic search schema_name: Schema name to search within - user: User identifier (default: "trustgraph") collection: Collection identifier (default: "default") index_name: Optional index name to filter search to specific index limit: Maximum number of results to return (default: 10) @@ -806,9 +779,7 @@ class AsyncFlowInstance: # Search for customers by name similarity results = await flow.row_embeddings_query( text="John Smith", - schema_name="customers", - user="trustgraph", - collection="sales", + schema_name="customers",collection="sales", limit=5 ) @@ -823,7 +794,6 @@ class AsyncFlowInstance: request_data = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 6e5064ab..e5d553ea 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -22,10 +22,14 @@ class AsyncSocketClient: Or call connect()/aclose() manually. """ - def __init__(self, url: str, timeout: int, token: Optional[str]): + def __init__( + self, url: str, timeout: int, token: Optional[str], + workspace: str = "default", + ): self.url = self._convert_to_ws_url(url) self.timeout = timeout self.token = token + self.workspace = workspace self._request_counter = 0 self._socket = None self._connect_cm = None @@ -117,6 +121,7 @@ class AsyncSocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -149,6 +154,7 @@ class AsyncSocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -251,13 +257,12 @@ class AsyncSocketFlowInstance: self.client = client self.flow_id = flow_id - async def agent(self, question: str, user: str, state: Optional[Dict[str, Any]] = None, + async def agent(self, question: str, state: Optional[Dict[str, Any]] = None, group: Optional[str] = None, history: Optional[list] = None, streaming: bool = False, **kwargs) -> Union[Dict[str, Any], AsyncIterator]: """Agent with optional streaming""" request = { "question": question, - "user": user, "streaming": streaming } if state is not None: @@ -303,13 +308,12 @@ class AsyncSocketFlowInstance: if isinstance(chunk, RAGChunk): yield chunk - async def graph_rag(self, query: str, user: str, collection: str, + async def graph_rag(self, query: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, streaming: bool = False, **kwargs): """Graph RAG with optional streaming""" request = { "query": query, - "user": user, "collection": collection, "max-subgraph-size": max_subgraph_size, "max-subgraph-count": max_subgraph_count, @@ -330,12 +334,11 @@ class AsyncSocketFlowInstance: if hasattr(chunk, 'content'): yield chunk.content - async def document_rag(self, query: str, user: str, collection: str, + async def document_rag(self, query: str, collection: str, doc_limit: int = 10, streaming: bool = False, **kwargs): """Document RAG with optional streaming""" request = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": streaming @@ -375,14 +378,13 @@ class AsyncSocketFlowInstance: if hasattr(chunk, 'content'): yield chunk.content - async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): + async def graph_embeddings_query(self, text: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" emb_result = await self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] request = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -397,7 +399,7 @@ class AsyncSocketFlowInstance: return await self.client._send_request("embeddings", self.flow_id, request) - async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs): + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs): """Triple pattern query""" request = {"limit": limit} if s is not None: @@ -406,20 +408,17 @@ class AsyncSocketFlowInstance: request["p"] = str(p) if o is not None: request["o"] = str(o) - if user is not None: - request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) return await self.client._send_request("triples", self.flow_id, request) - async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + async def rows_query(self, query: str, collection: str, variables: Optional[Dict] = None, operation_name: Optional[str] = None, **kwargs): """GraphQL query against structured rows""" request = { "query": query, - "user": user, "collection": collection } if variables: @@ -441,7 +440,7 @@ class AsyncSocketFlowInstance: return await self.client._send_request("mcp-tool", self.flow_id, request) async def row_embeddings_query( - self, text: str, schema_name: str, user: str = "trustgraph", + self, text: str, schema_name: str, collection: str = "default", index_name: Optional[str] = None, limit: int = 10, **kwargs ): @@ -452,7 +451,6 @@ class AsyncSocketFlowInstance: request = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 75999550..0e49fc4e 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -85,7 +85,7 @@ class BulkClient: Args: flow: Flow identifier triples: Iterator yielding Triple objects - metadata: Metadata dict with id, metadata, user, collection + metadata: Metadata dict with id, metadata, collection batch_size: Number of triples per batch (default 100) **kwargs: Additional parameters (reserved for future use) @@ -105,7 +105,7 @@ class BulkClient: bulk.import_triples( flow="default", triples=triple_generator(), - metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"} + metadata={"id": "doc1", "metadata": [], "collection": "default"} ) ``` """ @@ -121,7 +121,7 @@ class BulkClient: ws_url = f"{ws_url}?token={self.token}" if metadata is None: - metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"} + metadata = {"id": "", "metadata": [], "collection": "default"} async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: batch = [] @@ -418,7 +418,7 @@ class BulkClient: Args: flow: Flow identifier contexts: Iterator yielding context dictionaries - metadata: Metadata dict with id, metadata, user, collection + metadata: Metadata dict with id, metadata, collection batch_size: Number of contexts per batch (default 100) **kwargs: Additional parameters (reserved for future use) @@ -435,7 +435,7 @@ class BulkClient: bulk.import_entity_contexts( flow="default", contexts=context_generator(), - metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"} + metadata={"id": "doc1", "metadata": [], "collection": "default"} ) ``` """ @@ -451,7 +451,7 @@ class BulkClient: ws_url = f"{ws_url}?token={self.token}" if metadata is None: - metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"} + metadata = {"id": "", "metadata": [], "collection": "default"} async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: batch = [] diff --git a/trustgraph-base/trustgraph/api/collection.py b/trustgraph-base/trustgraph/api/collection.py index 414d07db..11cd2843 100644 --- a/trustgraph-base/trustgraph/api/collection.py +++ b/trustgraph-base/trustgraph/api/collection.py @@ -2,11 +2,9 @@ TrustGraph Collection Management This module provides interfaces for managing data collections in TrustGraph. -Collections provide logical grouping and isolation for documents and knowledge -graph data. +Collections provide logical grouping within a workspace. """ -import datetime import logging from . types import CollectionMetadata @@ -18,10 +16,9 @@ class Collection: """ Collection management client. - Provides methods for managing data collections, including listing, - updating metadata, and deleting collections. Collections organize - documents and knowledge graph data into logical groupings for - isolation and access control. + Provides methods for managing data collections within the configured + workspace, including listing, updating metadata, and deleting + collections. """ def __init__(self, api): @@ -45,45 +42,20 @@ class Collection: """ return self.api.request(f"collection-management", request) - def list_collections(self, user, tag_filter=None): + def list_collections(self, tag_filter=None): """ - List all collections for a user. - - Retrieves metadata for all collections owned by the specified user, - with optional filtering by tags. + List all collections in this workspace. Args: - user: User identifier - tag_filter: Optional list of tags to filter collections (default: None) + tag_filter: Optional list of tags to filter collections Returns: list[CollectionMetadata]: List of collection metadata objects - - Raises: - ProtocolException: If response format is invalid - - Example: - ```python - collection = api.collection() - - # List all collections - all_colls = collection.list_collections(user="trustgraph") - for coll in all_colls: - print(f"{coll.collection}: {coll.name}") - print(f" Description: {coll.description}") - print(f" Tags: {', '.join(coll.tags)}") - - # List collections with specific tags - research_colls = collection.list_collections( - user="trustgraph", - tag_filter=["research", "published"] - ) - ``` """ input = { "operation": "list-collections", - "user": user, + "workspace": self.api.workspace, } if tag_filter: @@ -92,7 +64,6 @@ class Collection: object = self.request(input) try: - # Handle case where collections might be None or missing if object is None or "collections" not in object: return [] @@ -102,7 +73,6 @@ class Collection: return [ CollectionMetadata( - user = v["user"], collection = v["collection"], name = v["name"], description = v["description"], @@ -114,15 +84,11 @@ class Collection: logger.error("Failed to parse collection list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def update_collection(self, user, collection, name=None, description=None, tags=None): + def update_collection(self, collection, name=None, description=None, tags=None): """ Update collection metadata. - Updates the name, description, and/or tags for an existing collection. - Only provided fields are updated; others remain unchanged. - Args: - user: User identifier collection: Collection identifier name: New collection name (optional) description: New collection description (optional) @@ -130,35 +96,11 @@ class Collection: Returns: CollectionMetadata: Updated collection metadata, or None if not found - - Raises: - ProtocolException: If response format is invalid - - Example: - ```python - collection_api = api.collection() - - # Update collection metadata - updated = collection_api.update_collection( - user="trustgraph", - collection="default", - name="Default Collection", - description="Main data collection for general use", - tags=["default", "production"] - ) - - # Update only specific fields - updated = collection_api.update_collection( - user="trustgraph", - collection="research", - description="Updated description" - ) - ``` """ input = { "operation": "update-collection", - "user": user, + "workspace": self.api.workspace, "collection": collection, } @@ -175,7 +117,6 @@ class Collection: if "collections" in object and object["collections"]: v = object["collections"][0] return CollectionMetadata( - user = v["user"], collection = v["collection"], name = v["name"], description = v["description"], @@ -186,37 +127,23 @@ class Collection: logger.error("Failed to parse collection update response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def delete_collection(self, user, collection): + def delete_collection(self, collection): """ Delete a collection. - Removes a collection and all its associated data from the system. - Args: - user: User identifier collection: Collection identifier to delete Returns: dict: Empty response object - - Example: - ```python - collection_api = api.collection() - - # Delete a collection - collection_api.delete_collection( - user="trustgraph", - collection="old-collection" - ) - ``` """ input = { "operation": "delete-collection", - "user": user, + "workspace": self.api.workspace, "collection": collection, } - object = self.request(input) + self.request(input) - return {} \ No newline at end of file + return {} diff --git a/trustgraph-base/trustgraph/api/config.py b/trustgraph-base/trustgraph/api/config.py index c8c8d5bb..5f17672f 100644 --- a/trustgraph-base/trustgraph/api/config.py +++ b/trustgraph-base/trustgraph/api/config.py @@ -21,14 +21,16 @@ class Config: and list operations. """ - def __init__(self, api): + def __init__(self, api, workspace="default"): """ Initialize Config client. Args: api: Parent Api instance for making requests + workspace: Workspace to scope all config operations to """ self.api = api + self.workspace = workspace def request(self, request): """ @@ -75,9 +77,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "get", + "workspace": self.workspace, "keys": [ { "type": k.type, "key": k.key } for k in keys @@ -123,9 +125,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "put", + "workspace": self.workspace, "values": [ { "type": v.type, "key": v.key, "value": v.value } for v in values @@ -157,9 +159,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "delete", + "workspace": self.workspace, "keys": [ { "type": v.type, "key": v.key } for v in keys @@ -195,9 +197,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "list", + "workspace": self.workspace, "type": type, } @@ -235,9 +237,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { "operation": "getvalues", + "workspace": self.workspace, "type": type, } @@ -255,6 +257,46 @@ class Config: except: raise ProtocolException(f"Response not formatted correctly") + def get_values_all_workspaces(self, type): + """ + Get all configuration values of a given type across all workspaces. + + Unlike get_values(), this is not scoped to a single workspace — + it returns every entry of the given type in the system. Each + returned ConfigValue includes its workspace field. Used by + shared processors to load type-scoped config at startup. + + Args: + type: Configuration type (e.g. "prompt", "schema") + + Returns: + list[ConfigValue]: Values across all workspaces; each has + its workspace field populated. + + Raises: + ProtocolException: If response format is invalid + """ + + input = { + "operation": "getvalues-all-ws", + "type": type, + } + + object = self.request(input) + + try: + return [ + ConfigValue( + type = v["type"], + key = v["key"], + value = v["value"], + workspace = v.get("workspace", ""), + ) + for v in object["values"] + ] + except Exception: + raise ProtocolException("Response not formatted correctly") + def all(self): """ Get complete configuration and version. @@ -279,9 +321,9 @@ class Config: ``` """ - # The input consists of system and prompt strings input = { - "operation": "config" + "operation": "config", + "workspace": self.workspace, } object = self.request(input) diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 08d0b4e7..656ff95f 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -486,7 +486,6 @@ class ExplainabilityClient: self, uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> Optional[ExplainEntity]: """ @@ -502,7 +501,6 @@ class ExplainabilityClient: Args: uri: The entity URI to fetch graph: Named graph to query (e.g., "urn:graph:retrieval") - user: User/keyspace identifier collection: Collection identifier Returns: @@ -515,7 +513,6 @@ class ExplainabilityClient: wire_triples = self.flow.triples_query( s=uri, g=graph, - user=user, collection=collection, limit=100 ) @@ -548,7 +545,7 @@ class ExplainabilityClient: if prev_triples: # Re-fetch and parse wire_triples = self.flow.triples_query( - s=uri, g=graph, user=user, collection=collection, limit=100 + s=uri, g=graph, collection=collection, limit=100 ) if wire_triples: triples = wire_triples_to_tuples(wire_triples) @@ -560,7 +557,6 @@ class ExplainabilityClient: self, uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> Optional[EdgeSelection]: """ @@ -569,7 +565,6 @@ class ExplainabilityClient: Args: uri: The edge selection URI graph: Named graph to query - user: User/keyspace identifier collection: Collection identifier Returns: @@ -578,7 +573,6 @@ class ExplainabilityClient: wire_triples = self.flow.triples_query( s=uri, g=graph, - user=user, collection=collection, limit=100 ) @@ -593,7 +587,6 @@ class ExplainabilityClient: self, uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> Optional[Focus]: """ @@ -602,20 +595,19 @@ class ExplainabilityClient: Args: uri: The Focus entity URI graph: Named graph to query - user: User/keyspace identifier collection: Collection identifier Returns: Focus with populated edge_selections, or None """ - entity = self.fetch_entity(uri, graph, user, collection) + entity = self.fetch_entity(uri, graph, collection) if not isinstance(entity, Focus): return None # Fetch each edge selection for edge_uri in entity.selected_edge_uris: - edge_sel = self.fetch_edge_selection(edge_uri, graph, user, collection) + edge_sel = self.fetch_edge_selection(edge_uri, graph, collection) if edge_sel: entity.edge_selections.append(edge_sel) @@ -624,7 +616,6 @@ class ExplainabilityClient: def resolve_label( self, uri: str, - user: Optional[str] = None, collection: Optional[str] = None ) -> str: """ @@ -632,7 +623,6 @@ class ExplainabilityClient: Args: uri: The URI to get label for - user: User/keyspace identifier collection: Collection identifier Returns: @@ -647,7 +637,6 @@ class ExplainabilityClient: wire_triples = self.flow.triples_query( s=uri, p=RDFS_LABEL, - user=user, collection=collection, limit=1 ) @@ -665,7 +654,6 @@ class ExplainabilityClient: def resolve_edge_labels( self, edge: Dict[str, str], - user: Optional[str] = None, collection: Optional[str] = None ) -> Tuple[str, str, str]: """ @@ -673,22 +661,20 @@ class ExplainabilityClient: Args: edge: Dict with "s", "p", "o" keys - user: User/keyspace identifier collection: Collection identifier Returns: Tuple of (s_label, p_label, o_label) """ - s_label = self.resolve_label(edge.get("s", ""), user, collection) - p_label = self.resolve_label(edge.get("p", ""), user, collection) - o_label = self.resolve_label(edge.get("o", ""), user, collection) + s_label = self.resolve_label(edge.get("s", ""), collection) + p_label = self.resolve_label(edge.get("p", ""), collection) + o_label = self.resolve_label(edge.get("o", ""), collection) return (s_label, p_label, o_label) def fetch_document_content( self, document_uri: str, api: Any, - user: Optional[str] = None, max_content: int = 10000 ) -> str: """ @@ -697,7 +683,6 @@ class ExplainabilityClient: Args: document_uri: The document URI in the librarian api: TrustGraph Api instance for librarian access - user: User identifier for librarian max_content: Maximum content length to return Returns: @@ -712,7 +697,7 @@ class ExplainabilityClient: for attempt in range(self.max_retries): try: library = api.library() - content_bytes = library.get_document_content(user=user, id=doc_id) + content_bytes = library.get_document_content(id=doc_id) # Decode as text try: @@ -736,7 +721,6 @@ class ExplainabilityClient: self, question_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, api: Any = None, max_content: int = 10000 @@ -749,7 +733,6 @@ class ExplainabilityClient: Args: question_uri: The question entity URI graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier api: TrustGraph Api instance for librarian access (optional) max_content: Maximum content length for synthesis @@ -769,7 +752,7 @@ class ExplainabilityClient: } # Fetch question - question = self.fetch_entity(question_uri, graph, user, collection) + question = self.fetch_entity(question_uri, graph, collection) if not isinstance(question, Question): return trace trace["question"] = question @@ -779,7 +762,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=question_uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -790,7 +772,7 @@ class ExplainabilityClient: for t in grounding_triples ] for gnd_uri in grounding_uris: - grounding = self.fetch_entity(gnd_uri, graph, user, collection) + grounding = self.fetch_entity(gnd_uri, graph, collection) if isinstance(grounding, Grounding): trace["grounding"] = grounding break @@ -803,7 +785,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["grounding"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -814,7 +795,7 @@ class ExplainabilityClient: for t in exploration_triples ] for exp_uri in exploration_uris: - exploration = self.fetch_entity(exp_uri, graph, user, collection) + exploration = self.fetch_entity(exp_uri, graph, collection) if isinstance(exploration, Exploration): trace["exploration"] = exploration break @@ -827,7 +808,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["exploration"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -838,7 +818,7 @@ class ExplainabilityClient: for t in focus_triples ] for focus_uri in focus_uris: - focus = self.fetch_focus_with_edges(focus_uri, graph, user, collection) + focus = self.fetch_focus_with_edges(focus_uri, graph, collection) if focus: trace["focus"] = focus break @@ -851,7 +831,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["focus"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -862,7 +841,7 @@ class ExplainabilityClient: for t in synthesis_triples ] for synth_uri in synthesis_uris: - synthesis = self.fetch_entity(synth_uri, graph, user, collection) + synthesis = self.fetch_entity(synth_uri, graph, collection) if isinstance(synthesis, Synthesis): trace["synthesis"] = synthesis break @@ -873,7 +852,6 @@ class ExplainabilityClient: self, question_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, api: Any = None, max_content: int = 10000 @@ -887,7 +865,6 @@ class ExplainabilityClient: Args: question_uri: The question entity URI graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier api: TrustGraph Api instance for librarian access (optional) max_content: Maximum content length for synthesis @@ -906,7 +883,7 @@ class ExplainabilityClient: } # Fetch question - question = self.fetch_entity(question_uri, graph, user, collection) + question = self.fetch_entity(question_uri, graph, collection) if not isinstance(question, Question): return trace trace["question"] = question @@ -916,7 +893,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=question_uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -927,7 +903,7 @@ class ExplainabilityClient: for t in grounding_triples ] for gnd_uri in grounding_uris: - grounding = self.fetch_entity(gnd_uri, graph, user, collection) + grounding = self.fetch_entity(gnd_uri, graph, collection) if isinstance(grounding, Grounding): trace["grounding"] = grounding break @@ -940,7 +916,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["grounding"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -951,7 +926,7 @@ class ExplainabilityClient: for t in exploration_triples ] for exp_uri in exploration_uris: - exploration = self.fetch_entity(exp_uri, graph, user, collection) + exploration = self.fetch_entity(exp_uri, graph, collection) if isinstance(exploration, Exploration): trace["exploration"] = exploration break @@ -964,7 +939,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=trace["exploration"].uri, g=graph, - user=user, collection=collection, limit=10 ) @@ -975,7 +949,7 @@ class ExplainabilityClient: for t in synthesis_triples ] for synth_uri in synthesis_uris: - synthesis = self.fetch_entity(synth_uri, graph, user, collection) + synthesis = self.fetch_entity(synth_uri, graph, collection) if isinstance(synthesis, Synthesis): trace["synthesis"] = synthesis break @@ -986,7 +960,6 @@ class ExplainabilityClient: self, session_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, api: Any = None, max_content: int = 10000 @@ -1002,7 +975,6 @@ class ExplainabilityClient: Args: session_uri: The agent session/question URI graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier api: TrustGraph Api instance for librarian access (optional) max_content: Maximum content length for conclusion @@ -1019,21 +991,21 @@ class ExplainabilityClient: } # Fetch question/session - question = self.fetch_entity(session_uri, graph, user, collection) + question = self.fetch_entity(session_uri, graph, collection) if not isinstance(question, Question): return trace trace["question"] = question # Follow the provenance chain from the question self._follow_provenance_chain( - session_uri, trace, graph, user, collection, + session_uri, trace, graph, collection, max_depth=50, ) return trace def _follow_provenance_chain( - self, current_uri, trace, graph, user, collection, + self, current_uri, trace, graph, collection, max_depth=50, ): """Recursively follow the provenance chain, handling branches.""" @@ -1044,7 +1016,7 @@ class ExplainabilityClient: derived_triples = self.flow.triples_query( p=PROV_WAS_DERIVED_FROM, o=current_uri, - g=graph, user=user, collection=collection, + g=graph, collection=collection, limit=20 ) @@ -1060,7 +1032,7 @@ class ExplainabilityClient: if not derived_uri: continue - entity = self.fetch_entity(derived_uri, graph, user, collection) + entity = self.fetch_entity(derived_uri, graph, collection) if entity is None: continue @@ -1070,7 +1042,7 @@ class ExplainabilityClient: # Continue following from this entity self._follow_provenance_chain( - derived_uri, trace, graph, user, collection, + derived_uri, trace, graph, collection, max_depth=max_depth - 1, ) @@ -1079,11 +1051,11 @@ class ExplainabilityClient: # Fetch the full sub-trace and embed it. if entity.question_type == "graph-rag": sub_trace = self.fetch_graphrag_trace( - derived_uri, graph, user, collection, + derived_uri, graph, collection, ) elif entity.question_type == "document-rag": sub_trace = self.fetch_docrag_trace( - derived_uri, graph, user, collection, + derived_uri, graph, collection, ) else: sub_trace = None @@ -1100,7 +1072,7 @@ class ExplainabilityClient: terminal = sub_trace.get("synthesis") if terminal: self._follow_provenance_chain( - terminal.uri, trace, graph, user, collection, + terminal.uri, trace, graph, collection, max_depth=max_depth - 1, ) @@ -1110,7 +1082,6 @@ class ExplainabilityClient: def list_sessions( self, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, limit: int = 50 ) -> List[Question]: @@ -1119,7 +1090,6 @@ class ExplainabilityClient: Args: graph: Named graph (default: urn:graph:retrieval) - user: User/keyspace identifier collection: Collection identifier limit: Maximum number of sessions to return @@ -1133,7 +1103,6 @@ class ExplainabilityClient: query_triples = self.flow.triples_query( p=TG_QUERY, g=graph, - user=user, collection=collection, limit=limit ) @@ -1142,7 +1111,7 @@ class ExplainabilityClient: for t in query_triples: question_uri = extract_term_value(t.get("s", {})) if question_uri: - entity = self.fetch_entity(question_uri, graph, user, collection) + entity = self.fetch_entity(question_uri, graph, collection) if isinstance(entity, Question): questions.append(entity) @@ -1154,7 +1123,6 @@ class ExplainabilityClient: s=q.uri, p=PROV_WAS_DERIVED_FROM, g=graph, - user=user, collection=collection, limit=1 ) @@ -1170,7 +1138,6 @@ class ExplainabilityClient: self, session_uri: str, graph: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None ) -> str: """ @@ -1179,7 +1146,6 @@ class ExplainabilityClient: Args: session_uri: The session/question URI graph: Named graph - user: User/keyspace identifier collection: Collection identifier Returns: @@ -1201,7 +1167,6 @@ class ExplainabilityClient: p=PROV_WAS_DERIVED_FROM, o=session_uri, g=graph, - user=user, collection=collection, limit=5 ) @@ -1212,7 +1177,7 @@ class ExplainabilityClient: ] for child_uri in all_child_uris: - entity = self.fetch_entity(child_uri, graph, user, collection) + entity = self.fetch_entity(child_uri, graph, collection) if isinstance(entity, (Analysis, Decomposition, Plan)): return "agent" if isinstance(entity, Exploration): diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 7ee32dad..961e348b 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -115,72 +115,32 @@ class Flow: return FlowInstance(api=self, id=id) def list_blueprints(self): - """ - List all available flow blueprints. + """List blueprints in the current workspace.""" - Returns: - list[str]: List of blueprint names - - Example: - ```python - blueprints = api.flow().list_blueprints() - print(blueprints) # ['default', 'custom-flow', ...] - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "list-blueprints", + "workspace": self.api.workspace, } return self.request(request = input)["blueprint-names"] def get_blueprint(self, blueprint_name): - """ - Get a flow blueprint definition by name. + """Get a flow blueprint definition by name.""" - Args: - blueprint_name: Name of the blueprint to retrieve - - Returns: - dict: Blueprint definition as a dictionary - - Example: - ```python - blueprint = api.flow().get_blueprint("default") - print(blueprint) # Blueprint configuration - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "get-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, } return json.loads(self.request(request = input)["blueprint-definition"]) def put_blueprint(self, blueprint_name, definition): - """ - Create or update a flow blueprint. + """Create or update a flow blueprint.""" - Args: - blueprint_name: Name for the blueprint - definition: Blueprint definition dictionary - - Example: - ```python - definition = { - "services": ["text-completion", "graph-rag"], - "parameters": {"model": "gpt-4"} - } - api.flow().put_blueprint("my-blueprint", definition) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "put-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, "blueprint-definition": json.dumps(definition), } @@ -188,96 +148,43 @@ class Flow: self.request(request = input) def delete_blueprint(self, blueprint_name): - """ - Delete a flow blueprint. + """Delete a flow blueprint.""" - Args: - blueprint_name: Name of the blueprint to delete - - Example: - ```python - api.flow().delete_blueprint("old-blueprint") - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "delete-blueprint", + "workspace": self.api.workspace, "blueprint-name": blueprint_name, } self.request(request = input) def list(self): - """ - List all active flow instances. + """List flow instances in the current workspace.""" - Returns: - list[str]: List of flow instance IDs - - Example: - ```python - flows = api.flow().list() - print(flows) # ['default', 'flow-1', 'flow-2', ...] - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "list-flows", + "workspace": self.api.workspace, } return self.request(request = input)["flow-ids"] def get(self, id): - """ - Get the definition of a running flow instance. + """Get the definition of a flow instance.""" - Args: - id: Flow instance ID - - Returns: - dict: Flow instance definition - - Example: - ```python - flow_def = api.flow().get("default") - print(flow_def) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "get-flow", + "workspace": self.api.workspace, "flow-id": id, } return json.loads(self.request(request = input)["flow"]) def start(self, blueprint_name, id, description, parameters=None): - """ - Start a new flow instance from a blueprint. + """Start a new flow instance from a blueprint.""" - Args: - blueprint_name: Name of the blueprint to instantiate - id: Unique identifier for the flow instance - description: Human-readable description - parameters: Optional parameters dictionary - - Example: - ```python - api.flow().start( - blueprint_name="default", - id="my-flow", - description="My custom flow", - parameters={"model": "gpt-4"} - ) - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "start-flow", + "workspace": self.api.workspace, "flow-id": id, "blueprint-name": blueprint_name, "description": description, @@ -289,21 +196,11 @@ class Flow: self.request(request = input) def stop(self, id): - """ - Stop a running flow instance. + """Stop a running flow instance.""" - Args: - id: Flow instance ID to stop - - Example: - ```python - api.flow().stop("my-flow") - ``` - """ - - # The input consists of system and prompt strings input = { "operation": "stop-flow", + "workspace": self.api.workspace, "flow-id": id, } @@ -349,6 +246,13 @@ class FlowInstance: Returns: dict: Service response """ + # Inject workspace so the gateway can route to the right + # workspace's flow. If already present, keep the caller's value. + if isinstance(request, dict) and "workspace" not in request: + request = { + "workspace": self.api.api.workspace, + **request, + } return self.api.request(path = f"{self.id}/{path}", request = request) def text_completion(self, system, prompt): @@ -392,7 +296,7 @@ class FlowInstance: model=result.get("model"), ) - def agent(self, question, user="trustgraph", state=None, group=None, history=None): + def agent(self, question,state=None, group=None, history=None): """ Execute an agent operation with reasoning and tool use capabilities. @@ -401,7 +305,6 @@ class FlowInstance: Args: question: User question or instruction - user: User identifier (default: "trustgraph") state: Optional state dictionary for stateful conversations group: Optional group identifier for multi-user contexts history: Optional conversation history as list of message dicts @@ -416,8 +319,7 @@ class FlowInstance: # Simple question answer = flow.agent( question="What is the capital of France?", - user="trustgraph" - ) + ) # With conversation history history = [ @@ -425,9 +327,7 @@ class FlowInstance: {"role": "assistant", "content": "Hi! How can I help?"} ] answer = flow.agent( - question="Tell me about Paris", - user="trustgraph", - history=history + question="Tell me about Paris",history=history ) ``` """ @@ -435,7 +335,6 @@ class FlowInstance: # The input consists of a question and optional context input = { "question": question, - "user": user, } # Only include state if it has a value @@ -455,7 +354,7 @@ class FlowInstance: )["answer"] def graph_rag( - self, query, user="trustgraph", collection="default", + self, query,collection="default", entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2, edge_score_limit=30, edge_limit=25, ): @@ -467,7 +366,6 @@ class FlowInstance: Args: query: Natural language query - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") entity_limit: Maximum entities to retrieve (default: 50) triple_limit: Maximum triples per entity (default: 30) @@ -483,9 +381,7 @@ class FlowInstance: ```python flow = api.flow().id("default") response = flow.graph_rag( - query="Tell me about Marie Curie's discoveries", - user="trustgraph", - collection="scientists", + query="Tell me about Marie Curie's discoveries",collection="scientists", entity_limit=20, max_path_length=3 ) @@ -496,7 +392,6 @@ class FlowInstance: # The input consists of a question input = { "query": query, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -519,7 +414,7 @@ class FlowInstance: ) def document_rag( - self, query, user="trustgraph", collection="default", + self, query,collection="default", doc_limit=10, ): """ @@ -530,7 +425,6 @@ class FlowInstance: Args: query: Natural language query - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") doc_limit: Maximum document chunks to retrieve (default: 10) @@ -541,9 +435,7 @@ class FlowInstance: ```python flow = api.flow().id("default") response = flow.document_rag( - query="Summarize the key findings", - user="trustgraph", - collection="research-papers", + query="Summarize the key findings",collection="research-papers", doc_limit=5 ) print(response) @@ -553,7 +445,6 @@ class FlowInstance: # The input consists of a question input = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, } @@ -600,7 +491,7 @@ class FlowInstance: input )["vectors"] - def graph_embeddings_query(self, text, user, collection, limit=10): + def graph_embeddings_query(self, text, collection, limit=10): """ Query knowledge graph entities using semantic similarity. @@ -609,7 +500,6 @@ class FlowInstance: Args: text: Query text for semantic search - user: User/keyspace identifier collection: Collection identifier limit: Maximum number of results (default: 10) @@ -620,9 +510,7 @@ class FlowInstance: ```python flow = api.flow().id("default") results = flow.graph_embeddings_query( - text="physicist who discovered radioactivity", - user="trustgraph", - collection="scientists", + text="physicist who discovered radioactivity",collection="scientists", limit=5 ) # results contains {"entities": [{"entity": {...}, "score": 0.95}, ...]} @@ -636,7 +524,6 @@ class FlowInstance: # Query graph embeddings for semantic search input = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -646,7 +533,7 @@ class FlowInstance: input ) - def document_embeddings_query(self, text, user, collection, limit=10): + def document_embeddings_query(self, text, collection, limit=10): """ Query document chunks using semantic similarity. @@ -655,7 +542,6 @@ class FlowInstance: Args: text: Query text for semantic search - user: User/keyspace identifier collection: Collection identifier limit: Maximum number of results (default: 10) @@ -666,9 +552,7 @@ class FlowInstance: ```python flow = api.flow().id("default") results = flow.document_embeddings_query( - text="machine learning algorithms", - user="trustgraph", - collection="research-papers", + text="machine learning algorithms",collection="research-papers", limit=5 ) # results contains {"chunks": [{"chunk_id": "doc1/p0/c0", "score": 0.95}, ...]} @@ -682,7 +566,6 @@ class FlowInstance: # Query document embeddings for semantic search input = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -805,7 +688,7 @@ class FlowInstance: def triples_query( self, s=None, p=None, o=None, - user=None, collection=None, limit=10000 + collection=None, limit=10000 ): """ Query knowledge graph triples using pattern matching. @@ -817,7 +700,6 @@ class FlowInstance: s: Subject URI (optional, use None for wildcard) p: Predicate URI (optional, use None for wildcard) o: Object URI or Literal (optional, use None for wildcard) - user: User/keyspace identifier (optional) collection: Collection identifier (optional) limit: Maximum results to return (default: 10000) @@ -835,9 +717,7 @@ class FlowInstance: # Find all triples about a specific subject triples = flow.triples_query( - s=Uri("http://example.org/person/marie-curie"), - user="trustgraph", - collection="scientists" + s=Uri("http://example.org/person/marie-curie"),collection="scientists" ) # Find all instances of a specific relationship @@ -851,10 +731,6 @@ class FlowInstance: input = { "limit": limit } - - if user: - input["user"] = user - if collection: input["collection"] = collection @@ -888,7 +764,7 @@ class FlowInstance: ] def load_document( - self, document, id=None, metadata=None, user=None, + self, document, id=None, metadata=None, collection=None, ): """ @@ -901,7 +777,6 @@ class FlowInstance: document: Document content as bytes id: Optional document identifier (auto-generated if None) metadata: Optional metadata (list of Triples or object with emit method) - user: User/keyspace identifier (optional) collection: Collection identifier (optional) Returns: @@ -918,9 +793,7 @@ class FlowInstance: with open("research.pdf", "rb") as f: result = flow.load_document( document=f.read(), - id="research-001", - user="trustgraph", - collection="papers" + id="research-001",collection="papers" ) ``` """ @@ -955,10 +828,6 @@ class FlowInstance: "metadata": triples, "data": base64.b64encode(document).decode("utf-8"), } - - if user: - input["user"] = user - if collection: input["collection"] = collection @@ -969,7 +838,7 @@ class FlowInstance: def load_text( self, text, id=None, metadata=None, charset="utf-8", - user=None, collection=None, + collection=None, ): """ Load text content for processing. @@ -982,7 +851,6 @@ class FlowInstance: id: Optional document identifier (auto-generated if None) metadata: Optional metadata (list of Triples or object with emit method) charset: Character encoding (default: "utf-8") - user: User/keyspace identifier (optional) collection: Collection identifier (optional) Returns: @@ -1000,9 +868,7 @@ class FlowInstance: result = flow.load_text( text=text_content, id="text-001", - charset="utf-8", - user="trustgraph", - collection="documents" + charset="utf-8",collection="documents" ) ``` """ @@ -1035,10 +901,6 @@ class FlowInstance: "charset": charset, "text": base64.b64encode(text).decode("utf-8"), } - - if user: - input["user"] = user - if collection: input["collection"] = collection @@ -1048,7 +910,7 @@ class FlowInstance: ) def rows_query( - self, query, user="trustgraph", collection="default", + self, query,collection="default", variables=None, operation_name=None ): """ @@ -1059,7 +921,6 @@ class FlowInstance: Args: query: GraphQL query string - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") variables: Optional query variables dictionary operation_name: Optional operation name for multi-operation documents @@ -1085,9 +946,7 @@ class FlowInstance: } ''' result = flow.rows_query( - query=query, - user="trustgraph", - collection="scientists" + query=query,collection="scientists" ) # Query with variables @@ -1109,7 +968,6 @@ class FlowInstance: # The input consists of a GraphQL query and optional variables input = { "query": query, - "user": user, "collection": collection, } @@ -1145,7 +1003,7 @@ class FlowInstance: return result def sparql_query( - self, query, user="trustgraph", collection="default", + self, query,collection="default", limit=10000 ): """ @@ -1153,7 +1011,6 @@ class FlowInstance: Args: query: SPARQL 1.1 query string - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") limit: Safety limit on results (default: 10000) @@ -1169,7 +1026,6 @@ class FlowInstance: input = { "query": query, - "user": user, "collection": collection, "limit": limit, } @@ -1213,14 +1069,13 @@ class FlowInstance: return response - def structured_query(self, question, user="trustgraph", collection="default"): + def structured_query(self, question,collection="default"): """ Execute a natural language question against structured data. Combines NLP query conversion and GraphQL execution. Args: question: Natural language question - user: Cassandra keyspace identifier (default: "trustgraph") collection: Data collection identifier (default: "default") Returns: @@ -1229,7 +1084,6 @@ class FlowInstance: input = { "question": question, - "user": user, "collection": collection } @@ -1383,7 +1237,7 @@ class FlowInstance: return response["schema-matches"] def row_embeddings_query( - self, text, schema_name, user="trustgraph", collection="default", + self, text, schema_name,collection="default", index_name=None, limit=10 ): """ @@ -1396,7 +1250,6 @@ class FlowInstance: Args: text: Query text for semantic search schema_name: Schema name to search within - user: User/keyspace identifier (default: "trustgraph") collection: Collection identifier (default: "default") index_name: Optional index name to filter search to specific index limit: Maximum number of results (default: 10) @@ -1412,9 +1265,7 @@ class FlowInstance: # Search for customers by name similarity results = flow.row_embeddings_query( text="John Smith", - schema_name="customers", - user="trustgraph", - collection="sales", + schema_name="customers",collection="sales", limit=5 ) @@ -1436,7 +1287,6 @@ class FlowInstance: input = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index 84f98918..c3ec2308 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -63,105 +63,50 @@ class Knowledge: """ return self.api.request(f"knowledge", request) - def list_kg_cores(self, user="trustgraph"): + def list_kg_cores(self): """ - List all available knowledge graph cores. - - Retrieves the IDs of all KG cores available for the specified user. - - Args: - user: User identifier (default: "trustgraph") + List all available knowledge graph cores in this workspace. Returns: list[str]: List of KG core identifiers - - Example: - ```python - knowledge = api.knowledge() - - # List available KG cores - cores = knowledge.list_kg_cores(user="trustgraph") - print(f"Available KG cores: {cores}") - ``` """ - # The input consists of system and prompt strings input = { "operation": "list-kg-cores", - "user": user, + "workspace": self.api.workspace, } return self.request(request = input)["ids"] - def delete_kg_core(self, id, user="trustgraph"): + def delete_kg_core(self, id): """ - Delete a knowledge graph core. - - Removes a KG core from storage. This does not affect currently loaded - cores in flows. + Delete a knowledge graph core in this workspace. Args: id: KG core identifier to delete - user: User identifier (default: "trustgraph") - - Example: - ```python - knowledge = api.knowledge() - - # Delete a KG core - knowledge.delete_kg_core(id="medical-kb-v1", user="trustgraph") - ``` """ - # The input consists of system and prompt strings input = { "operation": "delete-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, } self.request(request = input) - def load_kg_core(self, id, user="trustgraph", flow="default", - collection="default"): + def load_kg_core(self, id, flow="default", collection="default"): """ Load a knowledge graph core into a flow. - Makes a KG core available for use in queries and RAG operations within - the specified flow and collection. - Args: id: KG core identifier to load - user: User identifier (default: "trustgraph") flow: Flow instance to load into (default: "default") collection: Collection to associate with (default: "default") - - Example: - ```python - knowledge = api.knowledge() - - # Load a medical knowledge base into the default flow - knowledge.load_kg_core( - id="medical-kb-v1", - user="trustgraph", - flow="default", - collection="medical" - ) - - # Now the flow can use this KG core for RAG queries - flow = api.flow().id("default") - response = flow.graph_rag( - query="What are the symptoms of diabetes?", - user="trustgraph", - collection="medical" - ) - ``` """ - # The input consists of system and prompt strings input = { "operation": "load-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, "flow": flow, "collection": collection, @@ -169,35 +114,18 @@ class Knowledge: self.request(request = input) - def unload_kg_core(self, id, user="trustgraph", flow="default"): + def unload_kg_core(self, id, flow="default"): """ Unload a knowledge graph core from a flow. - Removes a KG core from active use in the specified flow, freeing - resources while keeping the core available in storage. - Args: id: KG core identifier to unload - user: User identifier (default: "trustgraph") flow: Flow instance to unload from (default: "default") - - Example: - ```python - knowledge = api.knowledge() - - # Unload a KG core when no longer needed - knowledge.unload_kg_core( - id="medical-kb-v1", - user="trustgraph", - flow="default" - ) - ``` """ - # The input consists of system and prompt strings input = { "operation": "unload-kg-core", - "user": user, + "workspace": self.api.workspace, "id": id, "flow": flow, } diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index c66598aa..8f99e601 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -94,7 +94,7 @@ class Library: return self.api.request(f"librarian", request) def add_document( - self, document, id, metadata, user, title, comments, + self, document, id, metadata, title, comments, kind="text/plain", tags=[], on_progress=None, ): """ @@ -108,7 +108,6 @@ class Library: document: Document content as bytes id: Document identifier (auto-generated if None) metadata: Document metadata as list of Triple objects or object with emit method - user: User/owner identifier title: Document title comments: Document description or comments kind: MIME type of the document (default: "text/plain") @@ -131,7 +130,6 @@ class Library: document=f.read(), id="research-001", metadata=[], - user="trustgraph", title="Research Paper", comments="Key findings in quantum computing", kind="application/pdf", @@ -147,7 +145,6 @@ class Library: document=f.read(), id="large-doc-001", metadata=[], - user="trustgraph", title="Large Document", comments="A very large document", kind="application/pdf", @@ -176,7 +173,6 @@ class Library: document=document, id=id, metadata=metadata, - user=user, title=title, comments=comments, kind=kind, @@ -213,6 +209,7 @@ class Library: input = { "operation": "add-document", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), @@ -220,7 +217,7 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "user": user, + "workspace": self.api.workspace, "tags": tags }, "content": base64.b64encode(document).decode("utf-8"), @@ -229,7 +226,7 @@ class Library: return self.request(input) def _add_document_chunked( - self, document, id, metadata, user, title, comments, + self, document, id, metadata, title, comments, kind, tags, on_progress=None, ): """ @@ -245,13 +242,14 @@ class Library: # Begin upload session begin_request = { "operation": "begin-upload", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), "kind": kind, "title": title, "comments": comments, - "user": user, + "workspace": self.api.workspace, "tags": tags, }, "total-size": total_size, @@ -279,10 +277,10 @@ class Library: chunk_request = { "operation": "upload-chunk", + "workspace": self.api.workspace, "upload-id": upload_id, "chunk-index": chunk_index, "content": base64.b64encode(chunk_data).decode("utf-8"), - "user": user, } chunk_response = self.request(chunk_request) @@ -298,8 +296,8 @@ class Library: # Complete upload complete_request = { "operation": "complete-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } complete_response = self.request(complete_request) @@ -314,8 +312,8 @@ class Library: try: abort_request = { "operation": "abort-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } self.request(abort_request) logger.info(f"Aborted failed upload {upload_id}") @@ -323,15 +321,13 @@ class Library: logger.warning(f"Failed to abort upload: {abort_error}") raise - def get_documents(self, user, include_children=False): + def get_documents(self, include_children=False): """ - List all documents for a user. + List all documents in the current workspace. - Retrieves metadata for all documents owned by the specified user. By default, only returns top-level documents (not child/extracted documents). Args: - user: User identifier include_children: If True, also include child documents (default: False) Returns: @@ -345,7 +341,7 @@ class Library: library = api.library() # Get only top-level documents - docs = library.get_documents(user="trustgraph") + docs = library.get_documents() for doc in docs: print(f"{doc.id}: {doc.title} ({doc.kind})") @@ -353,13 +349,13 @@ class Library: print(f" Tags: {', '.join(doc.tags)}") # Get all documents including extracted pages - all_docs = library.get_documents(user="trustgraph", include_children=True) + all_docs = library.get_documents(include_children=True) ``` """ input = { "operation": "list-documents", - "user": user, + "workspace": self.api.workspace, "include-children": include_children, } @@ -381,7 +377,7 @@ class Library: ) for w in v["metadata"] ], - user = v["user"], + workspace = v.get("workspace", ""), tags = v["tags"], parent_id = v.get("parent-id", ""), document_type = v.get("document-type", "source"), @@ -392,14 +388,13 @@ class Library: logger.error("Failed to parse document list response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def get_document(self, user, id): + def get_document(self, id): """ Get metadata for a specific document. Retrieves the metadata for a single document by ID. Args: - user: User identifier id: Document identifier Returns: @@ -411,7 +406,7 @@ class Library: Example: ```python library = api.library() - doc = library.get_document(user="trustgraph", id="doc-123") + doc = library.get_document(id="doc-123") print(f"Title: {doc.title}") print(f"Comments: {doc.comments}") ``` @@ -419,7 +414,7 @@ class Library: input = { "operation": "get-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -441,7 +436,7 @@ class Library: ) for w in doc["metadata"] ], - user = doc["user"], + workspace = doc.get("workspace", ""), tags = doc["tags"], parent_id = doc.get("parent-id", ""), document_type = doc.get("document-type", "source"), @@ -450,14 +445,13 @@ class Library: logger.error("Failed to parse document response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def update_document(self, user, id, metadata): + def update_document(self, id, metadata): """ Update document metadata. Updates the metadata for an existing document in the library. Args: - user: User identifier id: Document identifier metadata: Updated DocumentMetadata object @@ -472,7 +466,7 @@ class Library: library = api.library() # Get existing document - doc = library.get_document(user="trustgraph", id="doc-123") + doc = library.get_document(id="doc-123") # Update metadata doc.title = "Updated Title" @@ -481,7 +475,6 @@ class Library: # Save changes updated_doc = library.update_document( - user="trustgraph", id="doc-123", metadata=doc ) @@ -490,8 +483,9 @@ class Library: input = { "operation": "update-document", + "workspace": self.api.workspace, "document-metadata": { - "user": user, + "workspace": self.api.workspace, "document-id": id, "time": metadata.time, "title": metadata.title, @@ -526,21 +520,20 @@ class Library: ) for w in doc["metadata"] ], - user = doc["user"], + workspace = doc.get("workspace", ""), tags = doc["tags"] ) except Exception as e: logger.error("Failed to parse document update response", exc_info=True) raise ProtocolException(f"Response not formatted correctly") - def remove_document(self, user, id): + def remove_document(self, id): """ Remove a document from the library. Deletes a document and its metadata from the library. Args: - user: User identifier id: Document identifier to remove Returns: @@ -549,13 +542,13 @@ class Library: Example: ```python library = api.library() - library.remove_document(user="trustgraph", id="doc-123") + library.remove_document(id="doc-123") ``` """ input = { "operation": "remove-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -565,7 +558,7 @@ class Library: def start_processing( self, id, document_id, flow="default", - user="trustgraph", collection="default", tags=[], + collection="default", tags=[], ): """ Start a document processing workflow. @@ -577,7 +570,6 @@ class Library: id: Unique processing job identifier document_id: ID of the document to process flow: Flow instance to use for processing (default: "default") - user: User identifier (default: "trustgraph") collection: Target collection for processed data (default: "default") tags: List of tags for the processing job (default: []) @@ -593,7 +585,6 @@ class Library: id="proc-001", document_id="doc-123", flow="default", - user="trustgraph", collection="research", tags=["automated", "extract"] ) @@ -602,12 +593,13 @@ class Library: input = { "operation": "add-processing", + "workspace": self.api.workspace, "processing-metadata": { "id": id, "document-id": document_id, "time": int(time.time()), "flow": flow, - "user": user, + "workspace": self.api.workspace, "collection": collection, "tags": tags, } @@ -618,7 +610,7 @@ class Library: return {} def stop_processing( - self, id, user="trustgraph", + self, id, ): """ Stop a running document processing job. @@ -627,7 +619,6 @@ class Library: Args: id: Processing job identifier to stop - user: User identifier (default: "trustgraph") Returns: dict: Empty response object @@ -635,29 +626,26 @@ class Library: Example: ```python library = api.library() - library.stop_processing(id="proc-001", user="trustgraph") + library.stop_processing(id="proc-001") ``` """ input = { "operation": "remove-processing", + "workspace": self.api.workspace, "processing-id": id, - "user": user, } object = self.request(input) return {} - def get_processings(self, user="trustgraph"): + def get_processings(self): """ List all active document processing jobs. Retrieves metadata for all currently running document processing workflows - for the specified user. - - Args: - user: User identifier (default: "trustgraph") + in the current workspace. Returns: list[ProcessingMetadata]: List of processing job metadata objects @@ -668,7 +656,7 @@ class Library: Example: ```python library = api.library() - jobs = library.get_processings(user="trustgraph") + jobs = library.get_processings() for job in jobs: print(f"Job {job.id}:") @@ -681,7 +669,7 @@ class Library: input = { "operation": "list-processing", - "user": user, + "workspace": self.api.workspace, } object = self.request(input) @@ -693,7 +681,7 @@ class Library: document_id = v["document-id"], time = datetime.datetime.fromtimestamp(v["time"]), flow = v["flow"], - user = v["user"], + workspace = v.get("workspace", ""), collection = v["collection"], tags = v["tags"], ) @@ -705,23 +693,20 @@ class Library: # Chunked upload management methods - def get_pending_uploads(self, user): + def get_pending_uploads(self): """ - List all pending (in-progress) uploads for a user. + List all pending (in-progress) uploads in the current workspace. Retrieves information about chunked uploads that have been started but not yet completed. - Args: - user: User identifier - Returns: list[dict]: List of pending upload information Example: ```python library = api.library() - pending = library.get_pending_uploads(user="trustgraph") + pending = library.get_pending_uploads() for upload in pending: print(f"Upload {upload['upload_id']}:") @@ -731,14 +716,14 @@ class Library: """ input = { "operation": "list-uploads", - "user": user, + "workspace": self.api.workspace, } response = self.request(input) return response.get("upload-sessions", []) - def get_upload_status(self, upload_id, user): + def get_upload_status(self, upload_id): """ Get the status of a specific upload. @@ -747,7 +732,6 @@ class Library: Args: upload_id: Upload session identifier - user: User identifier Returns: dict: Upload status information including: @@ -763,10 +747,7 @@ class Library: Example: ```python library = api.library() - status = library.get_upload_status( - upload_id="abc-123", - user="trustgraph" - ) + status = library.get_upload_status(upload_id="abc-123") if status['state'] == 'in-progress': print(f"Missing chunks: {status['missing_chunks']}") @@ -774,13 +755,13 @@ class Library: """ input = { "operation": "get-upload-status", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(input) - def abort_upload(self, upload_id, user): + def abort_upload(self, upload_id): """ Abort an in-progress upload. @@ -788,7 +769,6 @@ class Library: Args: upload_id: Upload session identifier - user: User identifier Returns: dict: Empty response on success @@ -796,18 +776,18 @@ class Library: Example: ```python library = api.library() - library.abort_upload(upload_id="abc-123", user="trustgraph") + library.abort_upload(upload_id="abc-123") ``` """ input = { "operation": "abort-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(input) - def resume_upload(self, upload_id, document, user, on_progress=None): + def resume_upload(self, upload_id, document, on_progress=None): """ Resume an interrupted upload. @@ -817,7 +797,6 @@ class Library: Args: upload_id: Upload session identifier to resume document: Complete document content as bytes - user: User identifier on_progress: Optional callback(bytes_sent, total_bytes) for progress updates Returns: @@ -828,23 +807,19 @@ class Library: library = api.library() # Check what's missing - status = library.get_upload_status( - upload_id="abc-123", - user="trustgraph" - ) + status = library.get_upload_status(upload_id="abc-123") if status['state'] == 'in-progress': # Resume with the same document with open("large_document.pdf", "rb") as f: library.resume_upload( upload_id="abc-123", - document=f.read(), - user="trustgraph" + document=f.read() ) ``` """ # Get current status - status = self.get_upload_status(upload_id, user) + status = self.get_upload_status(upload_id) if status.get("upload-state") == "expired": raise RuntimeError("Upload session has expired, please start a new upload") @@ -867,10 +842,10 @@ class Library: chunk_request = { "operation": "upload-chunk", + "workspace": self.api.workspace, "upload-id": upload_id, "chunk-index": chunk_index, "content": base64.b64encode(chunk_data).decode("utf-8"), - "user": user, } self.request(chunk_request) @@ -886,8 +861,8 @@ class Library: # Complete upload complete_request = { "operation": "complete-upload", + "workspace": self.api.workspace, "upload-id": upload_id, - "user": user, } return self.request(complete_request) @@ -895,7 +870,7 @@ class Library: # Child document methods def add_child_document( - self, document, id, parent_id, user, title, comments, + self, document, id, parent_id, title, comments, kind="text/plain", tags=[], metadata=None, ): """ @@ -909,7 +884,6 @@ class Library: document: Document content as bytes id: Document identifier (auto-generated if None) parent_id: Parent document identifier (required) - user: User/owner identifier title: Document title comments: Document description or comments kind: MIME type of the document (default: "text/plain") @@ -931,7 +905,6 @@ class Library: document=page_text.encode('utf-8'), id="doc-123-page-1", parent_id="doc-123", - user="trustgraph", title="Page 1 of Research Paper", comments="First page extracted from PDF", kind="text/plain", @@ -964,6 +937,7 @@ class Library: input = { "operation": "add-child-document", + "workspace": self.api.workspace, "document-metadata": { "id": id, "time": int(time.time()), @@ -971,7 +945,7 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "user": user, + "workspace": self.api.workspace, "tags": tags, "parent-id": parent_id, "document-type": "extracted", @@ -981,13 +955,12 @@ class Library: return self.request(input) - def list_children(self, document_id, user): + def list_children(self, document_id): """ List all child documents for a given parent document. Args: document_id: Parent document identifier - user: User identifier Returns: list[DocumentMetadata]: List of child document metadata objects @@ -995,10 +968,7 @@ class Library: Example: ```python library = api.library() - children = library.list_children( - document_id="doc-123", - user="trustgraph" - ) + children = library.list_children(document_id="doc-123") for child in children: print(f"{child.id}: {child.title}") @@ -1006,8 +976,8 @@ class Library: """ input = { "operation": "list-children", + "workspace": self.api.workspace, "document-id": document_id, - "user": user, } response = self.request(input) @@ -1028,7 +998,7 @@ class Library: ) for w in v.get("metadata", []) ], - user=v["user"], + workspace=v.get("workspace", ""), tags=v.get("tags", []), parent_id=v.get("parent-id", ""), document_type=v.get("document-type", "source"), @@ -1039,14 +1009,13 @@ class Library: logger.error("Failed to parse children response", exc_info=True) raise ProtocolException("Response not formatted correctly") - def get_document_content(self, user, id): + def get_document_content(self, id): """ Get the content of a document. Retrieves the full content of a document as bytes. Args: - user: User identifier id: Document identifier Returns: @@ -1055,10 +1024,7 @@ class Library: Example: ```python library = api.library() - content = library.get_document_content( - user="trustgraph", - id="doc-123" - ) + content = library.get_document_content(id="doc-123") # Write to file with open("output.pdf", "wb") as f: @@ -1067,7 +1033,7 @@ class Library: """ input = { "operation": "get-document-content", - "user": user, + "workspace": self.api.workspace, "document-id": id, } @@ -1076,7 +1042,7 @@ class Library: return base64.b64decode(content_b64) - def stream_document_to_file(self, user, id, file_path, chunk_size=1024*1024, on_progress=None): + def stream_document_to_file(self, id, file_path, chunk_size=1024*1024, on_progress=None): """ Stream document content to a file. @@ -1084,7 +1050,6 @@ class Library: enabling memory-efficient handling of large documents. Args: - user: User identifier id: Document identifier file_path: Path to write the document content chunk_size: Size of each chunk to download (default 1MB) @@ -1101,7 +1066,6 @@ class Library: print(f"Downloaded {received}/{total} bytes") library.stream_document_to_file( - user="trustgraph", id="large-doc-123", file_path="/tmp/document.pdf", on_progress=progress @@ -1116,7 +1080,7 @@ class Library: while True: input = { "operation": "stream-document", - "user": user, + "workspace": self.api.workspace, "document-id": id, "chunk-index": chunk_index, "chunk-size": chunk_size, diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index c590c9b4..4eade3e8 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -84,10 +84,14 @@ class SocketClient: for streaming responses. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__( + self, url: str, timeout: int, token: Optional[str], + workspace: str = "default", + ) -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace self._request_counter: int = 0 self._lock: Lock = Lock() self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -251,6 +255,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -290,6 +295,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -328,6 +334,7 @@ class SocketClient: try: message = { "id": request_id, + "workspace": self.workspace, "service": service, "request": request } @@ -488,7 +495,6 @@ class SocketFlowInstance: def agent( self, question: str, - user: str, state: Optional[Dict[str, Any]] = None, group: Optional[str] = None, history: Optional[List[Dict[str, Any]]] = None, @@ -498,7 +504,6 @@ class SocketFlowInstance: """Execute an agent operation with streaming support.""" request = { "question": question, - "user": user, "streaming": streaming } if state is not None: @@ -514,7 +519,6 @@ class SocketFlowInstance: def agent_explain( self, question: str, - user: str, collection: str, state: Optional[Dict[str, Any]] = None, group: Optional[str] = None, @@ -524,7 +528,6 @@ class SocketFlowInstance: """Execute an agent operation with explainability support.""" request = { "question": question, - "user": user, "collection": collection, "streaming": True } @@ -574,7 +577,6 @@ class SocketFlowInstance: def graph_rag( self, query: str, - user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, @@ -592,7 +594,6 @@ class SocketFlowInstance: """ request = { "query": query, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -619,7 +620,6 @@ class SocketFlowInstance: def graph_rag_explain( self, query: str, - user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, @@ -632,7 +632,6 @@ class SocketFlowInstance: """Execute graph-based RAG query with explainability support.""" request = { "query": query, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -653,7 +652,6 @@ class SocketFlowInstance: def document_rag( self, query: str, - user: str, collection: str, doc_limit: int = 10, streaming: bool = False, @@ -666,7 +664,6 @@ class SocketFlowInstance: """ request = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": streaming @@ -688,7 +685,6 @@ class SocketFlowInstance: def document_rag_explain( self, query: str, - user: str, collection: str, doc_limit: int = 10, **kwargs: Any @@ -696,7 +692,6 @@ class SocketFlowInstance: """Execute document-based RAG query with explainability support.""" request = { "query": query, - "user": user, "collection": collection, "doc-limit": doc_limit, "streaming": True, @@ -748,7 +743,6 @@ class SocketFlowInstance: def graph_embeddings_query( self, text: str, - user: str, collection: str, limit: int = 10, **kwargs: Any @@ -759,7 +753,6 @@ class SocketFlowInstance: request = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -770,7 +763,6 @@ class SocketFlowInstance: def document_embeddings_query( self, text: str, - user: str, collection: str, limit: int = 10, **kwargs: Any @@ -781,7 +773,6 @@ class SocketFlowInstance: request = { "vector": vector, - "user": user, "collection": collection, "limit": limit } @@ -802,7 +793,6 @@ class SocketFlowInstance: p: Optional[Union[str, Dict[str, Any]]] = None, o: Optional[Union[str, Dict[str, Any]]] = None, g: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, **kwargs: Any @@ -822,8 +812,6 @@ class SocketFlowInstance: request["o"] = o_term if g is not None: request["g"] = g - if user is not None: - request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) @@ -839,7 +827,6 @@ class SocketFlowInstance: p: Optional[Union[str, Dict[str, Any]]] = None, o: Optional[Union[str, Dict[str, Any]]] = None, g: Optional[str] = None, - user: Optional[str] = None, collection: Optional[str] = None, limit: int = 100, batch_size: int = 20, @@ -864,8 +851,6 @@ class SocketFlowInstance: request["o"] = o_term if g is not None: request["g"] = g - if user is not None: - request["user"] = user if collection is not None: request["collection"] = collection request.update(kwargs) @@ -879,7 +864,6 @@ class SocketFlowInstance: def sparql_query_stream( self, query: str, - user: str = "trustgraph", collection: str = "default", limit: int = 10000, batch_size: int = 20, @@ -888,7 +872,6 @@ class SocketFlowInstance: """Execute a SPARQL query with streaming batches.""" request = { "query": query, - "user": user, "collection": collection, "limit": limit, "streaming": True, @@ -904,7 +887,6 @@ class SocketFlowInstance: def rows_query( self, query: str, - user: str, collection: str, variables: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, @@ -913,7 +895,6 @@ class SocketFlowInstance: """Execute a GraphQL query against structured rows.""" request = { "query": query, - "user": user, "collection": collection } if variables: @@ -943,7 +924,6 @@ class SocketFlowInstance: self, text: str, schema_name: str, - user: str = "trustgraph", collection: str = "default", index_name: Optional[str] = None, limit: int = 10, @@ -956,7 +936,6 @@ class SocketFlowInstance: request = { "vector": vector, "schema_name": schema_name, - "user": user, "collection": collection, "limit": limit } diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index f5987b0e..129f807a 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -45,10 +45,13 @@ class ConfigValue: type: Configuration type/category key: Specific configuration key value: Configuration value as string + workspace: Workspace the value belongs to. Only populated for + responses to getvalues-all-ws; empty otherwise. """ type : str key : str value : str + workspace : str = "" @dataclasses.dataclass class DocumentMetadata: @@ -62,7 +65,7 @@ class DocumentMetadata: title: Document title comments: Additional comments or description metadata: List of RDF triples providing structured metadata - user: User/owner identifier + workspace: Workspace the document belongs to tags: List of tags for categorization parent_id: Parent document ID for child documents (empty for top-level docs) document_type: "source" for uploaded documents, "extracted" for derived content @@ -73,7 +76,7 @@ class DocumentMetadata: title : str comments : str metadata : List[Triple] - user : str + workspace : str tags : List[str] parent_id : str = "" document_type : str = "source" @@ -88,7 +91,7 @@ class ProcessingMetadata: document_id: ID of the document being processed time: Processing start timestamp flow: Flow instance handling the processing - user: User identifier + workspace: Workspace the processing job belongs to collection: Target collection for processed data tags: List of tags for categorization """ @@ -96,7 +99,7 @@ class ProcessingMetadata: document_id : str time : datetime.datetime flow : str - user : str + workspace : str collection : str tags : List[str] @@ -105,17 +108,15 @@ class CollectionMetadata: """ Metadata for a data collection. - Collections provide logical grouping and isolation for documents and - knowledge graph data. + Collections provide logical grouping within a workspace for documents + and knowledge graph data. Attributes: - user: User/owner identifier collection: Collection identifier name: Human-readable collection name description: Collection description tags: List of tags for categorization """ - user : str collection : str name : str description : str diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 9b9328cb..a7ce4961 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -125,21 +125,39 @@ class AsyncProcessor: response_metrics = config_resp_metrics, ) - async def fetch_config(self): - """Fetch full config from config service using a short-lived - request/response client. Returns (config, version) or raises.""" - client = self._create_config_client() - try: - await client.start() - resp = await client.request( - ConfigRequest(operation="config"), - timeout=10, - ) - if resp.error: - raise RuntimeError(f"Config error: {resp.error.message}") - return resp.config, resp.version - finally: - await client.stop() + async def _fetch_type_workspace(self, client, workspace, config_type): + """Fetch config values of a single type within one workspace. + Returns dict of {key: value}.""" + resp = await client.request( + ConfigRequest( + operation="getvalues", + workspace=workspace, + type=config_type, + ), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + return {v.key: v.value for v in resp.values} + + async def _fetch_type_all_workspaces(self, client, config_type): + """Fetch config values of a single type across all workspaces. + Returns dict of {workspace: {key: value}}.""" + resp = await client.request( + ConfigRequest( + operation="getvalues-all-ws", + type=config_type, + ), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + + grouped = {} + for v in resp.values: + ws = grouped.setdefault(v.workspace, {}) + ws[v.key] = v.value + return grouped, resp.version # This is called to start dynamic behaviour. # Implements the subscribe-then-fetch pattern to avoid race conditions. @@ -155,21 +173,51 @@ class AsyncProcessor: # processed by on_config_notify, which does the version check async def fetch_and_apply_config(self): - """Fetch full config from config service and apply to all handlers. - Retries until successful — config service may not be ready yet.""" + """Startup: for each registered handler, fetch config for all its + types across all workspaces and invoke the handler once per + workspace. Retries until successful — config service may not be + ready yet.""" while self.running: try: - config, version = await self.fetch_config() + client = self._create_config_client() + try: + await client.start() - logger.info(f"Fetched config version {version}") + version = 0 - self.config_version = version + for entry in self.config_handlers: + handler_types = entry["types"] - # Apply to all handlers (startup = invoke all) - for entry in self.config_handlers: - await entry["handler"](config, version) + # Handlers registered without types get nothing + # at startup (there is no "all types" fetch). + if not handler_types: + continue + + # Group all registered types by workspace: + # {workspace: {type: {key: value}}} + per_ws = {} + for t in handler_types: + type_data, v = \ + await self._fetch_type_all_workspaces( + client, t, + ) + version = max(version, v) + for ws, kv in type_data.items(): + per_ws.setdefault(ws, {})[t] = kv + + # Call the handler once per workspace + for ws, config in per_ws.items(): + await entry["handler"](ws, config, version) + + logger.info( + f"Applied startup config version {version}" + ) + self.config_version = version + + finally: + await client.stop() return @@ -204,8 +252,9 @@ class AsyncProcessor: # Called when a config notify message arrives async def on_config_notify(self, message, consumer, flow): - notify_version = message.value().version - notify_types = set(message.value().types) + v = message.value() + notify_version = v.version + changes = v.changes # dict of type -> [workspaces] # Skip if we already have this version or newer if notify_version <= self.config_version: @@ -215,41 +264,60 @@ class AsyncProcessor: ) return - # Check if any handler cares about the affected types - if notify_types: - any_interested = False - for entry in self.config_handlers: - handler_types = entry["types"] - if handler_types is None or notify_types & handler_types: - any_interested = True - break + notify_types = set(changes.keys()) - if not any_interested: - logger.debug( - f"Ignoring config notify v{notify_version}, " - f"no handlers for types {notify_types}" - ) - self.config_version = notify_version - return + # Filter out handlers that don't care about any of the changed + # types. A handler registered without types never fires on + # notifications (nothing to scope to). + interested = [] + for entry in self.config_handlers: + handler_types = entry["types"] + if handler_types and notify_types & handler_types: + interested.append(entry) + + if not interested: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"no handlers for types {notify_types}" + ) + self.config_version = notify_version + return logger.info( - f"Config notify v{notify_version} types={list(notify_types)}, " - f"fetching config..." + f"Config notify v{notify_version} " + f"types={list(notify_types)}, fetching config..." ) - # Fetch full config using short-lived client try: - config, version = await self.fetch_config() + client = self._create_config_client() + try: + await client.start() - self.config_version = version + for entry in interested: + handler_types = entry["types"] - # Invoke handlers that care about the affected types - for entry in self.config_handlers: - handler_types = entry["types"] - if handler_types is None: - await entry["handler"](config, version) - elif not notify_types or notify_types & handler_types: - await entry["handler"](config, version) + # Build {workspace: {type: {key: value}}} for types + # this handler cares about, where the workspace was + # affected for that type. + per_ws = {} + for t in handler_types: + if t not in changes: + continue + for ws in changes[t]: + kv = await self._fetch_type_workspace( + client, ws, t, + ) + per_ws.setdefault(ws, {})[t] = kv + + for ws, config in per_ws.items(): + await entry["handler"]( + ws, config, notify_version, + ) + + finally: + await client.stop() + + self.config_version = notify_version except Exception as e: logger.error( diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index 4bd78428..3771d78e 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -48,12 +48,13 @@ class ChunkingService(FlowProcessor): await super(ChunkingService, self).start() await self.librarian.start() - async def get_document_text(self, doc): + async def get_document_text(self, doc, workspace): """ Get text content from a TextDocument, fetching from librarian if needed. Args: doc: TextDocument with either inline text or document_id + workspace: Workspace for librarian lookup (from flow.workspace) Returns: str: The document text content @@ -62,7 +63,7 @@ class ChunkingService(FlowProcessor): logger.info(f"Fetching document {doc.document_id} from librarian...") text = await self.librarian.fetch_document_text( document_id=doc.document_id, - user=doc.metadata.user, + workspace=workspace, ) logger.info(f"Fetched {len(text)} characters from librarian") return text diff --git a/trustgraph-base/trustgraph/base/collection_config_handler.py b/trustgraph-base/trustgraph/base/collection_config_handler.py index 8c1af822..4cb91c53 100644 --- a/trustgraph-base/trustgraph/base/collection_config_handler.py +++ b/trustgraph-base/trustgraph/base/collection_config_handler.py @@ -15,114 +15,139 @@ class CollectionConfigHandler: Storage services should: 1. Inherit from this class along with their service base class 2. Call register_config_handler(self.on_collection_config) in __init__ - 3. Implement create_collection(user, collection, metadata) method - 4. Implement delete_collection(user, collection) method + 3. Implement create_collection(workspace, collection, metadata) method + 4. Implement delete_collection(workspace, collection) method """ def __init__(self, **kwargs): - # Track known collections: {(user, collection): metadata_dict} + # Track known collections: {(workspace, collection): metadata_dict} self.known_collections: Dict[tuple, dict] = {} # Pass remaining kwargs up the inheritance chain super().__init__(**kwargs) - async def on_collection_config(self, config: dict, version: int): + async def on_collection_config( + self, workspace: str, config: dict, version: int + ): """ Handle config push messages and extract collection information + for a single workspace. Args: + workspace: Workspace the config applies to config: Configuration dictionary from ConfigPush message version: Configuration version number """ - logger.info(f"Processing collection configuration (version {version})") + logger.info( + f"Processing collection configuration " + f"(version {version}, workspace {workspace})" + ) - # Extract collections from config (treat missing key as empty) + # Extract collections from config (treat missing key as empty). + # Each config key IS the collection name — config is already + # partitioned by workspace, so no workspace prefix is needed + # on the key. collection_config = config.get("collection", {}) # Track which collections we've seen in this config current_collections: Set[tuple] = set() - # Process each collection in the config - for key, value_json in collection_config.items(): + for collection, value_json in collection_config.items(): try: - # Parse user:collection key - if ":" not in key: - logger.warning(f"Invalid collection key format (expected user:collection): {key}") - continue + current_collections.add((workspace, collection)) - user, collection = key.split(":", 1) - current_collections.add((user, collection)) - - # Parse metadata metadata = json.loads(value_json) - # Check if this is a new collection or updated - collection_key = (user, collection) - if collection_key not in self.known_collections: - logger.info(f"New collection detected: {user}/{collection}") - await self.create_collection(user, collection, metadata) - self.known_collections[collection_key] = metadata + key = (workspace, collection) + if key not in self.known_collections: + logger.info( + f"New collection detected: {workspace}/{collection}" + ) + await self.create_collection( + workspace, collection, metadata + ) + self.known_collections[key] = metadata else: - # Collection already exists, update metadata if changed - if self.known_collections[collection_key] != metadata: - logger.info(f"Collection metadata updated: {user}/{collection}") - # Most storage services don't need to do anything for metadata updates - # They just need to know the collection exists - self.known_collections[collection_key] = metadata + if self.known_collections[key] != metadata: + logger.info( + f"Collection metadata updated: " + f"{workspace}/{collection}" + ) + self.known_collections[key] = metadata except Exception as e: - logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True) + logger.error( + f"Error processing collection config for " + f"{workspace}/{collection}: {e}", + exc_info=True, + ) - # Find collections that were deleted (in known but not in current) - deleted_collections = set(self.known_collections.keys()) - current_collections - for user, collection in deleted_collections: - logger.info(f"Collection deleted: {user}/{collection}") + # Find collections for THIS workspace that were deleted (in + # known but not in current). Only compare collections owned by + # this workspace — other workspaces' collections are not + # affected by this config update. + known_for_ws = { + (w, c) for (w, c) in self.known_collections.keys() + if w == workspace + } + deleted_collections = known_for_ws - current_collections + for ws, collection in deleted_collections: + logger.info(f"Collection deleted: {ws}/{collection}") try: - # Remove from known_collections FIRST to immediately reject new writes - # This eliminates race condition with worker threads - del self.known_collections[(user, collection)] - # Physical deletion happens after - worker threads already rejecting writes - await self.delete_collection(user, collection) + # Remove from known_collections FIRST to immediately + # reject new writes + del self.known_collections[(ws, collection)] + await self.delete_collection(ws, collection) except Exception as e: - logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True) - # If physical deletion failed, should we re-add to known_collections? - # For now, keep it removed - collection is logically deleted per config + logger.error( + f"Error deleting collection {ws}/{collection}: {e}", + exc_info=True, + ) - logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}") + logger.debug( + f"Collection config processing complete. " + f"Known collections: {len(self.known_collections)}" + ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection( + self, workspace: str, collection: str, metadata: dict, + ): """ Create a collection in the storage backend. Subclasses must implement this method. Args: - user: User ID + workspace: Workspace ID collection: Collection ID metadata: Collection metadata dictionary """ - raise NotImplementedError("Storage service must implement create_collection method") + raise NotImplementedError( + "Storage service must implement create_collection method" + ) - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """ Delete a collection from the storage backend. Subclasses must implement this method. Args: - user: User ID + workspace: Workspace ID collection: Collection ID """ - raise NotImplementedError("Storage service must implement delete_collection method") + raise NotImplementedError( + "Storage service must implement delete_collection method" + ) - def collection_exists(self, user: str, collection: str) -> bool: + def collection_exists(self, workspace: str, collection: str) -> bool: """ - Check if a collection is known to exist + Check if a collection is known to exist. Args: - user: User ID + workspace: Workspace ID collection: Collection ID Returns: True if collection exists, False otherwise """ - return (user, collection) in self.known_collections + return (workspace, collection) in self.known_collections diff --git a/trustgraph-base/trustgraph/base/config_client.py b/trustgraph-base/trustgraph/base/config_client.py index c9ec3f9b..504a6d58 100644 --- a/trustgraph-base/trustgraph/base/config_client.py +++ b/trustgraph-base/trustgraph/base/config_client.py @@ -18,10 +18,11 @@ class ConfigClient(RequestResponse): ) return resp - async def get(self, type, key, timeout=CONFIG_TIMEOUT): + async def get(self, workspace, type, key, timeout=CONFIG_TIMEOUT): """Get a single config value. Returns the value string or None.""" resp = await self._request( operation="get", + workspace=workspace, keys=[ConfigKey(type=type, key=key)], timeout=timeout, ) @@ -29,19 +30,21 @@ class ConfigClient(RequestResponse): return resp.values[0].value return None - async def put(self, type, key, value, timeout=CONFIG_TIMEOUT): + async def put(self, workspace, type, key, value, timeout=CONFIG_TIMEOUT): """Put a single config value.""" await self._request( operation="put", + workspace=workspace, values=[ConfigValue(type=type, key=key, value=value)], timeout=timeout, ) - async def put_many(self, values, timeout=CONFIG_TIMEOUT): - """Put multiple config values in a single request. - values is a list of (type, key, value) tuples.""" + async def put_many(self, workspace, values, timeout=CONFIG_TIMEOUT): + """Put multiple config values in a single request within a + single workspace. values is a list of (type, key, value) tuples.""" await self._request( operation="put", + workspace=workspace, values=[ ConfigValue(type=t, key=k, value=v) for t, k, v in values @@ -49,19 +52,21 @@ class ConfigClient(RequestResponse): timeout=timeout, ) - async def delete(self, type, key, timeout=CONFIG_TIMEOUT): + async def delete(self, workspace, type, key, timeout=CONFIG_TIMEOUT): """Delete a single config key.""" await self._request( operation="delete", + workspace=workspace, keys=[ConfigKey(type=type, key=key)], timeout=timeout, ) - async def delete_many(self, keys, timeout=CONFIG_TIMEOUT): - """Delete multiple config keys in a single request. - keys is a list of (type, key) tuples.""" + async def delete_many(self, workspace, keys, timeout=CONFIG_TIMEOUT): + """Delete multiple config keys in a single request within a + single workspace. keys is a list of (type, key) tuples.""" await self._request( operation="delete", + workspace=workspace, keys=[ ConfigKey(type=t, key=k) for t, k in keys @@ -69,15 +74,26 @@ class ConfigClient(RequestResponse): timeout=timeout, ) - async def keys(self, type, timeout=CONFIG_TIMEOUT): - """List all keys for a config type.""" + async def keys(self, workspace, type, timeout=CONFIG_TIMEOUT): + """List all keys for a config type within a workspace.""" resp = await self._request( operation="list", + workspace=workspace, type=type, timeout=timeout, ) return resp.directory + async def workspaces_for_type(self, type, timeout=CONFIG_TIMEOUT): + """Return the set of distinct workspaces with any config of + the given type.""" + resp = await self._request( + operation="getvalues-all-ws", + type=type, + timeout=timeout, + ) + return {v.workspace for v in resp.values if v.workspace} + class ConfigClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py index 023537df..af072cca 100644 --- a/trustgraph-base/trustgraph/base/consumer_spec.py +++ b/trustgraph-base/trustgraph/base/consumer_spec.py @@ -24,7 +24,10 @@ class ConsumerSpec(Spec): flow = flow, backend = processor.pubsub, topic = definition["topics"][self.name], - subscriber = processor.id + "--" + flow.name + "--" + self.name, + subscriber = ( + processor.id + "--" + flow.workspace + "--" + + flow.name + "--" + self.name + ), schema = self.schema, handler = self.handler, metrics = consumer_metrics, diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index dd985eab..a93cdc87 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -9,14 +9,12 @@ from .. knowledge import Uri, Literal logger = logging.getLogger(__name__) class DocumentEmbeddingsClient(RequestResponse): - async def query(self, vector, limit=20, user="trustgraph", - collection="default", timeout=30): + async def query(self, vector, limit=20, collection="default", timeout=30): resp = await self.request( DocumentEmbeddingsRequest( vector = vector, limit = limit, - user = user, collection = collection ), timeout=timeout diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index d5bf8421..cd9e91b1 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -60,7 +60,9 @@ class DocumentEmbeddingsQueryService(FlowProcessor): logger.debug(f"Handling document embeddings query request {id}...") - docs = await self.query_document_embeddings(request) + docs = await self.query_document_embeddings( + flow.workspace, request, + ) logger.debug("Sending document embeddings query response...") r = DocumentEmbeddingsResponse(chunks=docs, error=None) diff --git a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py index 0c7921db..96b7781f 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py @@ -41,7 +41,8 @@ class DocumentEmbeddingsStoreService(FlowProcessor): request = msg.value() - await self.store_document_embeddings(request) + # Workspace comes from the flow the message arrived on. + await self.store_document_embeddings(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/base/dynamic_tool_service.py b/trustgraph-base/trustgraph/base/dynamic_tool_service.py index bcfb71ab..00a457de 100644 --- a/trustgraph-base/trustgraph/base/dynamic_tool_service.py +++ b/trustgraph-base/trustgraph/base/dynamic_tool_service.py @@ -2,7 +2,7 @@ Base class for dynamically pluggable tool services. Tool services are Pulsar services that can be invoked as agent tools. -They receive a ToolServiceRequest with user, config, and arguments, +They receive a ToolServiceRequest with config and arguments, and return a ToolServiceResponse with the result. Uses direct Pulsar topics (no flow configuration required): @@ -42,7 +42,6 @@ class DynamicToolService(AsyncProcessor): the tool's logic. The invoke method receives: - - user: The user context for multi-tenancy - config: Dict of config values from the tool descriptor - arguments: Dict of arguments from the LLM @@ -115,14 +114,13 @@ class DynamicToolService(AsyncProcessor): id = msg.properties().get("id", "unknown") # Parse the request - user = request.user or "trustgraph" config = json.loads(request.config) if request.config else {} arguments = json.loads(request.arguments) if request.arguments else {} - logger.debug(f"Tool service request: user={user}, config={config}, arguments={arguments}") + logger.debug(f"Tool service request: config={config}, arguments={arguments}") # Invoke the tool implementation - response = await self.invoke(user, config, arguments) + response = await self.invoke(config, arguments) # Send success response await self.producer.send( @@ -159,14 +157,13 @@ class DynamicToolService(AsyncProcessor): properties={"id": id if id else "unknown"} ) - async def invoke(self, user, config, arguments): + async def invoke(self, config, arguments): """ Invoke the tool service. Override this method in subclasses to implement the tool's logic. Args: - user: The user context for multi-tenancy config: Dict of config values from the tool descriptor arguments: Dict of arguments from the LLM diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py index 9a515bf8..2caad938 100644 --- a/trustgraph-base/trustgraph/base/flow.py +++ b/trustgraph-base/trustgraph/base/flow.py @@ -4,15 +4,16 @@ import asyncio class Flow: """ Runtime representation of a deployed flow process. - + This class maintains internal processor states and orchestrates - lifecycles (start, stop) for inputs (consumers) and parameters + lifecycles (start, stop) for inputs (consumers) and parameters that drive data flowing across linked nodes. """ - def __init__(self, id, flow, processor, defn): + def __init__(self, id, flow, workspace, processor, defn): self.id = id self.name = flow + self.workspace = workspace self.producer = {} diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py index 99cb0f53..aa7bf921 100644 --- a/trustgraph-base/trustgraph/base/flow_processor.py +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -35,6 +35,8 @@ class FlowProcessor(AsyncProcessor): ) # Initialise flow information state + # Keyed by (workspace, flow) tuples; each workspace has its own + # set of flow variants for this processor. self.flows = {} # These can be overriden by a derived class: @@ -48,23 +50,28 @@ class FlowProcessor(AsyncProcessor): def register_specification(self, spec: Any) -> None: self.specifications.append(spec) - # Start processing for a new flow - async def start_flow(self, flow, defn): - self.flows[flow] = Flow(self.id, flow, self, defn) - await self.flows[flow].start() - logger.info(f"Started flow: {flow}") - - # Stop processing for a new flow - async def stop_flow(self, flow): - if flow in self.flows: - await self.flows[flow].stop() - del self.flows[flow] - logger.info(f"Stopped flow: {flow}") + # Start processing for a new flow within a workspace + async def start_flow(self, workspace, flow, defn): + key = (workspace, flow) + self.flows[key] = Flow(self.id, flow, workspace, self, defn) + await self.flows[key].start() + logger.info(f"Started flow: {workspace}/{flow}") - # Event handler - called for a configuration change - async def on_configure_flows(self, config, version): + # Stop processing for a flow within a workspace + async def stop_flow(self, workspace, flow): + key = (workspace, flow) + if key in self.flows: + await self.flows[key].stop() + del self.flows[key] + logger.info(f"Stopped flow: {workspace}/{flow}") - logger.info(f"Got config version {version}") + # Event handler - called for a configuration change for a single + # workspace + async def on_configure_flows(self, workspace, config, version): + + logger.info( + f"Got config version {version} for workspace {workspace}" + ) config_type = f"processor:{self.id}" @@ -76,26 +83,28 @@ class FlowProcessor(AsyncProcessor): for k, v in config[config_type].items() } else: - logger.debug("No configuration settings for me.") + logger.debug( + f"No configuration settings for me in {workspace}." + ) flow_config = {} - # Get list of flows which should be running and are currently - # running - wanted_flows = flow_config.keys() - # This takes a copy, needed because dict gets modified by stop_flow - current_flows = list(self.flows.keys()) + # Get list of flows which should be running in this workspace, + # and the list currently running in this workspace + wanted_flows = set(flow_config.keys()) + current_flows = { + f for (ws, f) in self.flows.keys() if ws == workspace + } - # Start all the flows which arent currently running - for flow in wanted_flows: - if flow not in current_flows: - await self.start_flow(flow, flow_config[flow]) + # Start all the flows which aren't currently running in this + # workspace + for flow in wanted_flows - current_flows: + await self.start_flow(workspace, flow, flow_config[flow]) - # Stop all the unwanted flows which are due to be stopped - for flow in current_flows: - if flow not in wanted_flows: - await self.stop_flow(flow) + # Stop all the unwanted flows in this workspace + for flow in current_flows - wanted_flows: + await self.stop_flow(workspace, flow) - logger.info("Handled config update") + logger.info(f"Handled config update for workspace {workspace}") # Start threads, just call parent async def start(self): diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py index fe717bf1..a9348c19 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -22,14 +22,12 @@ def to_value(x: Any) -> Any: return Literal(x.value or x.iri) class GraphEmbeddingsClient(RequestResponse): - async def query(self, vector, limit=20, user="trustgraph", - collection="default", timeout=30): + async def query(self, vector, limit=20, collection="default", timeout=30): resp = await self.request( GraphEmbeddingsRequest( vector = vector, limit = limit, - user = user, collection = collection ), timeout=timeout diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py index 55c8efa9..cbce810c 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -60,7 +60,9 @@ class GraphEmbeddingsQueryService(FlowProcessor): logger.debug(f"Handling graph embeddings query request {id}...") - entities = await self.query_graph_embeddings(request) + entities = await self.query_graph_embeddings( + flow.workspace, request, + ) logger.debug("Sending graph embeddings query response...") r = GraphEmbeddingsResponse(entities=entities, error=None) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py index 09bbbe6a..10cfe93c 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py @@ -41,7 +41,8 @@ class GraphEmbeddingsStoreService(FlowProcessor): request = msg.value() - await self.store_graph_embeddings(request) + # Workspace comes from the flow the message arrived on. + await self.store_graph_embeddings(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index 9db23293..e07781f9 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -3,7 +3,7 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import GraphRagQuery, GraphRagResponse class GraphRagClient(RequestResponse): - async def rag(self, query, user="trustgraph", collection="default", + async def rag(self, query, collection="default", chunk_callback=None, explain_callback=None, parent_uri="", timeout=600): @@ -12,7 +12,6 @@ class GraphRagClient(RequestResponse): Args: query: The question to ask - user: User identifier collection: Collection identifier chunk_callback: Optional async callback(text, end_of_stream) for text chunks explain_callback: Optional async callback(explain_id, explain_graph, explain_triples) for explain notifications @@ -49,7 +48,6 @@ class GraphRagClient(RequestResponse): await self.request( GraphRagQuery( query = query, - user = user, collection = collection, parent_uri = parent_uri, ), diff --git a/trustgraph-base/trustgraph/base/librarian_client.py b/trustgraph-base/trustgraph/base/librarian_client.py index 5ad97f47..1876602b 100644 --- a/trustgraph-base/trustgraph/base/librarian_client.py +++ b/trustgraph-base/trustgraph/base/librarian_client.py @@ -10,7 +10,7 @@ Usage: id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params ) await self.librarian.start() - content = await self.librarian.fetch_document_content(doc_id, user) + content = await self.librarian.fetch_document_content(doc_id, workspace) """ import asyncio @@ -150,7 +150,7 @@ class LibrarianClient: finally: self._streams.pop(request_id, None) - async def fetch_document_content(self, document_id, user, timeout=120): + async def fetch_document_content(self, document_id, workspace, timeout=120): """Fetch document content using streaming. Returns base64-encoded content. Caller is responsible for decoding. @@ -158,7 +158,7 @@ class LibrarianClient: req = LibrarianRequest( operation="stream-document", document_id=document_id, - user=user, + workspace=workspace, ) chunks = await self.stream(req, timeout=timeout) @@ -176,24 +176,24 @@ class LibrarianClient: return base64.b64encode(raw) - async def fetch_document_text(self, document_id, user, timeout=120): + async def fetch_document_text(self, document_id, workspace, timeout=120): """Fetch document content and decode as UTF-8 text.""" content = await self.fetch_document_content( - document_id, user, timeout=timeout, + document_id, workspace, timeout=timeout, ) return base64.b64decode(content).decode("utf-8") - async def fetch_document_metadata(self, document_id, user, timeout=120): + async def fetch_document_metadata(self, document_id, workspace, timeout=120): """Fetch document metadata from the librarian.""" req = LibrarianRequest( operation="get-document-metadata", document_id=document_id, - user=user, + workspace=workspace, ) response = await self.request(req, timeout=timeout) return response.document_metadata - async def save_child_document(self, doc_id, parent_id, user, content, + async def save_child_document(self, doc_id, parent_id, workspace, content, document_type="chunk", title=None, kind="text/plain", timeout=120): """Save a child document to the librarian.""" @@ -202,7 +202,7 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind=kind, title=title or doc_id, parent_id=parent_id, @@ -218,7 +218,7 @@ class LibrarianClient: await self.request(req, timeout=timeout) return doc_id - async def save_document(self, doc_id, user, content, title=None, + async def save_document(self, doc_id, workspace, content, title=None, document_type="answer", kind="text/plain", timeout=120): """Save a document to the librarian.""" @@ -227,7 +227,7 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind=kind, title=title or doc_id, document_type=document_type, @@ -238,7 +238,7 @@ class LibrarianClient: document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content).decode("utf-8"), - user=user, + workspace=workspace, ) await self.request(req, timeout=timeout) diff --git a/trustgraph-base/trustgraph/base/request_response_spec.py b/trustgraph-base/trustgraph/base/request_response_spec.py index b91c655c..aa934a7f 100644 --- a/trustgraph-base/trustgraph/base/request_response_spec.py +++ b/trustgraph-base/trustgraph/base/request_response_spec.py @@ -133,8 +133,9 @@ class RequestResponseSpec(Spec): # Make subscription names unique, so that all subscribers get # to see all response messages subscription = ( - processor.id + "--" + flow.name + "--" + self.request_name + - "--" + str(uuid.uuid4()) + processor.id + "--" + flow.workspace + "--" + + flow.name + "--" + self.request_name + "--" + + str(uuid.uuid4()) ), consumer_name = flow.id, request_topic = definition["topics"][self.request_name], diff --git a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py index 811adf40..98c2e0a7 100644 --- a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py +++ b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py @@ -3,13 +3,12 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse class RowEmbeddingsQueryClient(RequestResponse): async def row_embeddings_query( - self, vector, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, collection="default", index_name=None, limit=10, timeout=600 ): request = RowEmbeddingsRequest( vector=vector, schema_name=schema_name, - user=user, collection=collection, limit=limit ) diff --git a/trustgraph-base/trustgraph/base/structured_query_client.py b/trustgraph-base/trustgraph/base/structured_query_client.py index 84d6bff3..49b30cd1 100644 --- a/trustgraph-base/trustgraph/base/structured_query_client.py +++ b/trustgraph-base/trustgraph/base/structured_query_client.py @@ -2,11 +2,10 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import StructuredQueryRequest, StructuredQueryResponse class StructuredQueryClient(RequestResponse): - async def structured_query(self, question, user="trustgraph", collection="default", timeout=600): + async def structured_query(self, question, collection="default", timeout=600): resp = await self.request( StructuredQueryRequest( question = question, - user = user, collection = collection ), timeout=timeout diff --git a/trustgraph-base/trustgraph/base/subscriber_spec.py b/trustgraph-base/trustgraph/base/subscriber_spec.py index bf35f869..80f9b0d5 100644 --- a/trustgraph-base/trustgraph/base/subscriber_spec.py +++ b/trustgraph-base/trustgraph/base/subscriber_spec.py @@ -21,7 +21,7 @@ class SubscriberSpec(Spec): subscriber = Subscriber( backend = processor.pubsub, topic = definition["topics"][self.name], - subscription = flow.id, + subscription = flow.id + "--" + flow.workspace + "--" + flow.name, consumer_name = flow.id, schema = self.schema, metrics = subscriber_metrics, diff --git a/trustgraph-base/trustgraph/base/tool_service.py b/trustgraph-base/trustgraph/base/tool_service.py index 3ff977d1..eeaced6a 100644 --- a/trustgraph-base/trustgraph/base/tool_service.py +++ b/trustgraph-base/trustgraph/base/tool_service.py @@ -64,6 +64,7 @@ class ToolService(FlowProcessor): id = msg.properties()["id"] response = await self.invoke_tool( + flow.workspace, request.name, json.loads(request.parameters) if request.parameters else {}, ) diff --git a/trustgraph-base/trustgraph/base/tool_service_client.py b/trustgraph-base/trustgraph/base/tool_service_client.py index 81930ba0..db5946e9 100644 --- a/trustgraph-base/trustgraph/base/tool_service_client.py +++ b/trustgraph-base/trustgraph/base/tool_service_client.py @@ -11,12 +11,11 @@ logger = logging.getLogger(__name__) class ToolServiceClient(RequestResponse): """Client for invoking dynamically configured tool services.""" - async def call(self, user, config, arguments, timeout=600): + async def call(self, config, arguments, timeout=600): """ Call a tool service. Args: - user: User context for multi-tenancy config: Dict of config values (e.g., {"collection": "customers"}) arguments: Dict of arguments from LLM timeout: Request timeout in seconds @@ -26,7 +25,6 @@ class ToolServiceClient(RequestResponse): """ resp = await self.request( ToolServiceRequest( - user=user, config=json.dumps(config) if config else "{}", arguments=json.dumps(arguments) if arguments else "{}", ), @@ -38,12 +36,11 @@ class ToolServiceClient(RequestResponse): return resp.response - async def call_streaming(self, user, config, arguments, callback, timeout=600): + async def call_streaming(self, config, arguments, callback, timeout=600): """ Call a tool service with streaming response. Args: - user: User context for multi-tenancy config: Dict of config values arguments: Dict of arguments from LLM callback: Async function called with each response chunk @@ -66,7 +63,6 @@ class ToolServiceClient(RequestResponse): await self.request( ToolServiceRequest( - user=user, config=json.dumps(config) if config else "{}", arguments=json.dumps(arguments) if arguments else "{}", ), diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index a81a5cd0..2601a1e1 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -45,7 +45,7 @@ def from_value(x: Any) -> Any: class TriplesClient(RequestResponse): async def query(self, s=None, p=None, o=None, limit=20, - user="trustgraph", collection="default", + collection="default", timeout=30, g=None): resp = await self.request( @@ -54,7 +54,6 @@ class TriplesClient(RequestResponse): p = from_value(p), o = from_value(o), limit = limit, - user = user, collection = collection, g = g, ), @@ -72,7 +71,7 @@ class TriplesClient(RequestResponse): return triples async def query_stream(self, s=None, p=None, o=None, limit=20, - user="trustgraph", collection="default", + collection="default", batch_size=20, timeout=30, batch_callback=None, g=None): """ @@ -81,7 +80,6 @@ class TriplesClient(RequestResponse): Args: s, p, o: Triple pattern (None for wildcard) limit: Maximum total triples to return - user: User/keyspace collection: Collection name batch_size: Triples per batch timeout: Request timeout in seconds @@ -116,7 +114,6 @@ class TriplesClient(RequestResponse): p=from_value(p), o=from_value(o), limit=limit, - user=user, collection=collection, streaming=True, batch_size=batch_size, diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index 832ff6f1..5850307c 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -58,9 +58,13 @@ class TriplesQueryService(FlowProcessor): logger.debug(f"Handling triples query request {id}...") + workspace = flow.workspace + if request.streaming: # Streaming mode: send batches - async for batch, is_final in self.query_triples_stream(request): + async for batch, is_final in self.query_triples_stream( + workspace, request, + ): r = TriplesQueryResponse( triples=batch, error=None, @@ -70,7 +74,7 @@ class TriplesQueryService(FlowProcessor): logger.debug("Triples query streaming completed") else: # Non-streaming mode: single response - triples = await self.query_triples(request) + triples = await self.query_triples(workspace, request) logger.debug("Sending triples query response...") r = TriplesQueryResponse(triples=triples, error=None) await flow("response").send(r, properties={"id": id}) @@ -92,13 +96,13 @@ class TriplesQueryService(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def query_triples_stream(self, request): + async def query_triples_stream(self, workspace, request): """ Streaming query - yields (batch, is_final) tuples. Default implementation batches results from query_triples. Override for true streaming from backend. """ - triples = await self.query_triples(request) + triples = await self.query_triples(workspace, request) batch_size = request.batch_size if request.batch_size > 0 else 20 for i in range(0, len(triples), batch_size): diff --git a/trustgraph-base/trustgraph/base/triples_store_service.py b/trustgraph-base/trustgraph/base/triples_store_service.py index abd3aab8..7c44fe29 100644 --- a/trustgraph-base/trustgraph/base/triples_store_service.py +++ b/trustgraph-base/trustgraph/base/triples_store_service.py @@ -45,7 +45,10 @@ class TriplesStoreService(FlowProcessor): request = msg.value() - await self.store_triples(request) + # Workspace is derived from the flow the message arrived on, + # not from fields in the message payload. Topic routing is + # the isolation boundary. + await self.store_triples(flow.workspace, request) except TooManyRequests as e: raise e diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index 78b62688..25c1af94 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -33,6 +33,7 @@ class ConfigClient(BaseClient): subscriber=None, input_queue=None, output_queue=None, + workspace="default", **pubsub_config, ): @@ -51,10 +52,13 @@ class ConfigClient(BaseClient): **pubsub_config, ) + self.workspace = workspace + def get(self, keys, timeout=300): resp = self.call( operation="get", + workspace=self.workspace, keys=[ ConfigKey( type = k["type"], @@ -78,6 +82,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="list", + workspace=self.workspace, type=type, timeout=timeout ) @@ -88,6 +93,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="getvalues", + workspace=self.workspace, type=type, timeout=timeout ) @@ -101,10 +107,31 @@ class ConfigClient(BaseClient): for v in resp.values ] + def getvalues_all_ws(self, type, timeout=300): + """Fetch all values of a given type across all workspaces. + Returns a list of dicts including a 'workspace' field.""" + + resp = self.call( + operation="getvalues-all-ws", + type=type, + timeout=timeout + ) + + return [ + { + "workspace": v.workspace, + "type": v.type, + "key": v.key, + "value": v.value, + } + for v in resp.values + ] + def delete(self, keys, timeout=300): resp = self.call( operation="delete", + workspace=self.workspace, keys=[ ConfigKey( type = k["type"], @@ -121,6 +148,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="put", + workspace=self.workspace, values=[ ConfigValue( type = v["type"], @@ -138,6 +166,7 @@ class ConfigClient(BaseClient): resp = self.call( operation="config", + workspace=self.workspace, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index ebbad397..ad20206c 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -35,11 +35,11 @@ class DocumentEmbeddingsClient(BaseClient): ) def request( - self, vector, user="trustgraph", collection="default", + self, vector, collection="default", limit=10, timeout=300 ): return self.call( - user=user, collection=collection, + collection=collection, vector=vector, limit=limit, timeout=timeout ).chunks diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 365ea09d..e8deaafd 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -33,14 +33,13 @@ class DocumentRagClient(BaseClient): output_schema=DocumentRagResponse, ) - def request(self, query, user="trustgraph", collection="default", + def request(self, query, collection="default", chunk_callback=None, explain_callback=None, timeout=300): """ Request a document RAG query with optional streaming callbacks. Args: query: The question to ask - user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications @@ -71,7 +70,7 @@ class DocumentRagClient(BaseClient): return False # Continue receiving self.call( - query=query, user=user, collection=collection, + query=query, collection=collection, inspect=inspect, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index 62a55609..9b38a11b 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -35,11 +35,11 @@ class GraphEmbeddingsClient(BaseClient): ) def request( - self, vector, user="trustgraph", collection="default", + self, vector, collection="default", limit=10, timeout=300 ): return self.call( - user=user, collection=collection, + collection=collection, vector=vector, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 0d33bf91..f1d2374e 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -34,7 +34,7 @@ class GraphRagClient(BaseClient): ) def request( - self, query, user="trustgraph", collection="default", + self, query, collection="default", chunk_callback=None, explain_callback=None, timeout=500 @@ -44,7 +44,6 @@ class GraphRagClient(BaseClient): Args: query: The question to ask - user: User identifier collection: Collection identifier chunk_callback: Optional callback(text, end_of_stream) for text chunks explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications @@ -76,7 +75,7 @@ class GraphRagClient(BaseClient): return False # Continue receiving self.call( - user=user, collection=collection, query=query, + collection=collection, query=query, inspect=inspect, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py index 6e10de29..c2329f9d 100644 --- a/trustgraph-base/trustgraph/clients/row_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -35,11 +35,11 @@ class RowEmbeddingsClient(BaseClient): ) def request( - self, vector, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, collection="default", index_name=None, limit=10, timeout=300 ): kwargs = dict( - user=user, collection=collection, + collection=collection, vector=vector, schema_name=schema_name, limit=limit, timeout=timeout ) diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index 403d02ea..864f4442 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -45,16 +45,15 @@ class TriplesQueryClient(BaseClient): return Term(type=LITERAL, value=ent) def request( - self, + self, s, p, o, - user="trustgraph", collection="default", + collection="default", limit=10, timeout=120, ): return self.call( s=self.create_value(s), p=self.create_value(p), o=self.create_value(o), - user=user, collection=collection, limit=limit, timeout=timeout, diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 7df59907..d1e13e33 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -13,7 +13,6 @@ class AgentRequestTranslator(MessageTranslator): state=data.get("state", None), group=data.get("group", None), history=data.get("history", []), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), streaming=data.get("streaming", False), session_id=data.get("session_id", ""), @@ -33,7 +32,6 @@ class AgentRequestTranslator(MessageTranslator): "state": obj.state, "group": obj.group, "history": obj.history, - "user": obj.user, "collection": getattr(obj, "collection", "default"), "streaming": getattr(obj, "streaming", False), "session_id": getattr(obj, "session_id", ""), diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index 2e39e8c2..cd07bc99 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -9,7 +9,7 @@ class CollectionManagementRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest: return CollectionManagementRequest( operation=data.get("operation"), - user=data.get("user"), + workspace=data.get("workspace", ""), collection=data.get("collection"), timestamp=data.get("timestamp"), name=data.get("name"), @@ -24,8 +24,8 @@ class CollectionManagementRequestTranslator(MessageTranslator): if obj.operation is not None: result["operation"] = obj.operation - if obj.user is not None: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection is not None: result["collection"] = obj.collection if obj.timestamp is not None: @@ -63,7 +63,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): if "collections" in data: for coll_data in data["collections"]: collections.append(CollectionMetadata( - user=coll_data.get("user"), collection=coll_data.get("collection"), name=coll_data.get("name"), description=coll_data.get("description"), @@ -91,7 +90,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): result["collections"] = [] for coll in obj.collections: result["collections"].append({ - "user": coll.user, "collection": coll.collection, "name": coll.name, "description": coll.description, diff --git a/trustgraph-base/trustgraph/messaging/translators/config.py b/trustgraph-base/trustgraph/messaging/translators/config.py index e166362a..223db6c8 100644 --- a/trustgraph-base/trustgraph/messaging/translators/config.py +++ b/trustgraph-base/trustgraph/messaging/translators/config.py @@ -23,13 +23,15 @@ class ConfigRequestTranslator(MessageTranslator): ConfigValue( type=v["type"], key=v["key"], - value=v["value"] + value=v["value"], + workspace=v.get("workspace", ""), ) for v in data["values"] ] return ConfigRequest( operation=data.get("operation"), + workspace=data.get("workspace", ""), keys=keys, type=data.get("type"), values=values @@ -37,10 +39,13 @@ class ConfigRequestTranslator(MessageTranslator): def encode(self, obj: ConfigRequest) -> Dict[str, Any]: result = {} - + if obj.operation is not None: result["operation"] = obj.operation + if obj.workspace is not None: + result["workspace"] = obj.workspace + if obj.type is not None: result["type"] = obj.type @@ -56,13 +61,14 @@ class ConfigRequestTranslator(MessageTranslator): if obj.values is not None: result["values"] = [ { + **({"workspace": v.workspace} if v.workspace else {}), "type": v.type, "key": v.key, - "value": v.value + "value": v.value, } for v in obj.values ] - + return result @@ -81,13 +87,14 @@ class ConfigResponseTranslator(MessageTranslator): if obj.values is not None: result["values"] = [ { + **({"workspace": v.workspace} if v.workspace else {}), "type": v.type, "key": v.key, - "value": v.value + "value": v.value, } for v in obj.values ] - + if obj.directory is not None: result["directory"] = obj.directory diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index df2aa3ba..61917321 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -39,7 +39,6 @@ class DocumentTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), data=base64.b64encode(doc).decode("utf-8") @@ -56,8 +55,6 @@ class DocumentTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -79,7 +76,6 @@ class TextDocumentTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), text=text.encode("utf-8") @@ -96,8 +92,6 @@ class TextDocumentTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -115,7 +109,6 @@ class ChunkTranslator(SendTranslator): metadata=Metadata( id=data.get("id"), root=data.get("root", ""), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ), chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"] @@ -132,8 +125,6 @@ class ChunkTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection @@ -161,7 +152,6 @@ class DocumentEmbeddingsTranslator(SendTranslator): metadata=Metadata( id=metadata.get("id"), root=metadata.get("root", ""), - user=metadata.get("user", "trustgraph"), collection=metadata.get("collection", "default"), ), chunks=chunks @@ -184,8 +174,6 @@ class DocumentEmbeddingsTranslator(SendTranslator): metadata_dict["id"] = obj.metadata.id if obj.metadata.root: metadata_dict["root"] = obj.metadata.root - if obj.metadata.user: - metadata_dict["user"] = obj.metadata.user if obj.metadata.collection: metadata_dict["collection"] = obj.metadata.collection diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index fce1625e..c435ba48 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -15,7 +15,6 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): return DocumentEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) @@ -23,7 +22,6 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): return { "vector": obj.vector, "limit": obj.limit, - "user": obj.user, "collection": obj.collection } @@ -60,7 +58,6 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): return GraphEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) @@ -68,7 +65,6 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): return { "vector": obj.vector, "limit": obj.limit, - "user": obj.user, "collection": obj.collection } @@ -108,7 +104,6 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): return RowEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), schema_name=data.get("schema_name", ""), index_name=data.get("index_name") @@ -118,7 +113,6 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): result = { "vector": obj.vector, "limit": obj.limit, - "user": obj.user, "collection": obj.collection, "schema_name": obj.schema_name, } diff --git a/trustgraph-base/trustgraph/messaging/translators/flow.py b/trustgraph-base/trustgraph/messaging/translators/flow.py index 2047475e..07304c18 100644 --- a/trustgraph-base/trustgraph/messaging/translators/flow.py +++ b/trustgraph-base/trustgraph/messaging/translators/flow.py @@ -9,18 +9,21 @@ class FlowRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> FlowRequest: return FlowRequest( operation=data.get("operation"), + workspace=data.get("workspace", ""), blueprint_name=data.get("blueprint-name"), blueprint_definition=data.get("blueprint-definition"), description=data.get("description"), flow_id=data.get("flow-id"), parameters=data.get("parameters") ) - + def encode(self, obj: FlowRequest) -> Dict[str, Any]: result = {} if obj.operation is not None: result["operation"] = obj.operation + if obj.workspace is not None: + result["workspace"] = obj.workspace if obj.blueprint_name is not None: result["blueprint-name"] = obj.blueprint_name if obj.blueprint_definition is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f819dc9c..83cdbbf4 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -21,7 +21,6 @@ class KnowledgeRequestTranslator(MessageTranslator): metadata=Metadata( id=data["triples"]["metadata"]["id"], root=data["triples"]["metadata"].get("root", ""), - user=data["triples"]["metadata"]["user"], collection=data["triples"]["metadata"]["collection"] ), triples=self.subgraph_translator.decode(data["triples"]["triples"]), @@ -33,7 +32,6 @@ class KnowledgeRequestTranslator(MessageTranslator): metadata=Metadata( id=data["graph-embeddings"]["metadata"]["id"], root=data["graph-embeddings"]["metadata"].get("root", ""), - user=data["graph-embeddings"]["metadata"]["user"], collection=data["graph-embeddings"]["metadata"]["collection"] ), entities=[ @@ -47,7 +45,7 @@ class KnowledgeRequestTranslator(MessageTranslator): return KnowledgeRequest( operation=data.get("operation"), - user=data.get("user"), + workspace=data.get("workspace", ""), id=data.get("id"), flow=data.get("flow"), collection=data.get("collection"), @@ -60,8 +58,8 @@ class KnowledgeRequestTranslator(MessageTranslator): if obj.operation: result["operation"] = obj.operation - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.id: result["id"] = obj.id if obj.flow: @@ -74,7 +72,6 @@ class KnowledgeRequestTranslator(MessageTranslator): "metadata": { "id": obj.triples.metadata.id, "root": obj.triples.metadata.root, - "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.encode(obj.triples.triples), @@ -85,7 +82,6 @@ class KnowledgeRequestTranslator(MessageTranslator): "metadata": { "id": obj.graph_embeddings.metadata.id, "root": obj.graph_embeddings.metadata.root, - "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ @@ -122,7 +118,6 @@ class KnowledgeResponseTranslator(MessageTranslator): "metadata": { "id": obj.triples.metadata.id, "root": obj.triples.metadata.root, - "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, "triples": self.subgraph_translator.encode(obj.triples.triples), @@ -136,7 +131,6 @@ class KnowledgeResponseTranslator(MessageTranslator): "metadata": { "id": obj.graph_embeddings.metadata.id, "root": obj.graph_embeddings.metadata.root, - "user": obj.graph_embeddings.metadata.user, "collection": obj.graph_embeddings.metadata.collection, }, "entities": [ diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py index 7c77c39c..d528097e 100644 --- a/trustgraph-base/trustgraph/messaging/translators/library.py +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -49,7 +49,7 @@ class LibraryRequestTranslator(MessageTranslator): document_metadata=doc_metadata, processing_metadata=proc_metadata, content=content, - user=data.get("user", ""), + workspace=data.get("workspace", ""), collection=data.get("collection", ""), criteria=criteria, # Chunked upload fields @@ -76,8 +76,8 @@ class LibraryRequestTranslator(MessageTranslator): result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata) if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.criteria is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py index 3e141c19..9da5d5c0 100644 --- a/trustgraph-base/trustgraph/messaging/translators/metadata.py +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -19,7 +19,7 @@ class DocumentMetadataTranslator(Translator): title=data.get("title"), comments=data.get("comments"), metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [], - user=data.get("user"), + workspace=data.get("workspace"), tags=data.get("tags"), parent_id=data.get("parent-id", ""), document_type=data.get("document-type", "source"), @@ -40,8 +40,8 @@ class DocumentMetadataTranslator(Translator): result["comments"] = obj.comments if obj.metadata is not None: result["metadata"] = self.subgraph_translator.encode(obj.metadata) - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.tags is not None: result["tags"] = obj.tags if obj.parent_id: @@ -61,7 +61,7 @@ class ProcessingMetadataTranslator(Translator): document_id=data.get("document-id"), time=data.get("time"), flow=data.get("flow"), - user=data.get("user"), + workspace=data.get("workspace"), collection=data.get("collection"), tags=data.get("tags") ) @@ -77,8 +77,8 @@ class ProcessingMetadataTranslator(Translator): result["time"] = obj.time if obj.flow: result["flow"] = obj.flow - if obj.user: - result["user"] = obj.user + if obj.workspace: + result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.tags is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index e37b76e1..fe766522 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -10,7 +10,6 @@ class DocumentRagRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> DocumentRagQuery: return DocumentRagQuery( query=data["query"], - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), doc_limit=int(data.get("doc-limit", 20)), streaming=data.get("streaming", False) @@ -19,7 +18,6 @@ class DocumentRagRequestTranslator(MessageTranslator): def encode(self, obj: DocumentRagQuery) -> Dict[str, Any]: return { "query": obj.query, - "user": obj.user, "collection": obj.collection, "doc-limit": obj.doc_limit, "streaming": getattr(obj, "streaming", False) @@ -96,7 +94,6 @@ class GraphRagRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> GraphRagQuery: return GraphRagQuery( query=data["query"], - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), entity_limit=int(data.get("entity-limit", 50)), triple_limit=int(data.get("triple-limit", 30)), @@ -110,7 +107,6 @@ class GraphRagRequestTranslator(MessageTranslator): def encode(self, obj: GraphRagQuery) -> Dict[str, Any]: return { "query": obj.query, - "user": obj.user, "collection": obj.collection, "entity-limit": obj.entity_limit, "triple-limit": obj.triple_limit, diff --git a/trustgraph-base/trustgraph/messaging/translators/rows_query.py b/trustgraph-base/trustgraph/messaging/translators/rows_query.py index 6153901c..3d3f682f 100644 --- a/trustgraph-base/trustgraph/messaging/translators/rows_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/rows_query.py @@ -9,7 +9,6 @@ class RowsQueryRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> RowsQueryRequest: return RowsQueryRequest( - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), query=data.get("query", ""), variables=data.get("variables", {}), @@ -18,7 +17,6 @@ class RowsQueryRequestTranslator(MessageTranslator): def encode(self, obj: RowsQueryRequest) -> Dict[str, Any]: result = { - "user": obj.user, "collection": obj.collection, "query": obj.query, "variables": dict(obj.variables) if obj.variables else {} diff --git a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py index a8b13865..e69d998a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py @@ -12,7 +12,6 @@ class SparqlQueryRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> SparqlQueryRequest: return SparqlQueryRequest( - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), query=data.get("query", ""), limit=int(data.get("limit", 10000)), @@ -22,7 +21,6 @@ class SparqlQueryRequestTranslator(MessageTranslator): def encode(self, obj: SparqlQueryRequest) -> Dict[str, Any]: return { - "user": obj.user, "collection": obj.collection, "query": obj.query, "limit": obj.limit, diff --git a/trustgraph-base/trustgraph/messaging/translators/structured_query.py b/trustgraph-base/trustgraph/messaging/translators/structured_query.py index 6b0b38a1..bb76f3e7 100644 --- a/trustgraph-base/trustgraph/messaging/translators/structured_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/structured_query.py @@ -10,14 +10,12 @@ class StructuredQueryRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> StructuredQueryRequest: return StructuredQueryRequest( question=data.get("question", ""), - user=data.get("user", "trustgraph"), # Default fallback - collection=data.get("collection", "default") # Default fallback + collection=data.get("collection", "default") ) - + def encode(self, obj: StructuredQueryRequest) -> Dict[str, Any]: return { "question": obj.question, - "user": obj.user, "collection": obj.collection } diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py index 21d2698f..7a48ff15 100644 --- a/trustgraph-base/trustgraph/messaging/translators/triples.py +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -22,16 +22,14 @@ class TriplesQueryRequestTranslator(MessageTranslator): o=o, g=g, limit=int(data.get("limit", 10000)), - user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), streaming=data.get("streaming", False), batch_size=int(data.get("batch-size", 20)), ) - + def encode(self, obj: TriplesQueryRequest) -> Dict[str, Any]: result = { "limit": obj.limit, - "user": obj.user, "collection": obj.collection, "streaming": obj.streaming, "batch-size": obj.batch_size, diff --git a/trustgraph-base/trustgraph/schema/core/metadata.py b/trustgraph-base/trustgraph/schema/core/metadata.py index a37a8d62..a307db4f 100644 --- a/trustgraph-base/trustgraph/schema/core/metadata.py +++ b/trustgraph-base/trustgraph/schema/core/metadata.py @@ -8,6 +8,7 @@ class Metadata: # Root document identifier (set by librarian, preserved through pipeline) root: str = "" - # Collection management - user: str = "" + # Collection the message belongs to. Workspace is NOT carried on the + # message — consumers derive it from flow.workspace (the flow the + # message arrived on), which is the trusted isolation boundary. collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 0c4a9f7c..37969566 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings # <- (error) # list-kg-cores -# -> (user) +# -> (workspace) # <- () # <- (error) @@ -27,8 +27,8 @@ class KnowledgeRequest: # load-kg-core, unload-kg-core operation: str = "" - # list-kg-cores, delete-kg-core, put-kg-core - user: str = "" + # Workspace the cores belong to. Partition / isolation boundary. + workspace: str = "" # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # load-kg-core, unload-kg-core diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index cd4a2b45..50ac1dd1 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -22,7 +22,6 @@ class AgentStep: action: str = "" arguments: dict[str, str] = field(default_factory=dict) observation: str = "" - user: str = "" # User context for the step step_type: str = "" # "react", "plan", "execute", "decompose", "synthesise" plan: list[PlanStep] = field(default_factory=list) # Plan steps (for plan-then-execute) subagent_results: dict[str, str] = field(default_factory=dict) # Subagent results keyed by goal @@ -33,7 +32,6 @@ class AgentRequest: state: str = "" group: list[str] | None = None history: list[AgentStep] = field(default_factory=list) - user: str = "" # User context for multi-tenancy collection: str = "default" # Collection for provenance traces streaming: bool = False # Enable streaming response delivery (default false) session_id: str = "" # For provenance tracking across iterations diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index f4b5fc6e..13dd0607 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -13,7 +13,6 @@ from ..core.topic import queue @dataclass class CollectionMetadata: """Collection metadata record""" - user: str = "" collection: str = "" name: str = "" description: str = "" @@ -23,11 +22,17 @@ class CollectionMetadata: @dataclass class CollectionManagementRequest: - """Request for collection management operations""" + """Request for collection management operations. + + Collection-management is a global (non-flow-scoped) service, so the + workspace has to travel on the wire — it's the isolation boundary + for which workspace's collections the request operates on. + """ operation: str = "" # e.g., "delete-collection" - # For 'list-collections' - user: str = "" + # Workspace the collection belongs to. + workspace: str = "" + collection: str = "" timestamp: str = "" # ISO timestamp name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index c08e96d7..3bcbc72c 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -7,12 +7,19 @@ from ..core.primitives import Error ############################################################################ # Config service: -# get(keys) -> (version, values) -# list(type) -> (version, values) -# getvalues(type) -> (version, values) -# put(values) -> () -# delete(keys) -> () -# config() -> (version, config) +# get(workspace, keys) -> (version, values) +# list(workspace, type) -> (version, directory) +# getvalues(workspace, type) -> (version, values) +# getvalues-all-ws(type) -> (version, values with workspace field) +# put(workspace, values) -> () +# delete(workspace, keys) -> () +# config(workspace) -> (version, config) +# +# Most operations are scoped to a workspace. The workspace field on the +# request identifies which workspace's config to read or modify. +# getvalues-all-ws returns values across all workspaces for a single +# type — used by shared processors to load type-scoped config at startup. + @dataclass class ConfigKey: type: str = "" @@ -23,16 +30,24 @@ class ConfigValue: type: str = "" key: str = "" value: str = "" + # Populated by getvalues-all-ws responses so callers can identify + # which workspace each value belongs to. Empty otherwise. + workspace: str = "" -# Prompt services, abstract the prompt generation @dataclass class ConfigRequest: - operation: str = "" # get, list, getvalues, delete, put, config + # Operations: get, list, getvalues, getvalues-all-ws, delete, put, + # config + operation: str = "" + + # Workspace scope — required on all operations except + # getvalues-all-ws which spans all workspaces. + workspace: str = "" # get, delete keys: list[ConfigKey] = field(default_factory=list) - # list, getvalues + # list, getvalues, getvalues-all-ws type: str = "" # put @@ -58,7 +73,12 @@ class ConfigResponse: @dataclass class ConfigPush: version: int = 0 - types: list[str] = field(default_factory=list) + + # Dict of config type -> list of affected workspaces. + # Handlers look up their registered type and get the list of + # workspaces that need refreshing. + # e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]} + changes: dict[str, list[str]] = field(default_factory=dict) config_request_queue = queue('config', cls='request') config_response_queue = queue('config', cls='response') diff --git a/trustgraph-base/trustgraph/schema/services/flow.py b/trustgraph-base/trustgraph/schema/services/flow.py index 0d497dd7..586c160d 100644 --- a/trustgraph-base/trustgraph/schema/services/flow.py +++ b/trustgraph-base/trustgraph/schema/services/flow.py @@ -17,12 +17,14 @@ from ..core.primitives import Error # start_flow(flowid, blueprintname) -> () # stop_flow(flowid) -> () -# Prompt services, abstract the prompt generation @dataclass class FlowRequest: operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint # list-flows, get-flow, start-flow, stop-flow + # Workspace scope — all operations act within this workspace + workspace: str = "" + # get_blueprint, put_blueprint, delete_blueprint, start_flow blueprint_name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index f5d4592c..961b47dc 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -43,12 +43,12 @@ from ..core.metadata import Metadata # <- (error) # list-documents -# -> (user, collection?) +# -> (workspace, collection?) # <- (document_metadata[]) # <- (error) # list-processing -# -> (user, collection?) +# -> (workspace, collection?) # <- (processing_metadata[]) # <- (error) @@ -78,7 +78,7 @@ from ..core.metadata import Metadata # <- (error) # list-uploads -# -> (user) +# -> (workspace) # <- (uploads[]) # <- (error) @@ -90,7 +90,7 @@ class DocumentMetadata: title: str = "" comments: str = "" metadata: list[Triple] = field(default_factory=list) - user: str = "" + workspace: str = "" tags: list[str] = field(default_factory=list) # Child document support parent_id: str = "" # Empty for top-level docs, set for children @@ -107,7 +107,7 @@ class ProcessingMetadata: document_id: str = "" time: int = 0 flow: str = "" - user: str = "" + workspace: str = "" collection: str = "" tags: list[str] = field(default_factory=list) @@ -162,8 +162,8 @@ class LibrarianRequest: # add-document, upload-chunk content: bytes = b"" - # list-documents, list-processing, list-uploads - user: str = "" + # Workspace scopes every library operation. + workspace: str = "" # list-documents?, list-processing? collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index f9f08658..9c11a157 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -11,7 +11,6 @@ from ..core.topic import queue class GraphEmbeddingsRequest: vector: list[float] = field(default_factory=list) limit: int = 0 - user: str = "" collection: str = "" @dataclass @@ -31,7 +30,6 @@ class GraphEmbeddingsResponse: @dataclass class TriplesQueryRequest: - user: str = "" collection: str = "" s: Term | None = None p: Term | None = None @@ -55,7 +53,6 @@ class TriplesQueryResponse: class DocumentEmbeddingsRequest: vector: list[float] = field(default_factory=list) limit: int = 0 - user: str = "" collection: str = "" @dataclass @@ -89,7 +86,6 @@ class RowEmbeddingsRequest: """Request for row embeddings semantic search""" vector: list[float] = field(default_factory=list) # Query vector limit: int = 10 # Max results to return - user: str = "" # User/keyspace collection: str = "" # Collection name schema_name: str = "" # Schema name to search within index_name: str | None = None # Optional: filter to specific index diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index a1af9170..e937e720 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -8,7 +8,6 @@ from ..core.primitives import Error, Term, Triple @dataclass class GraphRagQuery: query: str = "" - user: str = "" collection: str = "" entity_limit: int = 0 triple_limit: int = 0 @@ -40,7 +39,6 @@ class GraphRagResponse: @dataclass class DocumentRagQuery: query: str = "" - user: str = "" collection: str = "" doc_limit: int = 0 streaming: bool = False diff --git a/trustgraph-base/trustgraph/schema/services/rows_query.py b/trustgraph-base/trustgraph/schema/services/rows_query.py index e3c4f14c..ea0759f1 100644 --- a/trustgraph-base/trustgraph/schema/services/rows_query.py +++ b/trustgraph-base/trustgraph/schema/services/rows_query.py @@ -15,7 +15,6 @@ class GraphQLError: @dataclass class RowsQueryRequest: - user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest) collection: str = "" # Data collection identifier (required for partition key) query: str = "" # GraphQL query string variables: dict[str, str] = field(default_factory=dict) # GraphQL variables diff --git a/trustgraph-base/trustgraph/schema/services/sparql_query.py b/trustgraph-base/trustgraph/schema/services/sparql_query.py index 62c02c93..a5ed502f 100644 --- a/trustgraph-base/trustgraph/schema/services/sparql_query.py +++ b/trustgraph-base/trustgraph/schema/services/sparql_query.py @@ -16,7 +16,6 @@ class SparqlBinding: @dataclass class SparqlQueryRequest: - user: str = "" collection: str = "" query: str = "" # SPARQL query string limit: int = 10000 # Safety limit on results diff --git a/trustgraph-base/trustgraph/schema/services/structured_query.py b/trustgraph-base/trustgraph/schema/services/structured_query.py index 5f54ac16..272643ac 100644 --- a/trustgraph-base/trustgraph/schema/services/structured_query.py +++ b/trustgraph-base/trustgraph/schema/services/structured_query.py @@ -9,7 +9,6 @@ from ..core.primitives import Error @dataclass class StructuredQueryRequest: question: str = "" - user: str = "" # Cassandra keyspace identifier collection: str = "" # Data collection identifier @dataclass diff --git a/trustgraph-base/trustgraph/schema/services/tool_service.py b/trustgraph-base/trustgraph/schema/services/tool_service.py index 18315f29..a42fd5e3 100644 --- a/trustgraph-base/trustgraph/schema/services/tool_service.py +++ b/trustgraph-base/trustgraph/schema/services/tool_service.py @@ -7,8 +7,6 @@ from ..core.primitives import Error @dataclass class ToolServiceRequest: """Request to a dynamically configured tool service.""" - # User context for multi-tenancy - user: str = "" # Config values (collection, etc.) as JSON config: str = "" # Arguments from LLM as JSON diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index a60b2bba..a5738449 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -95,6 +95,8 @@ tg-list-config-items = "trustgraph.cli.list_config_items:main" tg-get-config-item = "trustgraph.cli.get_config_item:main" tg-put-config-item = "trustgraph.cli.put_config_item:main" tg-delete-config-item = "trustgraph.cli.delete_config_item:main" +tg-export-workspace-config = "trustgraph.cli.export_workspace_config:main" +tg-import-workspace-config = "trustgraph.cli.import_workspace_config:main" tg-list-collections = "trustgraph.cli.list_collections:main" tg-set-collection = "trustgraph.cli.set_collection:main" tg-delete-collection = "trustgraph.cli.delete_collection:main" diff --git a/trustgraph-cli/trustgraph/cli/add_library_document.py b/trustgraph-cli/trustgraph/cli/add_library_document.py index 3273e63d..8d08d11a 100644 --- a/trustgraph-cli/trustgraph/cli/add_library_document.py +++ b/trustgraph-cli/trustgraph/cli/add_library_document.py @@ -15,17 +15,17 @@ from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import DigitalDocument default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") class Loader: def __init__( - self, id, url, user, metadata, title, comments, kind, tags - ): + self, id, url, metadata, title, comments, kind, tags + , token=None, workspace="default"): - self.api = Api(url).library() + self.api = Api(url, token=token, workspace=workspace).library() - self.user = user self.metadata = metadata self.title = title self.comments = comments @@ -55,13 +55,13 @@ class Loader: else: id = hash(data) id = to_uri(PREF_DOC, id) - + self.metadata.id = id self.api.add_document( - document=data, id=id, metadata=self.metadata, - user=self.user, kind=self.kind, title=self.title, + document=data, id=id, metadata=self.metadata, + kind=self.kind, title=self.title, comments=self.comments, tags=self.tags ) @@ -83,11 +83,16 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -186,12 +191,13 @@ def main(): p = Loader( id=args.identifier, url=args.url, - user=args.user, metadata=document, title=args.name, comments=args.description, kind=args.kind, tags=args.tags, + token=args.token, + workspace=args.workspace, ) p.load(args.files) diff --git a/trustgraph-cli/trustgraph/cli/delete_collection.py b/trustgraph-cli/trustgraph/cli/delete_collection.py index 3e19ac09..aedd801a 100644 --- a/trustgraph-cli/trustgraph/cli/delete_collection.py +++ b/trustgraph-cli/trustgraph/cli/delete_collection.py @@ -7,9 +7,11 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_collection(url, user, collection, confirm): + +def delete_collection(url, collection, confirm, token=None, workspace="default"): if not confirm: response = input(f"Are you sure you want to delete collection '{collection}' and all its data? (y/N): ") @@ -17,9 +19,9 @@ def delete_collection(url, user, collection, confirm): print("Operation cancelled.") return - api = Api(url).collection() + api = Api(url, token=token, workspace=workspace).collection() - api.delete_collection(user=user, collection=collection) + api.delete_collection(collection=collection) print(f"Collection '{collection}' deleted successfully.") @@ -41,27 +43,34 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-y', '--yes', action='store_true', help='Skip confirmation prompt' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: delete_collection( url = args.api_url, - user = args.user, collection = args.collection, - confirm = args.yes + confirm = args.yes, + token = args.token, + workspace = args.workspace, ) except Exception as e: @@ -69,4 +78,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/delete_config_item.py b/trustgraph-cli/trustgraph/cli/delete_config_item.py index cf4cba93..801c2a99 100644 --- a/trustgraph-cli/trustgraph/cli/delete_config_item.py +++ b/trustgraph-cli/trustgraph/cli/delete_config_item.py @@ -9,10 +9,11 @@ from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_config_item(url, config_type, key, token=None): +def delete_config_item(url, config_type, key, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_key = ConfigKey(type=config_type, key=key) api.delete([config_key]) @@ -50,6 +51,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -59,6 +66,8 @@ def main(): config_type=args.type, key=args.key, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py index 9ff8aeba..62140f0e 100644 --- a/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/delete_flow_blueprint.py @@ -9,10 +9,13 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_flow_blueprint(url, blueprint_name): +def delete_flow_blueprint(url, blueprint_name, token=None, + workspace="default"): - api = Api(url).flow() + api = Api(url, token=token, workspace=workspace).flow() blueprint_names = api.delete_blueprint(blueprint_name) @@ -29,6 +32,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', help=f'Flow blueprint name', @@ -41,6 +56,8 @@ def main(): delete_flow_blueprint( url=args.api_url, blueprint_name=args.blueprint_name, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_kg_core.py b/trustgraph-cli/trustgraph/cli/delete_kg_core.py index 81f95e45..0e0753e0 100644 --- a/trustgraph-cli/trustgraph/cli/delete_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/delete_kg_core.py @@ -1,20 +1,20 @@ """ -Deletes a flow class +Deletes a knowledge core """ import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def delete_kg_core(url, user, id): +def delete_kg_core(url, id, token=None, workspace="default"): - api = Api(url).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.delete_kg_core(user = user, id = id) + api.delete_kg_core(id=id) def main(): @@ -29,26 +29,33 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', - ) - parser.add_argument( '--id', '--identifier', required=True, help=f'Knowledge core ID', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: delete_kg_core( url=args.api_url, - user=args.user, id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py b/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py index a3ae7e77..eed9ed21 100644 --- a/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/delete_mcp_tool.py @@ -10,12 +10,16 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def delete_mcp_tool( url : str, id : str, + token=None, + workspace="default", ): - api = Api(url).config() + api = Api(url, token=token, workspace=workspace).config() # Check if the tool exists first try: @@ -73,6 +77,18 @@ def main(): help='MCP tool ID to delete', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -81,8 +97,10 @@ def main(): raise RuntimeError("Must specify --id for MCP tool to delete") delete_mcp_tool( - url=args.api_url, - id=args.id + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/delete_tool.py b/trustgraph-cli/trustgraph/cli/delete_tool.py index 961c9aa8..50f43fdd 100644 --- a/trustgraph-cli/trustgraph/cli/delete_tool.py +++ b/trustgraph-cli/trustgraph/cli/delete_tool.py @@ -12,12 +12,16 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def delete_tool( url : str, id : str, + token=None, + workspace="default", ): - api = Api(url).config() + api = Api(url, token=token, workspace=workspace).config() # Check if the tool configuration exists try: @@ -78,6 +82,18 @@ def main(): help='Tool ID to delete', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -86,8 +102,10 @@ def main(): raise RuntimeError("Must specify --id for tool to delete") delete_tool( - url=args.api_url, - id=args.id + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/export_workspace_config.py b/trustgraph-cli/trustgraph/cli/export_workspace_config.py new file mode 100644 index 00000000..feef97de --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/export_workspace_config.py @@ -0,0 +1,114 @@ +""" +Exports a curated subset of a workspace's configuration to a JSON file +for later reload into another workspace (useful for cloning test setups). + +The subset covers the config types that define workspace behaviour: +mcp-tool, tool, flow-blueprint, token-cost, agent-pattern, +agent-task-type, parameter-type, interface-description, prompt. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +EXPORT_TYPES = [ + "mcp-tool", + "tool", + "flow-blueprint", + "token-cost", + "agent-pattern", + "agent-task-type", + "parameter-type", + "interface-description", + "prompt", +] + + +def export_workspace_config(url, workspace, output, token=None): + + api = Api(url, token=token, workspace=workspace).config() + + config, version = api.all() + + subset = {} + for t in EXPORT_TYPES: + if t in config: + subset[t] = config[t] + + payload = { + "source_workspace": workspace, + "source_version": version, + "config": subset, + } + + if output == "-": + json.dump(payload, sys.stdout, indent=2) + sys.stdout.write("\n") + else: + with open(output, "w") as f: + json.dump(payload, f, indent=2) + + total = sum(len(v) for v in subset.values()) + print( + f"Exported {total} items across {len(subset)} types " + f"from workspace '{workspace}' (version {version}).", + file=sys.stderr, + ) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-export-workspace-config', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Source workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help='Output JSON file path (use "-" for stdout)', + ) + + args = parser.parse_args() + + try: + + export_workspace_config( + url=args.api_url, + workspace=args.workspace, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_config_item.py b/trustgraph-cli/trustgraph/cli/get_config_item.py index c2421e94..028cc064 100644 --- a/trustgraph-cli/trustgraph/cli/get_config_item.py +++ b/trustgraph-cli/trustgraph/cli/get_config_item.py @@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_config_item(url, config_type, key, format_type, token=None): +def get_config_item(url, config_type, key, format_type, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_key = ConfigKey(type=config_type, key=key) values = api.get([config_key]) @@ -66,6 +68,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -76,6 +84,7 @@ def main(): key=args.key, format_type=args.format, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py index 3d70f37d..62fa7ca2 100644 --- a/trustgraph-cli/trustgraph/cli/get_document_content.py +++ b/trustgraph-cli/trustgraph/cli/get_document_content.py @@ -9,21 +9,19 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_content(url, user, document_id, output_file, token=None): +def get_content(url, document_id, output_file, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - content = api.get_document_content(user=user, id=document_id) + content = api.get_document_content(id=document_id) if output_file: with open(output_file, 'wb') as f: f.write(content) print(f"Written {len(content)} bytes to {output_file}") else: - # Write to stdout - # Try to decode as text, fall back to binary info try: text = content.decode('utf-8') print(text) @@ -51,9 +49,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -73,10 +71,10 @@ def main(): get_content( url=args.api_url, - user=args.user, document_id=args.document_id, output_file=args.output, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py index 817b8f47..56d43a7c 100644 --- a/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/get_flow_blueprint.py @@ -9,10 +9,12 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def get_flow_blueprint(url, blueprint_name): +def get_flow_blueprint(url, blueprint_name, token=None, workspace="default"): - api = Api(url).flow() + api = Api(url, token=token, workspace=workspace).flow() cls = api.get_blueprint(blueprint_name) @@ -31,6 +33,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', required=True, @@ -44,6 +58,8 @@ def main(): get_flow_blueprint( url=args.api_url, blueprint_name=args.blueprint_name, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index b75f7155..8bee4115 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,7 +5,6 @@ to a local file in msgpack format. import argparse import os -import textwrap import uuid import asyncio import json @@ -13,17 +12,16 @@ from websockets.asyncio.client import connect import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def write_triple(f, data): msg = ( "t", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -36,9 +34,8 @@ def write_ge(f, data): "ge", { "m": { - "i": data["metadata"]["id"], + "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ @@ -52,7 +49,7 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, user, id, output, token=None): +async def fetch(url, workspace, id, output, token=None): if not url.endswith("/"): url += "/" @@ -68,10 +65,11 @@ async def fetch(url, user, id, output, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "get-kg-core", - "user": user, + "workspace": workspace, "id": id, } }) @@ -124,10 +122,11 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -154,11 +153,11 @@ def main(): asyncio.run( fetch( - url = args.url, - user = args.user, - id = args.id, - output = args.output, - token = args.token, + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) ) @@ -167,4 +166,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py index 840f8574..4d4a94b3 100644 --- a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py +++ b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py @@ -13,9 +13,9 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def term_to_rdflib(term): @@ -58,9 +58,10 @@ def term_to_rdflib(term): return rdflib.term.Literal(str(term)) -def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): +def show_graph(url, flow_id, collection, limit, batch_size, + token=None, workspace="default"): - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) g = rdflib.Graph() @@ -68,7 +69,7 @@ def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): try: for batch in flow.triples_query_stream( s=None, p=None, o=None, - user=user, collection=collection, + collection=collection, limit=limit, batch_size=batch_size, ): @@ -108,12 +109,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -126,6 +121,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-l', '--limit', type=int, @@ -147,11 +148,11 @@ def main(): show_graph( url = args.api_url, flow_id = args.flow_id, - user = args.user, collection = args.collection, limit = args.limit, batch_size = args.batch_size, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/import_workspace_config.py b/trustgraph-cli/trustgraph/cli/import_workspace_config.py new file mode 100644 index 00000000..3fe3be97 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/import_workspace_config.py @@ -0,0 +1,143 @@ +""" +Imports a workspace-config dump produced by tg-export-workspace-config +into a target workspace. Writes mcp-tool, tool, flow-blueprint, +token-cost, agent-pattern, agent-task-type, parameter-type, +interface-description and prompt items verbatim. + +Existing items with the same (type, key) are overwritten. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api +from trustgraph.api.types import ConfigValue + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +IMPORT_TYPES = { + "mcp-tool", + "tool", + "flow-blueprint", + "token-cost", + "agent-pattern", + "agent-task-type", + "parameter-type", + "interface-description", + "prompt", +} + + +def import_workspace_config(url, workspace, input_path, token=None, + dry_run=False): + + if input_path == "-": + payload = json.load(sys.stdin) + else: + with open(input_path, "r") as f: + payload = json.load(f) + + # Accept both the wrapped export format and a bare {type: {key: value}} + # dict, so hand-written files are also loadable. + if isinstance(payload, dict) and "config" in payload \ + and isinstance(payload["config"], dict): + config = payload["config"] + source = payload.get("source_workspace") + else: + config = payload + source = None + + skipped_types = set(config.keys()) - IMPORT_TYPES + if skipped_types: + print( + f"Ignoring unsupported types: {sorted(skipped_types)}", + file=sys.stderr, + ) + + values = [] + for t in IMPORT_TYPES: + items = config.get(t, {}) + for key, value in items.items(): + values.append(ConfigValue(type=t, key=key, value=value)) + + if not values: + print("Nothing to import.", file=sys.stderr) + return + + if dry_run: + print( + f"[dry-run] would import {len(values)} items into " + f"workspace '{workspace}'" + + (f" (from '{source}')" if source else "") + ) + return + + api = Api(url, token=token, workspace=workspace).config() + api.put(values) + + print( + f"Imported {len(values)} items into workspace '{workspace}'" + + (f" (from '{source}')." if source else "."), + ) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-import-workspace-config', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Target workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help='Input JSON file path (use "-" for stdin)', + ) + + parser.add_argument( + '--dry-run', + action='store_true', + help='Parse and validate the input without writing anything', + ) + + args = parser.parse_args() + + try: + + import_workspace_config( + url=args.api_url, + workspace=args.workspace, + input_path=args.input, + token=args.token, + dry_run=args.dry_run, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py index 18c240ef..d984f925 100644 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ b/trustgraph-cli/trustgraph/cli/init_trustgraph.py @@ -69,10 +69,11 @@ def ensure_namespace(url, tenant, namespace, config): print(f"Namespace {tenant}/{namespace} created.", flush=True) -def ensure_config(config, **pubsub_config): +def ensure_config(config, workspace="default", **pubsub_config): cli = ConfigClient( subscriber=subscriber, + workspace=workspace, **pubsub_config, ) @@ -147,7 +148,8 @@ def init_pulsar(pulsar_admin_url, tenant): }) -def push_config(config_json, config_file, **pubsub_config): +def push_config(config_json, config_file, workspace="default", + **pubsub_config): """Push initial config if provided.""" if config_json is not None: @@ -160,7 +162,7 @@ def push_config(config_json, config_file, **pubsub_config): print("Exception:", e, flush=True) raise e - ensure_config(dec, **pubsub_config) + ensure_config(dec, workspace=workspace, **pubsub_config) elif config_file is not None: @@ -172,7 +174,7 @@ def push_config(config_json, config_file, **pubsub_config): print("Exception:", e, flush=True) raise e - ensure_config(dec, **pubsub_config) + ensure_config(dec, workspace=workspace, **pubsub_config) else: print("No config to update.", flush=True) @@ -207,6 +209,12 @@ def main(): help=f'Tenant (default: tg)', ) + parser.add_argument( + '-w', '--workspace', + default="default", + help=f'Workspace (default: default)', + ) + add_pubsub_args(parser) args = parser.parse_args() @@ -216,7 +224,10 @@ def main(): # Extract pubsub config from args pubsub_config = { k: v for k, v in vars(args).items() - if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant') + if k not in ( + 'pulsar_admin_url', 'config', 'config_file', 'tenant', + 'workspace', + ) } while True: @@ -241,6 +252,7 @@ def main(): # Push config (works with any backend) push_config( args.config, args.config_file, + workspace=args.workspace, **pubsub_config, ) diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index b379c2df..d815aacd 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -26,7 +26,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class Outputter: @@ -115,11 +115,12 @@ def output(text, prefix="> ", width=78): print(out) def question_explainable( - url, question_text, flow_id, user, collection, - state=None, group=None, verbose=False, token=None, debug=False + url, question_text, flow_id, collection, + state=None, group=None, verbose=False, token=None, debug=False, + workspace="default", ): """Execute agent with explainability - shows provenance events inline.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -132,7 +133,6 @@ def question_explainable( # Stream agent with explainability - process events as they arrive for item in flow.agent_explain( question=question_text, - user=user, collection=collection, state=state, group=group, @@ -191,7 +191,6 @@ def question_explainable( entity = explain_client.fetch_entity( prov_id, graph=explain_graph, - user=user, collection=collection ) @@ -269,11 +268,11 @@ def question_explainable( def question( - url, question, flow_id, user, collection, + url, question, flow_id, collection, plan=None, state=None, group=None, pattern=None, verbose=False, streaming=True, token=None, explainable=False, debug=False, - show_usage=False + show_usage=False, workspace="default", ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -281,13 +280,13 @@ def question( url=url, question_text=question, flow_id=flow_id, - user=user, collection=collection, state=state, group=group, verbose=verbose, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return @@ -296,14 +295,13 @@ def question( print() # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) # Prepare request parameters request_params = { "question": question, - "user": user, "streaming": streaming, } @@ -418,6 +416,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -430,12 +434,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -502,7 +500,6 @@ def main(): url = args.url, flow_id = args.flow_id, question = args.question, - user = args.user, collection = args.collection, plan = args.plan, state = args.state, @@ -514,6 +511,7 @@ def main(): explainable = args.explainable, debug = args.debug, show_usage = args.show_usage, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py index 43bcc985..ed851dff 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, user, collection, limit, token=None): +def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): # Call document embeddings query service result = flow.document_embeddings_query( text=query_text, - user=user, collection=collection, limit=limit ) @@ -59,15 +59,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -97,10 +97,10 @@ def main(): url=args.url, flow_id=args.flow_id, query_text=args.query[0], - user=args.user, collection=args.collection, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index d566f51d..01512ac8 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -18,16 +18,17 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_doc_limit = 10 def question_explainable( - url, flow_id, question_text, user, collection, doc_limit, token=None, debug=False + url, flow_id, question_text, collection, doc_limit, token=None, debug=False, + workspace="default", ): """Execute document RAG with explainability - shows provenance events inline.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -36,8 +37,7 @@ def question_explainable( # Stream DocumentRAG with explainability - process events as they arrive for item in flow.document_rag_explain( query=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, ): if isinstance(item, RAGChunk): @@ -54,8 +54,7 @@ def question_explainable( entity = explain_client.fetch_entity( prov_id, graph=explain_graph, - user=user, - collection=collection + collection=collection ) if entity is None: @@ -98,9 +97,9 @@ def question_explainable( def question( - url, flow_id, question_text, user, collection, doc_limit, + url, flow_id, question_text, collection, doc_limit, streaming=True, token=None, explainable=False, debug=False, - show_usage=False + show_usage=False, workspace="default", ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -108,16 +107,16 @@ def question( url=url, flow_id=flow_id, question_text=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) if streaming: # Use socket client for streaming @@ -127,8 +126,7 @@ def question( try: response = flow.document_rag( query=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, streaming=True ) @@ -155,8 +153,7 @@ def question( flow = api.flow().id(flow_id) result = flow.document_rag( query=question_text, - user=user, - collection=collection, + collection=collection, doc_limit=doc_limit, ) print(result.text) @@ -189,6 +186,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -201,12 +204,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -252,7 +249,6 @@ def main(): url=args.url, flow_id=args.flow_id, question_text=args.question, - user=args.user, collection=args.collection, doc_limit=args.doc_limit, streaming=not args.no_streaming, @@ -260,6 +256,7 @@ def main(): explainable=args.explainable, debug=args.debug, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py index 699a85cf..62eaa039 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, texts, token=None): +def query(url, flow_id, texts, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -51,6 +52,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -72,6 +79,8 @@ def main(): flow_id=args.flow_id, texts=args.texts, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py index 5b0f4c67..c7237c06 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, user, collection, limit, token=None): +def query(url, flow_id, query_text, collection, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -21,7 +22,6 @@ def query(url, flow_id, query_text, user, collection, limit, token=None): # Call graph embeddings query service result = flow.graph_embeddings_query( text=query_text, - user=user, collection=collection, limit=limit ) @@ -69,15 +69,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -107,10 +107,10 @@ def main(): url=args.url, flow_id=args.flow_id, query_text=args.query[0], - user=args.user, collection=args.collection, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index c9efe54d..23d6bcac 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -22,7 +22,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_entity_limit = 50 default_triple_limit = 30 @@ -108,7 +108,7 @@ def _format_provenance_details(event_type, triples): return lines -async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph=None, debug=False): +async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False): """Query triples for a provenance node (single attempt)""" request = { "id": "triples-request", @@ -116,7 +116,6 @@ async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph= "flow": flow_id, "request": { "s": {"t": "i", "i": prov_id}, - "user": user, "collection": collection, "limit": 100 } @@ -182,10 +181,10 @@ async def _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph= return triples -async def _query_triples(ws_url, flow_id, prov_id, user, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): +async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): """Query triples for a provenance node with retries for race condition""" for attempt in range(max_retries): - triples = await _query_triples_once(ws_url, flow_id, prov_id, user, collection, graph=graph, debug=debug) + triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug) if triples: return triples # Wait before retry if empty (triples may not be stored yet) @@ -196,7 +195,7 @@ async def _query_triples(ws_url, flow_id, prov_id, user, collection, graph=None, return [] -async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, collection, debug=False): +async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False): """ Query for provenance of an edge (s, p, o) in the knowledge graph. @@ -220,7 +219,6 @@ async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, } }, - "user": user, "collection": collection, "limit": 10 } @@ -273,7 +271,6 @@ async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, "request": { "s": {"t": "i", "i": stmt_uri}, "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "user": user, "collection": collection, "limit": 10 } @@ -312,7 +309,7 @@ async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, user, return sources -async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=False): +async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False): """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" request = { "id": "parent-request", @@ -321,7 +318,6 @@ async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=Fals "request": { "s": {"t": "i", "i": uri}, "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "user": user, "collection": collection, "limit": 1 } @@ -355,7 +351,7 @@ async def _query_derived_from(ws_url, flow_id, uri, user, collection, debug=Fals return None -async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, label_cache, debug=False): +async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False): """ Trace the full provenance chain from a source URI up to the root document. Returns a list of (uri, label) tuples from leaf to root. @@ -369,11 +365,11 @@ async def _trace_provenance_chain(ws_url, flow_id, source_uri, user, collection, break # Get label for current entity - label = await _query_label(ws_url, flow_id, current, user, collection, label_cache, debug) + label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug) chain.append((current, label)) # Get parent - parent = await _query_derived_from(ws_url, flow_id, current, user, collection, debug) + parent = await _query_derived_from(ws_url, flow_id, current, collection, debug) if not parent or parent == current: break current = parent @@ -401,7 +397,7 @@ def _is_iri(value): return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") -async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debug=False): +async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False): """ Query for the rdfs:label of an IRI. Uses label_cache to avoid repeated queries. @@ -421,7 +417,6 @@ async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debu "request": { "s": {"t": "i", "i": iri}, "p": {"t": "i", "i": RDFS_LABEL}, - "user": user, "collection": collection, "limit": 1 } @@ -460,7 +455,7 @@ async def _query_label(ws_url, flow_id, iri, user, collection, label_cache, debu return label -async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, label_cache, debug=False): +async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False): """ Resolve labels for all IRI components of an edge triple. Returns (s_label, p_label, o_label). @@ -469,15 +464,15 @@ async def _resolve_edge_labels(ws_url, flow_id, edge_triple, user, collection, l p = edge_triple.get("p", "?") o = edge_triple.get("o", "?") - s_label = await _query_label(ws_url, flow_id, s, user, collection, label_cache, debug) - p_label = await _query_label(ws_url, flow_id, p, user, collection, label_cache, debug) - o_label = await _query_label(ws_url, flow_id, o, user, collection, label_cache, debug) + s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug) + p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug) + o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug) return s_label, p_label, o_label async def _question_explainable( - url, flow_id, question, user, collection, entity_limit, triple_limit, + url, flow_id, question, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, token=None, debug=False ): """Execute graph RAG with explainability - shows provenance events with details""" @@ -502,7 +497,6 @@ async def _question_explainable( "flow": flow_id, "request": { "query": question, - "user": user, "collection": collection, "entity-limit": entity_limit, "triple-limit": triple_limit, @@ -549,7 +543,7 @@ async def _question_explainable( # Query triples for this explain node (using named graph filter) triples = await _query_triples( - ws_url, flow_id, explain_id, user, collection, graph=explain_graph, debug=debug + ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug ) # Format and display details @@ -564,7 +558,7 @@ async def _question_explainable( print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) for iri in entity_iris: label = await _query_label( - ws_url, flow_id, iri, user, collection, + ws_url, flow_id, iri, collection, label_cache, debug=debug ) print(f" - {label}", file=sys.stderr) @@ -579,7 +573,7 @@ async def _question_explainable( print(f" [debug] querying edge selection: {o}", file=sys.stderr) # Query the edge selection entity (using named graph filter) edge_triples = await _query_triples( - ws_url, flow_id, o, user, collection, graph=explain_graph, debug=debug + ws_url, flow_id, o, collection, graph=explain_graph, debug=debug ) if debug: print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) @@ -597,7 +591,7 @@ async def _question_explainable( if edge_triple: # Resolve labels for edge components s_label, p_label, o_label = await _resolve_edge_labels( - ws_url, flow_id, edge_triple, user, collection, + ws_url, flow_id, edge_triple, collection, label_cache, debug=debug ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) @@ -605,21 +599,21 @@ async def _question_explainable( r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning print(f" Reason: {r_short}", file=sys.stderr) - # Trace edge provenance in the user's collection (not explainability) + # Trace edge provenance in the workspace collection (not explainability) if edge_triple: sources = await _query_edge_provenance( ws_url, flow_id, edge_triple.get("s", ""), edge_triple.get("p", ""), edge_triple.get("o", ""), - user, collection, # Use the query collection, not explainability + collection, # Use the query collection, not explainability debug=debug ) if sources: for src in sources: # Trace full chain from source to root document chain = await _trace_provenance_chain( - ws_url, flow_id, src, user, collection, + ws_url, flow_id, src, collection, label_cache, debug=debug ) chain_str = _format_provenance_chain(chain) @@ -639,12 +633,12 @@ async def _question_explainable( def _question_explainable_api( - url, flow_id, question_text, user, collection, entity_limit, triple_limit, + url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, - edge_limit=25, token=None, debug=False + edge_limit=25, token=None, debug=False, workspace="default", ): """Execute graph RAG with explainability using the new API classes.""" - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) explain_client = ExplainabilityClient(flow, retry_delay=0.2, max_retries=10) @@ -653,8 +647,7 @@ def _question_explainable_api( # Stream GraphRAG with explainability - process events as they arrive for item in flow.graph_rag_explain( query=question_text, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -676,8 +669,7 @@ def _question_explainable_api( entity = explain_client.fetch_entity( prov_id, graph=explain_graph, - user=user, - collection=collection + collection=collection ) if entity is None: @@ -707,7 +699,7 @@ def _question_explainable_api( if entity.entities: print(f" Seed entities: {len(entity.entities)}", file=sys.stderr) for ent in entity.entities: - label = explain_client.resolve_label(ent, user, collection) + label = explain_client.resolve_label(ent, collection) print(f" - {label}", file=sys.stderr) elif isinstance(entity, Focus): @@ -719,15 +711,14 @@ def _question_explainable_api( focus_full = explain_client.fetch_focus_with_edges( prov_id, graph=explain_graph, - user=user, - collection=collection + collection=collection ) if focus_full and focus_full.edge_selections: for edge_sel in focus_full.edge_selections: if edge_sel.edge: # Resolve labels for edge components s_label, p_label, o_label = explain_client.resolve_edge_labels( - edge_sel.edge, user, collection + edge_sel.edge, collection ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) if edge_sel.reasoning: @@ -750,10 +741,11 @@ def _question_explainable_api( def question( - url, flow_id, question, user, collection, entity_limit, triple_limit, + url, flow_id, question, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=50, edge_limit=25, streaming=True, token=None, - explainable=False, debug=False, show_usage=False + explainable=False, debug=False, show_usage=False, + workspace="default", ): # Explainable mode uses the API to capture and process provenance events @@ -762,8 +754,7 @@ def question( url=url, flow_id=flow_id, question_text=question, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -771,12 +762,13 @@ def question( edge_score_limit=edge_score_limit, edge_limit=edge_limit, token=token, - debug=debug + debug=debug, + workspace=workspace, ) return # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) if streaming: # Use socket client for streaming @@ -786,8 +778,7 @@ def question( try: response = flow.graph_rag( query=question, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -819,8 +810,7 @@ def question( flow = api.flow().id(flow_id) result = flow.graph_rag( query=question, - user=user, - collection=collection, + collection=collection, entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, @@ -857,6 +847,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -869,12 +865,6 @@ def main(): help=f'Question to answer', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -955,7 +945,6 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, - user=args.user, collection=args.collection, entity_limit=args.entity_limit, triple_limit=args.triple_limit, @@ -968,6 +957,7 @@ def main(): explainable=args.explainable, debug=args.debug, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py index 3bf521f6..2006e9e8 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_llm.py +++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py @@ -9,12 +9,13 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def query(url, flow_id, system, prompt, streaming=True, token=None, - show_usage=False): + show_usage=False, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -74,6 +75,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( 'system', nargs=1, @@ -116,6 +123,7 @@ def main(): streaming=not args.no_streaming, token=args.token, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py b/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py index c5700c5c..32c20768 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/invoke_mcp_tool.py @@ -11,10 +11,12 @@ import json from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, name, parameters): +def query(url, flow_id, name, parameters, token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) resp = api.mcp_tool(name=name, parameters=parameters) @@ -36,6 +38,18 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -68,6 +82,8 @@ def main(): flow_id = args.flow_id, name = args.name, parameters = parameters, + token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py b/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py index 8b01187c..332531db 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_nlp_query.py @@ -10,9 +10,11 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def nlp_query(url, flow_id, question, max_results, output_format='json'): +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") +def nlp_query(url, flow_id, question, max_results, output_format='json', token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) resp = api.nlp_query( question=question, @@ -63,6 +65,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -100,6 +113,11 @@ def main(): question=args.question, max_results=args.max_results, output_format=args.format, + + token = args.token, + + workspace = args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py index 86f7a024..ed47df90 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py +++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py @@ -14,12 +14,13 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def query(url, flow_id, template_id, variables, streaming=True, token=None, - show_usage=False): + show_usage=False, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -80,6 +81,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -135,6 +142,7 @@ specified multiple times''', streaming=not args.no_streaming, token=args.token, show_usage=args.show_usage, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py index 7393b4c3..8244ae99 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py @@ -9,11 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None): +def query(url, flow_id, query_text, schema_name, collection, index_name, limit, token=None, workspace="default"): # Create API client - api = Api(url=url, token=token) + api = Api(url=url, token=token, workspace=workspace) socket = api.socket() flow = socket.flow(flow_id) @@ -22,7 +23,6 @@ def query(url, flow_id, query_text, schema_name, user, collection, index_name, l result = flow.row_embeddings_query( text=query_text, schema_name=schema_name, - user=user, collection=collection, index_name=index_name, limit=limit @@ -60,15 +60,15 @@ def main(): ) parser.add_argument( - '-f', '--flow-id', - default="default", - help=f'Flow ID (default: default)' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='User/keyspace (default: trustgraph)', + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' ) parser.add_argument( @@ -111,11 +111,11 @@ def main(): flow_id=args.flow_id, query_text=args.query[0], schema_name=args.schema_name, - user=args.user, collection=args.collection, index_name=args.index_name, limit=args.limit, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_rows_query.py b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py index 962f353c..46fba4d7 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_rows_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py @@ -12,10 +12,11 @@ from trustgraph.api import Api from tabulate import tabulate default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' -def format_output(data, output_format): +def format_output(data, output_format, token=None, workspace="default"): """Format GraphQL response data in the specified format""" if not data: return "No data returned" @@ -82,10 +83,10 @@ def format_table_data(rows, table_name, output_format): return json.dumps({table_name: rows}, indent=2) def rows_query( - url, flow_id, query, user, collection, variables, operation_name, output_format='table' + url, flow_id, query, collection, variables, operation_name, output_format='table', token=None, workspace="default" ): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) # Parse variables if provided as JSON string parsed_variables = {} @@ -98,7 +99,6 @@ def rows_query( resp = api.rows_query( query=query, - user=user, collection=collection, variables=parsed_variables if parsed_variables else None, operation_name=operation_name @@ -135,6 +135,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -148,12 +159,6 @@ def main(): help='GraphQL query to execute', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -185,11 +190,13 @@ def main(): url=args.url, flow_id=args.flow_id, query=args.query, - user=args.user, collection=args.collection, variables=args.variables, operation_name=args.operation_name, output_format=args.format, + token=args.token, + workspace=args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py index 7b1ae9a6..26e03929 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -9,7 +9,8 @@ import sys from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' @@ -44,10 +45,10 @@ def _term_str(val): return str(val) -def sparql_query(url, token, flow_id, query, user, collection, limit, - batch_size, output_format): +def sparql_query(url, token, flow_id, query, collection, limit, + batch_size, output_format, workspace="default"): - socket = Api(url=url, token=token).socket() + socket = Api(url=url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) variables = None @@ -57,7 +58,6 @@ def sparql_query(url, token, flow_id, query, user, collection, limit, for response in flow.sparql_query_stream( query=query, - user=user, collection=collection, limit=limit, batch_size=batch_size, @@ -154,8 +154,14 @@ def main(): parser.add_argument( '-t', '--token', - default=os.getenv("TRUSTGRAPH_TOKEN"), - help='API bearer token (default: TRUSTGRAPH_TOKEN env var)', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -174,12 +180,6 @@ def main(): help='Read SPARQL query from file (use - for stdin)', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -228,11 +228,11 @@ def main(): token=args.token, flow_id=args.flow_id, query=query, - user=args.user, collection=args.collection, limit=args.limit, batch_size=args.batch_size, output_format=args.format, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py index 9f5f8540..af2060bb 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py @@ -13,7 +13,9 @@ from tabulate import tabulate default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def format_output(data, output_format): +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") +def format_output(data, output_format, token=None, workspace="default"): """Format structured query response data in the specified format""" if not data: return "No data returned" @@ -79,11 +81,11 @@ def format_table_data(rows, table_name, output_format): else: return json.dumps({table_name: rows}, indent=2) -def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'): +def structured_query(url, flow_id, question, collection='default', output_format='table', token=None, workspace="default"): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token, workspace=workspace).flow().id(flow_id) - resp = api.structured_query(question=question, user=user, collection=collection) + resp = api.structured_query(question=question, collection=collection) # Check for errors if "error" in resp and resp["error"]: @@ -119,6 +121,17 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) parser.add_argument( '-f', '--flow-id', @@ -132,12 +145,6 @@ def main(): help='Natural language question to execute', ) - parser.add_argument( - '--user', - default='trustgraph', - help='Cassandra keyspace identifier (default: trustgraph)' - ) - parser.add_argument( '--collection', default='default', @@ -159,9 +166,12 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, - user=args.user, collection=args.collection, output_format=args.format, + token=args.token, + + workspace = args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_collections.py b/trustgraph-cli/trustgraph/cli/list_collections.py index 4086f471..e2f90f56 100644 --- a/trustgraph-cli/trustgraph/cli/list_collections.py +++ b/trustgraph-cli/trustgraph/cli/list_collections.py @@ -1,23 +1,22 @@ """ -List collections for a user +List collections in a workspace """ import argparse import os import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def list_collections(url, user, tag_filter): +def list_collections(url, tag_filter, token=None, workspace="default"): - api = Api(url).collection() + api = Api(url, token=token, workspace=workspace).collection() - collections = api.list_collections(user=user, tag_filter=tag_filter) + collections = api.list_collections(tag_filter=tag_filter) - # Handle None or empty collections if not collections or len(collections) == 0: print("No collections found.") return @@ -54,26 +53,33 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--tag-filter', action='append', help='Filter by tags (can be specified multiple times)' ) + parser.add_argument( + '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: list_collections( url = args.api_url, - user = args.user, - tag_filter = args.tag_filter + tag_filter = args.tag_filter, + token = args.token, + workspace = args.workspace, ) except Exception as e: @@ -81,4 +87,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/list_config_items.py b/trustgraph-cli/trustgraph/cli/list_config_items.py index 5cd0f233..8bc3f683 100644 --- a/trustgraph-cli/trustgraph/cli/list_config_items.py +++ b/trustgraph-cli/trustgraph/cli/list_config_items.py @@ -9,10 +9,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def list_config_items(url, config_type, format_type, token=None): +def list_config_items(url, config_type, format_type, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() keys = api.list(config_type) @@ -54,6 +56,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -63,6 +71,7 @@ def main(): config_type=args.type, format_type=args.format, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py index e6d1e075..9bc87db6 100644 --- a/trustgraph-cli/trustgraph/cli/list_explain_traces.py +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -18,7 +18,7 @@ from trustgraph.api import Api, ExplainabilityClient default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Retrieval graph @@ -86,9 +86,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -120,7 +120,7 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() flow = socket.flow(args.flow_id) explain_client = ExplainabilityClient(flow) @@ -129,7 +129,6 @@ def main(): # List all sessions — uses persistent websocket via SocketClient questions = explain_client.list_sessions( graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, limit=args.limit, ) @@ -141,7 +140,6 @@ def main(): session_type = explain_client.detect_session_type( q.uri, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection ) diff --git a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py index 20c78515..a776c59b 100644 --- a/trustgraph-cli/trustgraph/cli/load_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/load_doc_embeds.py @@ -46,7 +46,6 @@ async def load_de(running, queue, url): "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], - "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "chunks": [ @@ -77,7 +76,7 @@ async def stats(running): f"Graph embeddings: {de_counts:10d}" ) -async def loader(running, de_queue, path, format, user, collection): +async def loader(running, de_queue, path, format, collection): if format == "json": @@ -96,9 +95,6 @@ async def loader(running, de_queue, path, format, user, collection): except: break - if user: - unpacked["metadata"]["user"] = user - if collection: unpacked["metadata"]["collection"] = collection @@ -148,9 +144,9 @@ async def run(running, **args): running=running, de_queue=de_q, path=args["input_file"], format=args["format"], - user=args["user"], collection=args["collection"], + collection=args["collection"], ) - + ) de_task = asyncio.create_task( @@ -178,7 +174,6 @@ async def main(running): ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") - default_user = "trustgraph" collection = "default" parser.add_argument( @@ -207,11 +202,6 @@ async def main(running): help=f'Output format (default: msgpack)', ) - parser.add_argument( - '--user', - help=f'User ID to load as (default: from input)' - ) - parser.add_argument( '--collection', help=f'Collection ID to load as (default: from input)' diff --git a/trustgraph-cli/trustgraph/cli/load_kg_core.py b/trustgraph-cli/trustgraph/cli/load_kg_core.py index 008b124f..281255be 100644 --- a/trustgraph-cli/trustgraph/cli/load_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/load_kg_core.py @@ -6,20 +6,19 @@ run this utility. import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_flow = "default" default_collection = "default" -def load_kg_core(url, user, id, flow, collection): +def load_kg_core(url, id, flow, collection, token=None, workspace="default"): - api = Api(url).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.load_kg_core(user = user, id = id, flow=flow, - collection=collection) + api.load_kg_core(id=id, flow=flow, collection=collection) def main(): @@ -34,12 +33,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', - ) - parser.add_argument( '--id', '--identifier', required=True, @@ -49,13 +42,25 @@ def main(): parser.add_argument( '-f', '--flow-id', default=default_flow, - help=f'Flow ID (default: {default_flow}', + help=f'Flow ID (default: {default_flow})', ) parser.add_argument( '-C', '--collection', default=default_collection, - help=f'Collection ID (default: {default_collection}', + help=f'Collection ID (default: {default_collection})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -64,10 +69,11 @@ def main(): load_kg_core( url=args.api_url, - user=args.user, id=args.id, flow=args.flow_id, collection=args.collection, + token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/load_knowledge.py b/trustgraph-cli/trustgraph/cli/load_knowledge.py index 5e96850f..7e9dadd4 100644 --- a/trustgraph-cli/trustgraph/cli/load_knowledge.py +++ b/trustgraph-cli/trustgraph/cli/load_knowledge.py @@ -13,7 +13,7 @@ from trustgraph.log_level import LogLevel default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class KnowledgeLoader: @@ -22,19 +22,18 @@ class KnowledgeLoader: self, files, flow, - user, collection, document_id, url=default_url, - token=None, + token=None, workspace="default", ): self.files = files self.flow = flow - self.user = user self.collection = collection self.document_id = document_id self.url = url self.token = token + self.workspace = workspace def load_triples_from_file(self, file) -> Iterator[Triple]: """Generator that yields Triple objects from a Turtle file""" @@ -43,11 +42,9 @@ class KnowledgeLoader: g.parse(file, format="turtle") for e in g: - # Extract subject, predicate, object s_value = str(e[0]) p_value = str(e[1]) - # Check if object is a URI or literal if isinstance(e[2], rdflib.term.URIRef): o_value = str(e[2]) o_is_uri = True @@ -55,9 +52,6 @@ class KnowledgeLoader: o_value = str(e[2]) o_is_uri = False - # Create Triple object - # Note: The Triple dataclass has 's', 'p', 'o' fields as strings - # The API will handle the metadata wrapping yield Triple(s=s_value, p=p_value, o=o_value) def load_entity_contexts_from_file(self, file) -> Iterator[Tuple[str, str]]: @@ -67,11 +61,9 @@ class KnowledgeLoader: g.parse(file, format="turtle") for s, p, o in g: - # If object is a URI, skip (we only want literal contexts) if isinstance(o, rdflib.term.URIRef): continue - # If object is a literal, create entity context for subject s_str = str(s) o_str = str(o) @@ -81,11 +73,9 @@ class KnowledgeLoader: """Load triples and entity contexts using Python API""" try: - # Create API client - api = Api(url=self.url, token=self.token) + api = Api(url=self.url, token=self.token, workspace=self.workspace) bulk = api.bulk() - # Load triples from all files print("Loading triples...") total_triples = 0 for file in self.files: @@ -104,7 +94,6 @@ class KnowledgeLoader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, "collection": self.collection } ) @@ -113,20 +102,16 @@ class KnowledgeLoader: print(f"Triples loaded. Total: {total_triples}") - # Load entity contexts from all files print("Loading entity contexts...") total_contexts = 0 for file in self.files: print(f" Processing {file}...") count = 0 - # Convert tuples to the format expected by import_entity_contexts - # Entity must be in Term format: {"t": "i", "i": uri} for IRI def entity_context_generator(): nonlocal count for entity, context in self.load_entity_contexts_from_file(file): count += 1 - # Entities from RDF are URIs, use IRI term format yield { "entity": {"t": "i", "i": entity}, "context": context @@ -138,7 +123,6 @@ class KnowledgeLoader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, "collection": self.collection } ) @@ -170,6 +154,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -182,12 +172,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -210,8 +194,8 @@ def main(): token=args.token, flow=args.flow_id, files=args.files, - user=args.user, collection=args.collection, + workspace=args.workspace, ) loader.run() diff --git a/trustgraph-cli/trustgraph/cli/load_sample_documents.py b/trustgraph-cli/trustgraph/cli/load_sample_documents.py index 186006a8..0398864c 100644 --- a/trustgraph-cli/trustgraph/cli/load_sample_documents.py +++ b/trustgraph-cli/trustgraph/cli/load_sample_documents.py @@ -12,8 +12,8 @@ from trustgraph.api import Api from trustgraph.api.types import hash, Uri, Literal, Triple default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") from requests.adapters import HTTPAdapter @@ -656,11 +656,10 @@ documents = [ class Loader: def __init__( - self, url, user, token=None + self, url, token=None, workspace="default", ): - self.api = Api(url, token=token).library() - self.user = user + self.api = Api(url, token=token, workspace=workspace).library() def load(self, documents): @@ -689,10 +688,10 @@ class Loader: print(" adding...") self.api.add_document( - id = doc["id"], metadata = doc["metadata"], - user = self.user, kind = doc["kind"], title = doc["title"], - comments = doc["comments"], tags = doc["tags"], - document = content + id=doc["id"], metadata=doc["metadata"], + kind=doc["kind"], title=doc["title"], + comments=doc["comments"], tags=doc["tags"], + document=content, ) print(" successful.") @@ -714,26 +713,26 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--token', default=default_token, help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: p = Loader( url=args.url, - user=args.user, token=args.token, + workspace=args.workspace, ) p.load(documents) diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index fa167917..3cd2a229 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def load_structured_data( @@ -39,11 +40,11 @@ def load_structured_data( sample_chars: int = 500, schema_name: str = None, flow: str = 'default', - user: str = 'trustgraph', collection: str = 'default', dry_run: bool = False, verbose: bool = False, - token: str = None + token: str = None, + workspace: str = "default", ): """ Load structured data using a descriptor configuration. @@ -62,7 +63,6 @@ def load_structured_data( sample_chars: Maximum characters to read for sampling schema_name: Target schema name for generation flow: TrustGraph flow name to use for prompts - user: User name for metadata (default: trustgraph) collection: Collection name for metadata (default: default) dry_run: If True, validate but don't import data verbose: Enable verbose logging @@ -78,7 +78,7 @@ def load_structured_data( logger.info("Step 1: Analyzing data to discover best matching schema...") # Step 1: Auto-discover schema (reuse discover_schema logic) - discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger) + discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) if not discovered_schema: logger.error("Failed to discover suitable schema automatically") print("❌ Could not automatically determine the best schema for your data.") @@ -90,7 +90,7 @@ def load_structured_data( # Step 2: Auto-generate descriptor logger.info("Step 2: Generating descriptor configuration...") - auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger) + auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace) if not auto_descriptor: logger.error("Failed to generate descriptor automatically") print("❌ Could not automatically generate descriptor configuration.") @@ -110,7 +110,7 @@ def load_structured_data( try: # Use shared pipeline for preview (small sample) - preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, user, collection, sample_size=5) + preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, collection, sample_size=5) # Show preview print("📊 Data Preview (first few records):") @@ -131,13 +131,13 @@ def load_structured_data( print("🚀 Importing data to TrustGraph...") # Use shared pipeline for full processing (no sample limit) - output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, user, collection) + output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, collection) # Get batch size from descriptor batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) # Send to TrustGraph using shared function - imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token) + imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token, workspace=workspace) # Summary format_info = descriptor.get('format', {}) @@ -172,7 +172,7 @@ def load_structured_data( logger.info(f"Sample chars: {sample_chars} characters") # Use the helper function to discover schema (get raw response for display) - response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True) + response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace) if response: # Debug: print response type and content @@ -203,7 +203,7 @@ def load_structured_data( # If no schema specified, discover it first if not schema_name: logger.info("No schema specified, auto-discovering...") - schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger) + schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) if not schema_name: print("Error: Could not determine schema automatically.") print("Please specify a schema using --schema-name or run --discover-schema first.") @@ -213,7 +213,7 @@ def load_structured_data( logger.info(f"Target schema: {schema_name}") # Generate descriptor using helper function - descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger) + descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace) if descriptor: # Output the generated descriptor @@ -242,7 +242,7 @@ def load_structured_data( logger.info(f"Parsing {input_file} with descriptor {descriptor_file}...") # Use shared pipeline - output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size) + output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection, sample_size) # Output results if output_file: @@ -286,7 +286,7 @@ def load_structured_data( logger.info(f"Loading {input_file} to TrustGraph using descriptor {descriptor_file}...") # Use shared pipeline (no sample_size limit for full load) - output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection) + output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection) # Get batch size from descriptor or use default batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) @@ -527,18 +527,17 @@ def _apply_transformations(records, mappings): return processed_records -def _format_extracted_objects(processed_records, descriptor, user, collection): +def _format_extracted_objects(processed_records, descriptor, collection): """Convert to TrustGraph ExtractedObject format""" output_records = [] schema_name = descriptor.get('output', {}).get('schema_name', 'default') confidence = descriptor.get('output', {}).get('options', {}).get('confidence', 0.9) - + for record in processed_records: output_record = { "metadata": { "id": f"parsed-{len(output_records)+1}", "metadata": [], # Empty metadata triples - "user": user, "collection": collection }, "schema_name": schema_name, @@ -551,7 +550,7 @@ def _format_extracted_objects(processed_records, descriptor, user, collection): return output_records -def _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size=None): +def _process_data_pipeline(input_file, descriptor_file, collection, sample_size=None): """Shared pipeline: load descriptor → read → parse → transform → format""" # Load descriptor configuration descriptor = _load_descriptor(descriptor_file) @@ -568,12 +567,12 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample processed_records = _apply_transformations(parsed_records, mappings) # Format output for TrustGraph ExtractedObject structure - output_records = _format_extracted_objects(processed_records, descriptor, user, collection) + output_records = _format_extracted_objects(processed_records, descriptor, collection) return output_records, descriptor -def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): +def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, workspace="default"): """Send ExtractedObject records to TrustGraph using Python API""" from trustgraph.api import Api @@ -582,7 +581,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): logger.info(f"Importing {total_records} records to TrustGraph...") # Use Python API bulk import - api = Api(api_url, token=token) + api = Api(api_url, token=token, workspace=workspace) bulk = api.bulk() bulk.import_rows(flow=flow, rows=iter(rows)) @@ -604,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): # Helper functions for auto mode -def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False): +def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"): """Auto-discover the best matching schema for the input data Args: @@ -627,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url) + api = Api(api_url, workspace=workspace) config_api = api.config() # Get available schemas @@ -708,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur return None -def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger): +def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"): """Auto-generate descriptor configuration for the discovered schema""" try: # Read sample data @@ -718,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url) + api = Api(api_url, workspace=workspace) config_api = api.config() # Get schema definition @@ -885,12 +884,6 @@ For more information on the descriptor format, see: help='TrustGraph flow name to use for prompts and import (default: default)' ) - parser.add_argument( - '--user', - default='trustgraph', - help='User name for metadata (default: trustgraph)' - ) - parser.add_argument( '--collection', default='default', @@ -997,6 +990,12 @@ For more information on the descriptor format, see: help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() # Input validation @@ -1046,11 +1045,11 @@ For more information on the descriptor format, see: sample_chars=args.sample_chars, schema_name=args.schema_name, flow=args.flow, - user=args.user, collection=args.collection, dry_run=args.dry_run, verbose=args.verbose, - token=args.token + token=args.token, + workspace=args.workspace, ) except FileNotFoundError as e: print(f"Error: File not found - {e}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/load_turtle.py b/trustgraph-cli/trustgraph/cli/load_turtle.py index adb578f5..43ef9e6f 100644 --- a/trustgraph-cli/trustgraph/cli/load_turtle.py +++ b/trustgraph-cli/trustgraph/cli/load_turtle.py @@ -13,7 +13,7 @@ from trustgraph.log_level import LogLevel default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' class Loader: @@ -22,15 +22,14 @@ class Loader: self, files, flow, - user, collection, document_id, url=default_url, - token=None, + token=None, workspace="default", ): self.files = files self.flow = flow - self.user = user + self.workspace = workspace self.collection = collection self.document_id = document_id self.url = url @@ -43,28 +42,23 @@ class Loader: g.parse(file, format="turtle") for e in g: - # Extract subject, predicate, object s_value = str(e[0]) p_value = str(e[1]) - # Check if object is a URI or literal if isinstance(e[2], rdflib.term.URIRef): o_value = str(e[2]) else: o_value = str(e[2]) - # Create Triple object yield Triple(s=s_value, p=p_value, o=o_value) def run(self): """Load triples using Python API""" try: - # Create API client - api = Api(url=self.url, token=self.token) + api = Api(url=self.url, token=self.token, workspace=self.workspace) bulk = api.bulk() - # Load triples from all files print("Loading triples...") for file in self.files: print(f" Processing {file}...") @@ -76,7 +70,6 @@ class Loader: metadata={ "id": self.document_id, "metadata": [], - "user": self.user, "collection": self.collection } ) @@ -106,6 +99,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -118,12 +117,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -146,8 +139,8 @@ def main(): token=args.token, flow=args.flow_id, files=args.files, - user=args.user, collection=args.collection, + workspace=args.workspace, ) loader.run() diff --git a/trustgraph-cli/trustgraph/cli/put_config_item.py b/trustgraph-cli/trustgraph/cli/put_config_item.py index d79864a4..fda9cbeb 100644 --- a/trustgraph-cli/trustgraph/cli/put_config_item.py +++ b/trustgraph-cli/trustgraph/cli/put_config_item.py @@ -10,10 +10,12 @@ from trustgraph.api.types import ConfigValue default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def put_config_item(url, config_type, key, value, token=None): +def put_config_item(url, config_type, key, value, token=None, + workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config_value = ConfigValue(type=config_type, key=key, value=value) api.put([config_value]) @@ -63,6 +65,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -78,6 +86,7 @@ def main(): key=args.key, value=value, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py b/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py index 740a224a..96db6bec 100644 --- a/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py +++ b/trustgraph-cli/trustgraph/cli/put_flow_blueprint.py @@ -10,10 +10,12 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def put_flow_blueprint(url, blueprint_name, config, token=None): +def put_flow_blueprint(url, blueprint_name, config, token=None, + workspace="default"): - api = Api(url, token=token) + api = Api(url, token=token, workspace=workspace) blueprint_names = api.flow().put_blueprint(blueprint_name, config) @@ -36,6 +38,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', help=f'Flow blueprint name', @@ -55,6 +63,7 @@ def main(): blueprint_name=args.blueprint_name, config=json.loads(args.config), token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index cd0738fe..bd3169c8 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -1,10 +1,9 @@ """ -Uses the agent service to answer a question +Puts a knowledge core into the knowledge manager via the API socket. """ import argparse import os -import textwrap import uuid import asyncio import json @@ -12,18 +11,17 @@ from websockets.asyncio.client import connect import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): -def read_message(unpacked, id, user): - if unpacked[0] == "ge": msg = unpacked[1] return "ge", { "metadata": { "id": id, "metadata": msg["m"]["m"], - "user": user, "collection": "default", # Not used? }, "entities": [ @@ -40,7 +38,6 @@ def read_message(unpacked, id, user): "metadata": { "id": id, "metadata": msg["m"]["m"], - "user": user, "collection": "default", # Not used by receiver? }, "triples": msg["t"], @@ -48,7 +45,7 @@ def read_message(unpacked, id, user): else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, user, id, input, token=None): +async def put(url, workspace, id, input, token=None): if not url.endswith("/"): url += "/" @@ -60,7 +57,6 @@ async def put(url, user, id, input, token=None): async with connect(url) as ws: - ge = 0 t = 0 @@ -75,7 +71,7 @@ async def put(url, user, id, input, token=None): except: break - kind, msg = read_message(unpacked, id, user) + kind, msg = read_message(unpacked, id) mid = str(uuid.uuid4()) @@ -85,10 +81,11 @@ async def put(url, user, id, input, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "graph-embeddings": msg } @@ -100,10 +97,11 @@ async def put(url, user, id, input, token=None): req = json.dumps({ "id": mid, + "workspace": workspace, "service": "knowledge", "request": { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "triples": msg } @@ -117,7 +115,7 @@ async def put(url, user, id, input, token=None): # Retry loop, wait for right response to come back while True: - + msg = await ws.recv() msg = json.loads(msg) @@ -146,10 +144,11 @@ def main(): default=default_url, help=f'API URL (default: {default_url})', ) + parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -176,11 +175,11 @@ def main(): asyncio.run( put( - url = args.url, - user = args.user, - id = args.id, - input = args.input, - token = args.token, + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) ) @@ -189,4 +188,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/query_graph.py b/trustgraph-cli/trustgraph/cli/query_graph.py index a2c38353..091f0599 100644 --- a/trustgraph-cli/trustgraph/cli/query_graph.py +++ b/trustgraph-cli/trustgraph/cli/query_graph.py @@ -23,9 +23,9 @@ import sys from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def parse_inline_quoted_triple(value): @@ -285,15 +285,16 @@ def output_jsonl(triples): def query_graph( - url, flow_id, user, collection, limit, batch_size, + url, flow_id, collection, limit, batch_size, subject=None, predicate=None, obj=None, graph=None, - output_format="space", headers=False, token=None + output_format="space", headers=False, token=None, + workspace="default", ): """Query the triple store with pattern matching. Uses the API's triples_query_stream for efficient streaming delivery. """ - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) all_triples = [] @@ -305,7 +306,6 @@ def query_graph( p=predicate, o=obj, g=graph, - user=user, collection=collection, limit=limit, batch_size=batch_size, @@ -456,13 +456,6 @@ def main(): help='Flow ID (default: default)' ) - std_group.add_argument( - '-U', '--user', - default=default_user, - metavar='USER', - help=f'User/keyspace (default: {default_user})' - ) - std_group.add_argument( '-C', '--collection', default=default_collection, @@ -477,6 +470,12 @@ def main(): help='Auth token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + std_group.add_argument( '-l', '--limit', type=int, @@ -550,7 +549,6 @@ def main(): query_graph( url=args.api_url, flow_id=args.flow_id, - user=args.user, collection=args.collection, limit=args.limit, batch_size=args.batch_size, @@ -561,6 +559,8 @@ def main(): output_format=args.format, headers=args.headers, token=args.token, + + workspace=args.workspace, ) except json.JSONDecodeError as e: diff --git a/trustgraph-cli/trustgraph/cli/remove_library_document.py b/trustgraph-cli/trustgraph/cli/remove_library_document.py index 07a1fd59..d6500d50 100644 --- a/trustgraph-cli/trustgraph/cli/remove_library_document.py +++ b/trustgraph-cli/trustgraph/cli/remove_library_document.py @@ -4,20 +4,19 @@ Remove a document from the library import argparse import os -import uuid from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def remove_doc(url, user, id, token=None): +def remove_doc(url, id, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - api.remove_document(user=user, id=id) + api.remove_document(id=id) def main(): @@ -32,12 +31,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '--identifier', '--id', required=True, @@ -50,15 +43,24 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: - remove_doc(args.url, args.user, args.identifier, token=args.token) + remove_doc( + args.url, args.identifier, + token=args.token, workspace=args.workspace, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py index ca8d25de..99d6b4db 100644 --- a/trustgraph-cli/trustgraph/cli/save_doc_embeds.py +++ b/trustgraph-cli/trustgraph/cli/save_doc_embeds.py @@ -21,7 +21,7 @@ class Running: def get(self): return self.running def stop(self): self.running = False -async def fetch_de(running, queue, user, collection, url): +async def fetch_de(running, queue, collection, url): async with aiohttp.ClientSession() as session: @@ -38,10 +38,6 @@ async def fetch_de(running, queue, user, collection, url): data = msg.json() - if user: - if data["metadata"]["user"] != user: - continue - if collection: if data["metadata"]["collection"] != collection: continue @@ -52,7 +48,6 @@ async def fetch_de(running, queue, user, collection, url): "m": { "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "c": [ @@ -119,7 +114,7 @@ async def run(running, **args): de_task = asyncio.create_task( fetch_de( running=running, - queue=q, user=args["user"], collection=args["collection"], + queue=q, collection=args["collection"], url = f"{url}api/v1/flow/{flow_id}/export/document-embeddings" ) ) @@ -148,7 +143,6 @@ async def main(running): ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") - default_user = "trustgraph" collection = "default" parser.add_argument( @@ -177,11 +171,6 @@ async def main(running): help=f'Output format (default: msgpack)', ) - parser.add_argument( - '--user', - help=f'User ID to filter on (default: no filter)' - ) - parser.add_argument( '--collection', help=f'Collection ID to filter on (default: no filter)' diff --git a/trustgraph-cli/trustgraph/cli/set_collection.py b/trustgraph-cli/trustgraph/cli/set_collection.py index dd4148ea..53aaa74d 100644 --- a/trustgraph-cli/trustgraph/cli/set_collection.py +++ b/trustgraph-cli/trustgraph/cli/set_collection.py @@ -8,15 +8,14 @@ import tabulate from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_collection(url, user, collection, name, description, tags, token=None): +def set_collection(url, collection, name, description, tags, token=None, workspace="default"): - api = Api(url, token=token).collection() + api = Api(url, token=token, workspace=workspace).collection() result = api.update_collection( - user=user, collection=collection, name=name, description=description, @@ -59,12 +58,6 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-n', '--name', help='Collection name' @@ -88,18 +81,24 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: set_collection( url = args.api_url, - user = args.user, collection = args.collection, name = args.name, description = args.description, tags = args.tags, - token = args.token + token = args.token, + workspace=args.workspace, ) except Exception as e: @@ -107,4 +106,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py index 7976adbc..65c640c6 100644 --- a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py @@ -21,6 +21,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def set_mcp_tool( url : str, @@ -29,9 +30,10 @@ def set_mcp_tool( tool_url : str, auth_token : str = None, token : str = None, + workspace : str = "default", ): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() # Build the MCP tool configuration config = { @@ -80,6 +82,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--id', required=True, @@ -126,6 +134,8 @@ def main(): tool_url=args.tool_url, auth_token=args.auth_token, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_prompt.py b/trustgraph-cli/trustgraph/cli/set_prompt.py index bffc2cf2..dbf9c326 100644 --- a/trustgraph-cli/trustgraph/cli/set_prompt.py +++ b/trustgraph-cli/trustgraph/cli/set_prompt.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_system(url, system, token=None): +def set_system(url, system, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() api.put([ ConfigValue(type="prompt", key="system", value=json.dumps(system)) @@ -22,9 +23,9 @@ def set_system(url, system, token=None): print("System prompt set.") -def set_prompt(url, id, prompt, response, schema, token=None): +def set_prompt(url, id, prompt, response, schema, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="prompt", key="template-index") @@ -78,6 +79,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '--id', help=f'Prompt ID', diff --git a/trustgraph-cli/trustgraph/cli/set_token_costs.py b/trustgraph-cli/trustgraph/cli/set_token_costs.py index 19b8c703..9b046a7d 100644 --- a/trustgraph-cli/trustgraph/cli/set_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/set_token_costs.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def set_costs(api_url, model, input_costs, output_costs, token=None): +def set_costs(api_url, model, input_costs, output_costs, token=None, workspace="default"): - api = Api(api_url, token=token).config() + api = Api(api_url, token=token, workspace=workspace).config() api.put([ ConfigValue( @@ -26,9 +27,9 @@ def set_costs(api_url, model, input_costs, output_costs, token=None): ), ]) -def set_prompt(url, id, prompt, response, schema): +def set_prompt(url, id, prompt, response, schema, workspace="default"): - api = Api(url) + api = Api(url, workspace=workspace) values = api.config_get([ ConfigKey(type="prompt", key="template-index") @@ -102,6 +103,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: diff --git a/trustgraph-cli/trustgraph/cli/set_tool.py b/trustgraph-cli/trustgraph/cli/set_tool.py index c6412e48..45295089 100644 --- a/trustgraph-cli/trustgraph/cli/set_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_tool.py @@ -28,6 +28,7 @@ import dataclasses default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @dataclasses.dataclass class Argument: @@ -73,9 +74,10 @@ def set_tool( state : str, applicable_states : List[str], token : str = None, + workspace : str = "default", ): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="agent", key="tool-index") @@ -181,6 +183,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '--id', help=f'Unique tool identifier', @@ -303,6 +311,8 @@ def main(): state=args.state, applicable_states=args.applicable_states, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_config.py b/trustgraph-cli/trustgraph/cli/show_config.py index 6f426533..130c59b7 100644 --- a/trustgraph-cli/trustgraph/cli/show_config.py +++ b/trustgraph-cli/trustgraph/cli/show_config.py @@ -9,10 +9,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() config, version = api.all() @@ -38,6 +39,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -45,6 +52,7 @@ def main(): show_config( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index 90c0e452..17aaca1a 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -36,7 +36,7 @@ from trustgraph.api import ( default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Graphs @@ -50,13 +50,12 @@ PROV = "http://www.w3.org/ns/prov#" PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_client): +def trace_edge_provenance(flow, collection, edge, label_cache, explain_client): """ Trace an edge back to its source document via reification. Args: flow: SocketFlowInstance - user: User identifier collection: Collection identifier edge: Dict with s, p, o keys label_cache: Dict for caching labels @@ -90,7 +89,6 @@ def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_cli p=TG_CONTAINS, o=quoted_triple, g=SOURCE_GRAPH, - user=user, collection=collection, limit=10 ) @@ -108,14 +106,14 @@ def trace_edge_provenance(flow, user, collection, edge, label_cache, explain_cli # For each statement, trace wasDerivedFrom chain provenance_chains = [] for stmt_uri in stmt_uris: - chain = trace_provenance_chain(flow, user, collection, stmt_uri, label_cache, explain_client) + chain = trace_provenance_chain(flow, collection, stmt_uri, label_cache, explain_client) if chain: provenance_chains.append(chain) return provenance_chains -def trace_provenance_chain(flow, user, collection, start_uri, label_cache, explain_client, max_depth=10): +def trace_provenance_chain(flow, collection, start_uri, label_cache, explain_client, max_depth=10): """Trace prov:wasDerivedFrom chain from start_uri to root.""" chain = [] current = start_uri @@ -128,7 +126,7 @@ def trace_provenance_chain(flow, user, collection, start_uri, label_cache, expla if current in label_cache: label = label_cache[current] else: - label = explain_client.resolve_label(current, user, collection) + label = explain_client.resolve_label(current, collection) label_cache[current] = label chain.append({"uri": current, "label": label}) @@ -139,7 +137,6 @@ def trace_provenance_chain(flow, user, collection, start_uri, label_cache, expla s=current, p=PROV_WAS_DERIVED_FROM, g=SOURCE_GRAPH, - user=user, collection=collection, limit=1 ) @@ -167,7 +164,7 @@ def format_provenance_chain(chain): return " -> ".join(labels) -def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, show_provenance=False): +def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_provenance=False): """Print GraphRAG trace in text format.""" question = trace.get("question") @@ -202,7 +199,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, for i, edge_sel in enumerate(edges, 1): if edge_sel.edge: s_label, p_label, o_label = explain_client.resolve_edge_labels( - edge_sel.edge, user, collection + edge_sel.edge, collection ) print(f" {i}. ({s_label}, {p_label}, {o_label})") @@ -212,7 +209,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, if show_provenance and edge_sel.edge: provenance = trace_edge_provenance( - flow, user, collection, edge_sel.edge, + flow, collection, edge_sel.edge, label_cache, explain_client ) for chain in provenance: @@ -238,7 +235,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, content = "" if synthesis.document and api: content = explain_client.fetch_document_content( - synthesis.document, api, user + synthesis.document, api ) if content: print("Answer:") @@ -252,7 +249,7 @@ def print_graphrag_text(trace, explain_client, flow, user, collection, api=None, print("No synthesis data found") -def print_docrag_text(trace, explain_client, api, user): +def print_docrag_text(trace, explain_client, api): """Print DocRAG trace in text format.""" question = trace.get("question") @@ -288,7 +285,7 @@ def print_docrag_text(trace, explain_client, api, user): content = "" if synthesis.document and api: content = explain_client.fetch_document_content( - synthesis.document, api, user + synthesis.document, api ) if content: print("Answer:") @@ -302,14 +299,14 @@ def print_docrag_text(trace, explain_client, api, user): print("No synthesis data found") -def _print_document_content(explain_client, api, user, document_uri, label="Answer"): +def _print_document_content(explain_client, api, document_uri, label="Answer"): """Fetch and print document content, or fall back to URI.""" if not document_uri: return content = "" if api: content = explain_client.fetch_document_content( - document_uri, api, user + document_uri, api ) if content: print(f"{label}:") @@ -319,7 +316,7 @@ def _print_document_content(explain_client, api, user, document_uri, label="Answ print(f"Document: {document_uri}") -def print_agent_text(trace, explain_client, api, user): +def print_agent_text(trace, explain_client, api): """Print Agent trace in text format.""" question = trace.get("question") @@ -348,7 +345,7 @@ def print_agent_text(trace, explain_client, api, user): print("--- Finding ---") print(f"Goal: {step.goal}") _print_document_content( - explain_client, api, user, step.document, "Result", + explain_client, api, step.document, "Result", ) print() @@ -363,7 +360,7 @@ def print_agent_text(trace, explain_client, api, user): print("--- Step Result ---") print(f"Step: {step.step}") _print_document_content( - explain_client, api, user, step.document, "Result", + explain_client, api, step.document, "Result", ) print() @@ -385,21 +382,21 @@ def print_agent_text(trace, explain_client, api, user): elif isinstance(step, Observation): print("--- Observation ---") _print_document_content( - explain_client, api, user, step.document, "Content", + explain_client, api, step.document, "Content", ) print() elif isinstance(step, Synthesis): print("--- Synthesis ---") _print_document_content( - explain_client, api, user, step.document, "Answer", + explain_client, api, step.document, "Answer", ) print() elif isinstance(step, Conclusion): print("--- Conclusion ---") _print_document_content( - explain_client, api, user, step.document, "Answer", + explain_client, api, step.document, "Answer", ) print() @@ -559,9 +556,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -599,7 +596,7 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() flow = socket.flow(args.flow_id) explain_client = ExplainabilityClient(flow) @@ -609,7 +606,6 @@ def main(): trace_type = explain_client.detect_session_type( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, ) @@ -618,7 +614,6 @@ def main(): trace = explain_client.fetch_agent_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, api=api, max_content=args.max_answer, @@ -627,14 +622,13 @@ def main(): if args.format == 'json': print(json.dumps(trace_to_dict(trace, "agent"), indent=2)) else: - print_agent_text(trace, explain_client, api, args.user) + print_agent_text(trace, explain_client, api) elif trace_type == "docrag": # Fetch and display DocRAG trace trace = explain_client.fetch_docrag_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, api=api, max_content=args.max_answer, @@ -643,14 +637,13 @@ def main(): if args.format == 'json': print(json.dumps(trace_to_dict(trace, "docrag"), indent=2)) else: - print_docrag_text(trace, explain_client, api, args.user) + print_docrag_text(trace, explain_client, api) else: # Fetch and display GraphRAG trace trace = explain_client.fetch_graphrag_trace( args.question_id, graph=RETRIEVAL_GRAPH, - user=args.user, collection=args.collection, api=api, max_content=args.max_answer, @@ -661,7 +654,7 @@ def main(): else: print_graphrag_text( trace, explain_client, flow, - args.user, args.collection, + args.collection, api=api, show_provenance=args.show_provenance ) diff --git a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py index 4f87712c..49bf78ee 100644 --- a/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py +++ b/trustgraph-cli/trustgraph/cli/show_extraction_provenance.py @@ -17,7 +17,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = 'trustgraph' +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' # Predicates @@ -45,10 +45,9 @@ TYPE_MAP = { SOURCE_GRAPH = "urn:graph:source" -def query_triples(socket, flow_id, user, collection, s=None, p=None, o=None, g=None, limit=1000): +def query_triples(socket, flow_id, collection, s=None, p=None, o=None, g=None, limit=1000): """Query triples using the socket API.""" request = { - "user": user, "collection": collection, "limit": limit, "streaming": False, @@ -120,9 +119,9 @@ def extract_value(term): return str(term) -def get_node_metadata(socket, flow_id, user, collection, node_uri): +def get_node_metadata(socket, flow_id, collection, node_uri): """Get metadata for a node (label, types, title, format).""" - triples = query_triples(socket, flow_id, user, collection, s=node_uri, g=SOURCE_GRAPH) + triples = query_triples(socket, flow_id, collection, s=node_uri, g=SOURCE_GRAPH) metadata = {"uri": node_uri, "types": []} for s, p, o in triples: @@ -146,20 +145,20 @@ def classify_node(metadata): return "unknown" -def get_children(socket, flow_id, user, collection, parent_uri): +def get_children(socket, flow_id, collection, parent_uri): """Get children of a node via prov:wasDerivedFrom.""" triples = query_triples( - socket, flow_id, user, collection, + socket, flow_id, collection, p=PROV_WAS_DERIVED_FROM, o=parent_uri, g=SOURCE_GRAPH ) return [s for s, p, o in triples] -def get_document_content(api, user, doc_id, max_content): +def get_document_content(api, doc_id, max_content): """Fetch document content from librarian API.""" try: library = api.library() - content = library.get_document_content(user=user, id=doc_id) + content = library.get_document_content(id=doc_id) # Try to decode as text try: @@ -173,7 +172,7 @@ def get_document_content(api, user, doc_id, max_content): return f"[Error fetching content: {e}]" -def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_content=False, max_content=200, visited=None): +def build_hierarchy(socket, flow_id, collection, root_uri, api=None, show_content=False, max_content=200, visited=None): """Build document hierarchy tree recursively.""" if visited is None: visited = set() @@ -182,7 +181,7 @@ def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_ return None visited.add(root_uri) - metadata = get_node_metadata(socket, flow_id, user, collection, root_uri) + metadata = get_node_metadata(socket, flow_id, collection, root_uri) node_type = classify_node(metadata) node = { @@ -195,21 +194,21 @@ def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_ # Fetch content if requested if show_content and api: - content = get_document_content(api, user, root_uri, max_content) + content = get_document_content(api, root_uri, max_content) if content: node["content"] = content # Get children - children_uris = get_children(socket, flow_id, user, collection, root_uri) + children_uris = get_children(socket, flow_id, collection, root_uri) for child_uri in children_uris: - child_metadata = get_node_metadata(socket, flow_id, user, collection, child_uri) + child_metadata = get_node_metadata(socket, flow_id, collection, child_uri) child_type = classify_node(child_metadata) if child_type == "subgraph": # Subgraphs contain extracted edges — inline them contains_triples = query_triples( - socket, flow_id, user, collection, + socket, flow_id, collection, s=child_uri, p=TG_CONTAINS, g=SOURCE_GRAPH ) for _, _, edge in contains_triples: @@ -218,7 +217,7 @@ def build_hierarchy(socket, flow_id, user, collection, root_uri, api=None, show_ else: # Recurse into pages, chunks, etc. child_node = build_hierarchy( - socket, flow_id, user, collection, child_uri, + socket, flow_id, collection, child_uri, api=api, show_content=show_content, max_content=max_content, visited=visited ) @@ -331,9 +330,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -371,14 +370,13 @@ def main(): args = parser.parse_args() try: - api = Api(args.api_url, token=args.token) + api = Api(args.api_url, token=args.token, workspace=args.workspace) socket = api.socket() try: hierarchy = build_hierarchy( socket=socket, flow_id=args.flow_id, - user=args.user, collection=args.collection, root_uri=args.document_id, api=api if args.show_content else None, diff --git a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py index 8d16d098..4924c925 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py @@ -11,6 +11,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def format_parameters(params_metadata, param_type_defs): """ @@ -44,12 +45,13 @@ def format_parameters(params_metadata, param_type_defs): return "\n".join(param_list) -async def fetch_data(client): +async def fetch_data(client, workspace): """Fetch all data needed for show_flow_blueprints concurrently.""" # Round 1: list blueprints resp = await client._send_request("flow", None, { "operation": "list-blueprints", + "workspace": workspace, }) blueprint_names = resp.get("blueprint-names", []) @@ -60,6 +62,7 @@ async def fetch_data(client): blueprint_tasks = [ client._send_request("flow", None, { "operation": "get-blueprint", + "workspace": workspace, "blueprint-name": name, }) for name in blueprint_names @@ -84,6 +87,7 @@ async def fetch_data(client): param_type_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "parameter-type", "key": pt}], }) for pt in param_types_needed @@ -100,14 +104,16 @@ async def fetch_data(client): return blueprint_names, blueprints, param_type_defs -async def _show_flow_blueprints_async(url, token=None): +async def _show_flow_blueprints_async(url, token=None, workspace="default"): async with AsyncSocketClient(url, timeout=60, token=token) as client: - return await fetch_data(client) + return await fetch_data(client, workspace) -def show_flow_blueprints(url, token=None): +def show_flow_blueprints(url, token=None, workspace="default"): blueprint_names, blueprints, param_type_defs = asyncio.run( - _show_flow_blueprints_async(url, token=token) + _show_flow_blueprints_async( + url, token=token, workspace=workspace, + ) ) if not blueprint_names: @@ -156,6 +162,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -163,6 +175,7 @@ def main(): show_flow_blueprints( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flow_state.py b/trustgraph-cli/trustgraph/cli/show_flow_state.py index d5d87f2c..8fec04ec 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_state.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_state.py @@ -10,10 +10,12 @@ import os default_metrics_url = "http://localhost:8088/api/metrics" default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def dump_status(metrics_url, api_url, flow_id, token=None): +def dump_status(metrics_url, api_url, flow_id, token=None, + workspace="default"): - api = Api(api_url, token=token).flow() + api = Api(api_url, token=token, workspace=workspace).flow() flow = api.get(flow_id) blueprint_name = flow["blueprint-name"] @@ -84,11 +86,20 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: - dump_status(args.metrics_url, args.api_url, args.flow_id, token=args.token) + dump_status( + args.metrics_url, args.api_url, args.flow_id, + token=args.token, workspace=args.workspace, + ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py index f7a14469..6e9479f9 100644 --- a/trustgraph-cli/trustgraph/cli/show_flows.py +++ b/trustgraph-cli/trustgraph/cli/show_flows.py @@ -11,6 +11,7 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def describe_interfaces(intdefs, flow): @@ -97,17 +98,19 @@ def format_parameters(flow_params, blueprint_params_metadata, param_type_defs): return "\n".join(param_list) if param_list else "None" -async def fetch_show_flows(client): +async def fetch_show_flows(client, workspace): """Fetch all data needed for show_flows concurrently.""" # Round 1: list interfaces and list flows in parallel interface_names_resp, flow_ids_resp = await asyncio.gather( client._send_request("config", None, { "operation": "list", + "workspace": workspace, "type": "interface-description", }), client._send_request("flow", None, { "operation": "list-flows", + "workspace": workspace, }), ) @@ -115,12 +118,13 @@ async def fetch_show_flows(client): flow_ids = flow_ids_resp.get("flow-ids", []) if not flow_ids: - return {}, [], {}, {} + return {}, [], {}, {}, {} # Round 2: get all interfaces + all flows in parallel interface_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "interface-description", "key": name}], }) for name in interface_names @@ -129,6 +133,7 @@ async def fetch_show_flows(client): flow_tasks = [ client._send_request("flow", None, { "operation": "get-flow", + "workspace": workspace, "flow-id": fid, }) for fid in flow_ids @@ -163,6 +168,7 @@ async def fetch_show_flows(client): blueprint_tasks = [ client._send_request("flow", None, { "operation": "get-blueprint", + "workspace": workspace, "blueprint-name": bp_name, }) for bp_name in blueprint_names @@ -186,6 +192,7 @@ async def fetch_show_flows(client): param_type_tasks = [ client._send_request("config", None, { "operation": "get", + "workspace": workspace, "keys": [{"type": "parameter-type", "key": pt}], }) for pt in param_types_needed @@ -204,14 +211,16 @@ async def fetch_show_flows(client): return interface_defs, flow_ids, flows, blueprints, param_type_defs -async def _show_flows_async(url, token=None): +async def _show_flows_async(url, token=None, workspace="default"): async with AsyncSocketClient(url, timeout=60, token=token) as client: - return await fetch_show_flows(client) + return await fetch_show_flows(client, workspace) -def show_flows(url, token=None): +def show_flows(url, token=None, workspace="default"): - result = asyncio.run(_show_flows_async(url, token=token)) + result = asyncio.run(_show_flows_async( + url, token=token, workspace=workspace, + )) interface_defs, flow_ids, flows, blueprints, param_type_defs = result @@ -269,6 +278,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -276,6 +291,7 @@ def main(): show_flows( url=args.api_url, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_graph.py b/trustgraph-cli/trustgraph/cli/show_graph.py index 8db4edf4..6063b05a 100644 --- a/trustgraph-cli/trustgraph/cli/show_graph.py +++ b/trustgraph-cli/trustgraph/cli/show_graph.py @@ -13,9 +13,9 @@ import os from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") # Named graph constants for convenience GRAPH_DEFAULT = "" @@ -23,14 +23,13 @@ GRAPH_SOURCE = "urn:graph:source" GRAPH_RETRIEVAL = "urn:graph:retrieval" -def show_graph(url, flow_id, user, collection, limit, batch_size, graph=None, show_graph_column=False, token=None): +def show_graph(url, flow_id, collection, limit, batch_size, graph=None, show_graph_column=False, token=None, workspace="default"): - socket = Api(url, token=token).socket() + socket = Api(url, token=token, workspace=workspace).socket() flow = socket.flow(flow_id) try: for batch in flow.triples_query_stream( - user=user, collection=collection, s=None, p=None, o=None, g=graph, # Filter by named graph (None = all graphs) @@ -73,12 +72,6 @@ def main(): help=f'Flow ID (default: default)' ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-C', '--collection', default=default_collection, @@ -91,6 +84,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-l', '--limit', type=int, @@ -129,13 +128,13 @@ def main(): show_graph( url = args.api_url, flow_id = args.flow_id, - user = args.user, collection = args.collection, limit = args.limit, batch_size = args.batch_size, graph = graph, show_graph_column = args.show_graph, token = args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_kg_cores.py b/trustgraph-cli/trustgraph/cli/show_kg_cores.py index ea295543..c9d47889 100644 --- a/trustgraph-cli/trustgraph/cli/show_kg_cores.py +++ b/trustgraph-cli/trustgraph/cli/show_kg_cores.py @@ -4,16 +4,15 @@ Shows knowledge cores import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_cores(url, user, token=None): +def show_cores(url, token=None, workspace="default"): - api = Api(url, token=token).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() ids = api.list_kg_cores() @@ -26,7 +25,7 @@ def show_cores(url, user, token=None): def main(): parser = argparse.ArgumentParser( - prog='tg-show-flows', + prog='tg-show-kg-cores', description=__doc__, ) @@ -43,9 +42,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -54,8 +53,8 @@ def main(): show_cores( url=args.api_url, - user=args.user, token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -63,4 +62,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_library_documents.py b/trustgraph-cli/trustgraph/cli/show_library_documents.py index 6eeceb70..12a89f1a 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_documents.py +++ b/trustgraph-cli/trustgraph/cli/show_library_documents.py @@ -10,13 +10,13 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_docs(url, user, token=None): +def show_docs(url, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - docs = api.get_documents(user=user) + docs = api.get_documents() if len(docs) == 0: print("No documents.") @@ -60,9 +60,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) args = parser.parse_args() @@ -71,8 +71,8 @@ def main(): show_docs( url = args.api_url, - user = args.user, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_library_processing.py b/trustgraph-cli/trustgraph/cli/show_library_processing.py index 9ab69355..700a0f83 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/show_library_processing.py @@ -4,18 +4,17 @@ import argparse import os import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -default_user = "trustgraph" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_procs(url, user, token=None): +def show_procs(url, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - procs = api.get_processings(user = user) + procs = api.get_processings() if len(procs) == 0: print("No processing objects.") @@ -52,24 +51,26 @@ def main(): help=f'API URL (default: {default_url})', ) - parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - parser.add_argument( '-t', '--token', default=default_token, help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: show_procs( - url = args.api_url, user = args.user, token = args.token + url=args.api_url, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -77,4 +78,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py index 24cbfcfe..d5f7a1c1 100644 --- a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get_values(type="mcp") @@ -64,6 +65,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -71,6 +78,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_prompts.py b/trustgraph-cli/trustgraph/cli/show_prompts.py index 0e1cb2ae..cad6f317 100644 --- a/trustgraph-cli/trustgraph/cli/show_prompts.py +++ b/trustgraph-cli/trustgraph/cli/show_prompts.py @@ -11,10 +11,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get([ ConfigKey(type="prompt", key="system"), @@ -85,6 +86,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -92,6 +99,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_token_costs.py b/trustgraph-cli/trustgraph/cli/show_token_costs.py index adc13ad7..c7a7bff2 100644 --- a/trustgraph-cli/trustgraph/cli/show_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/show_token_costs.py @@ -13,10 +13,11 @@ tabulate.PRESERVE_WHITESPACE = True default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() models = api.list("token-cost") @@ -68,6 +69,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -75,6 +82,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_tools.py b/trustgraph-cli/trustgraph/cli/show_tools.py index d77f1fae..51aeacbf 100644 --- a/trustgraph-cli/trustgraph/cli/show_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_tools.py @@ -19,10 +19,11 @@ import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def show_config(url, token=None): +def show_config(url, token=None, workspace="default"): - api = Api(url, token=token).config() + api = Api(url, token=token, workspace=workspace).config() values = api.get_values(type="tool") @@ -116,6 +117,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + args = parser.parse_args() try: @@ -123,6 +130,8 @@ def main(): show_config( url=args.api_url, token=args.token, + + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_flow.py b/trustgraph-cli/trustgraph/cli/start_flow.py index e04e241d..f65ffc49 100644 --- a/trustgraph-cli/trustgraph/cli/start_flow.py +++ b/trustgraph-cli/trustgraph/cli/start_flow.py @@ -18,10 +18,12 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def start_flow(url, blueprint_name, flow_id, description, parameters=None, token=None): +def start_flow(url, blueprint_name, flow_id, description, parameters=None, + token=None, workspace="default"): - api = Api(url, token=token).flow() + api = Api(url, token=token, workspace=workspace).flow() api.start( blueprint_name = blueprint_name, @@ -49,6 +51,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-n', '--blueprint-name', required=True, @@ -120,6 +128,7 @@ def main(): description = args.description, parameters = parameters, token = args.token, + workspace = args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_library_processing.py b/trustgraph-cli/trustgraph/cli/start_library_processing.py index ff87ea9f..27b5f33d 100644 --- a/trustgraph-cli/trustgraph/cli/start_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/start_library_processing.py @@ -4,19 +4,18 @@ Submits a library document for processing import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def start_processing( - url, user, document_id, id, flow, collection, tags, token=None + url, document_id, id, flow, collection, tags, + token=None, workspace="default", ): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() if tags: tags = tags.split(",") @@ -27,9 +26,8 @@ def start_processing( id = id, document_id = document_id, flow = flow, - user = user, collection = collection, - tags = tags + tags = tags, ) def main(): @@ -52,9 +50,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -91,14 +89,14 @@ def main(): try: start_processing( - url = args.api_url, - user = args.user, - document_id = args.document_id, - id = args.id, - flow = args.flow_id, - collection = args.collection, - tags = args.tags, - token = args.token, + url=args.api_url, + document_id=args.document_id, + id=args.id, + flow=args.flow_id, + collection=args.collection, + tags=args.tags, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -106,4 +104,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/stop_flow.py b/trustgraph-cli/trustgraph/cli/stop_flow.py index ae3a0415..7e2d0798 100644 --- a/trustgraph-cli/trustgraph/cli/stop_flow.py +++ b/trustgraph-cli/trustgraph/cli/stop_flow.py @@ -10,10 +10,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def stop_flow(url, flow_id, token=None): +def stop_flow(url, flow_id, token=None, workspace="default"): - api = Api(url, token=token).flow() + api = Api(url, token=token, workspace=workspace).flow() api.stop(id = flow_id) @@ -36,6 +37,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-i', '--flow-id', required=True, @@ -50,6 +57,7 @@ def main(): url=args.api_url, flow_id=args.flow_id, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/stop_library_processing.py b/trustgraph-cli/trustgraph/cli/stop_library_processing.py index 3d8a2c56..72a8dbb8 100644 --- a/trustgraph-cli/trustgraph/cli/stop_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/stop_library_processing.py @@ -5,21 +5,17 @@ procesing, it doesn't stop in-flight processing at the moment. import argparse import os -import tabulate -from trustgraph.api import Api, ConfigKey -import json +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") -def stop_processing( - url, user, id, token=None -): +def stop_processing(url, id, token=None, workspace="default"): - api = Api(url, token=token).library() + api = Api(url, token=token, workspace=workspace).library() - api.stop_processing(user = user, id = id) + api.stop_processing(id=id) def main(): @@ -41,9 +37,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default=default_user, - help=f'User ID (default: {default_user})' + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -57,10 +53,10 @@ def main(): try: stop_processing( - url = args.api_url, - user = args.user, - id = args.id, - token = args.token, + url=args.api_url, + id=args.id, + token=args.token, + workspace=args.workspace, ) except Exception as e: @@ -68,4 +64,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/unload_kg_core.py b/trustgraph-cli/trustgraph/cli/unload_kg_core.py index 47f811f3..45c56067 100644 --- a/trustgraph-cli/trustgraph/cli/unload_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/unload_kg_core.py @@ -1,25 +1,21 @@ """ -Starts a load operation on a knowledge core which is already stored by -the knowledge manager. You could load a core with tg-put-kg-core and then -run this utility. +Unloads a knowledge core from a flow. """ import argparse import os -import tabulate from trustgraph.api import Api -import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_flow = "default" -default_collection = "default" -def unload_kg_core(url, user, id, flow, token=None): +def unload_kg_core(url, id, flow, token=None, workspace="default"): - api = Api(url, token=token).knowledge() + api = Api(url, token=token, workspace=workspace).knowledge() - class_names = api.unload_kg_core(user = user, id = id, flow=flow) + api.unload_kg_core(id=id, flow=flow) def main(): @@ -41,9 +37,9 @@ def main(): ) parser.add_argument( - '-U', '--user', - default="trustgraph", - help='API URL (default: trustgraph)', + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', ) parser.add_argument( @@ -55,7 +51,7 @@ def main(): parser.add_argument( '-f', '--flow-id', default=default_flow, - help=f'Flow ID (default: {default_flow}', + help=f'Flow ID (default: {default_flow})', ) args = parser.parse_args() @@ -64,10 +60,10 @@ def main(): unload_kg_core( url=args.api_url, - user=args.user, id=args.id, flow=args.flow_id, token=args.token, + workspace=args.workspace, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index 9491deaa..4ec055b7 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -25,6 +25,7 @@ default_pulsar_url = "http://localhost:8080" default_api_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") default_ui_url = "http://localhost:8888" default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") class HealthChecker: @@ -210,10 +211,10 @@ def check_processors(url: str, min_processors: int, timeout: int, tr, token: Opt return False, tr.t("cli.verify_system_status.processors.error", error=str(e)) -def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if flow blueprints are loaded.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) flow_api = api.flow() blueprints = flow_api.list_blueprints() @@ -227,10 +228,10 @@ def check_flow_blueprints(url: str, timeout: int, tr, token: Optional[str] = Non return False, tr.t("cli.verify_system_status.flow_blueprints.error", error=str(e)) -def check_flows(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_flows(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if flow manager is responding.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) flow_api = api.flow() flows = flow_api.list() @@ -242,10 +243,10 @@ def check_flows(url: str, timeout: int, tr, token: Optional[str] = None) -> Tupl return False, tr.t("cli.verify_system_status.flows.error", error=str(e)) -def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if prompts are loaded.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) config = api.config() # Import ConfigKey here to avoid top-level import issues @@ -268,14 +269,14 @@ def check_prompts(url: str, timeout: int, tr, token: Optional[str] = None) -> Tu return False, tr.t("cli.verify_system_status.prompts.error", error=str(e)) -def check_library(url: str, timeout: int, tr, token: Optional[str] = None) -> Tuple[bool, str]: +def check_library(url: str, timeout: int, tr, token: Optional[str] = None, workspace: str = "default") -> Tuple[bool, str]: """Check if library service is responding.""" try: - api = Api(url, token=token, timeout=timeout) + api = Api(url, token=token, timeout=timeout, workspace=workspace) library_api = api.library() - # Try to get documents (with default user) - docs = library_api.get_documents(user="trustgraph") + # Try to get documents + docs = library_api.get_documents() # Success if we get a valid response (even if empty) return True, tr.t("cli.verify_system_status.library.responding", count=len(docs)) @@ -376,6 +377,12 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)' ) + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + parser.add_argument( '-v', '--verbose', action='store_true', @@ -438,6 +445,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) checker.run_check( @@ -447,6 +455,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) checker.run_check( @@ -456,6 +465,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) print() @@ -471,6 +481,7 @@ def main(): args.check_timeout, tr, args.token, + args.workspace, ) print() diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index c793f9ca..8ea72260 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -26,42 +26,50 @@ class Service(ToolService): self.register_config_handler(self.on_mcp_config, types=["mcp"]) + # Per-workspace MCP service registries self.mcp_services = {} - async def on_mcp_config(self, config, version): + async def on_mcp_config(self, workspace, config, version): - logger.info(f"Got config version {version}") + logger.info( + f"Got config version {version} for workspace {workspace}" + ) if "mcp" not in config: - self.mcp_services = {} + self.mcp_services[workspace] = {} return - self.mcp_services = { + self.mcp_services[workspace] = { k: json.loads(v) for k, v in config["mcp"].items() } - async def invoke_tool(self, name, parameters): + async def invoke_tool(self, workspace, name, parameters): try: - if name not in self.mcp_services: - raise RuntimeError(f"MCP service {name} not known") + ws_services = self.mcp_services.get(workspace, {}) - if "url" not in self.mcp_services[name]: + if name not in ws_services: + raise RuntimeError( + f"MCP service {name} not known in workspace " + f"{workspace}" + ) + + if "url" not in ws_services[name]: raise RuntimeError(f"MCP service {name} URL not defined") - url = self.mcp_services[name]["url"] + url = ws_services[name]["url"] - if "remote-name" in self.mcp_services[name]: - remote_name = self.mcp_services[name]["remote-name"] + if "remote-name" in ws_services[name]: + remote_name = ws_services[name]["remote-name"] else: remote_name = name # Build headers with optional bearer token headers = {} - if "auth-token" in self.mcp_services[name]: - token = self.mcp_services[name]["auth-token"] + if "auth-token" in ws_services[name]: + token = ws_services[name]["auth-token"] headers["Authorization"] = f"Bearer {token}" logger.info(f"Invoking {remote_name} at {url}") diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py index cc5eb85c..c06b8c54 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py @@ -108,7 +108,7 @@ class Aggregator: ) def build_synthesis_request(self, correlation_id, original_question, - user, collection): + collection): """ Build the AgentRequest that triggers the synthesis phase. """ @@ -139,7 +139,6 @@ class Aggregator: state="", group=template.group if template else [], history=history, - user=user, collection=collection, streaming=template.streaming if template else False, session_id=parent_session_id, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index 6daba1a1..01abedf3 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -46,25 +46,20 @@ from ..tool_filter import filter_tools_by_group_and_state, get_next_state logger = logging.getLogger(__name__) -class UserAwareContext: - """Wraps flow interface to inject user context for tools that need it.""" +class FlowContext: + """Wraps flow interface with orchestrator-only scratch state + (explain URIs, response handle, streaming flag). Workspace isolation + is enforced by the flow layer (flow.workspace), not by this class.""" - def __init__(self, flow, user, respond=None, streaming=False): + def __init__(self, flow, respond=None, streaming=False): self._flow = flow - self._user = user self.respond = respond self.streaming = streaming self.current_explain_uri = None self.last_sub_explain_uri = None def __call__(self, service_name): - client = self._flow(service_name) - if service_name in ( - "structured-query-request", - "row-embeddings-query-request", - ): - client._current_user = self._user - return client + return self._flow(service_name) class UsageTracker: @@ -131,7 +126,6 @@ class PatternBase: state="", group=getattr(request, 'group', []), history=[completion_step], - user=request.user, collection=getattr(request, 'collection', 'default'), streaming=False, session_id=getattr(request, 'session_id', ''), @@ -158,9 +152,9 @@ class PatternBase: current_state=getattr(request, 'state', None), ) - def make_context(self, flow, user, respond=None, streaming=False): - """Create a user-aware context wrapper.""" - return UserAwareContext(flow, user, respond=respond, streaming=streaming) + def make_context(self, flow, respond=None, streaming=False): + """Create a flow context wrapper.""" + return FlowContext(flow, respond=respond, streaming=streaming) def build_history(self, request): """Convert AgentStep history into Action objects.""" @@ -249,7 +243,7 @@ class PatternBase: # ---- Provenance emission ------------------------------------------------ - async def emit_session_triples(self, flow, session_uri, question, user, + async def emit_session_triples(self, flow, session_uri, question, collection, respond, streaming, parent_uri=None): """Emit provenance triples for a new session.""" @@ -264,7 +258,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=session_uri, - user=user, collection=collection, ), triples=triples, @@ -281,7 +274,7 @@ class PatternBase: async def emit_pattern_decision_triples( self, flow, session_id, session_uri, pattern, task_type, - user, collection, respond, + collection, respond, ): """Emit provenance triples for a meta-router pattern decision.""" uri = agent_pattern_decision_uri(session_id) @@ -292,7 +285,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -329,7 +322,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=thought_doc_id, - user=request.user, + workspace=flow.workspace, content=act.thought, title=f"Agent Thought: {act.name}", ) @@ -360,7 +353,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=iteration_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=iter_triples, @@ -399,7 +391,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=observation_doc_id, - user=request.user, + workspace=flow.workspace, content=observation_text, title=f"Agent Observation", ) @@ -420,7 +412,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=observation_entity_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=obs_triples, @@ -456,7 +447,7 @@ class PatternBase: try: await self.processor.save_answer_content( doc_id=answer_doc_id, - user=request.user, + workspace=flow.workspace, content=answer_text, title=f"Agent Answer: {request.question[:50]}...", ) @@ -478,7 +469,6 @@ class PatternBase: await flow("explainability").send(Triples( metadata=Metadata( id=final_uri, - user=request.user, collection=getattr(request, 'collection', 'default'), ), triples=final_triples, @@ -496,7 +486,7 @@ class PatternBase: # ---- Orchestrator provenance helpers ------------------------------------ async def emit_decomposition_triples( - self, flow, session_id, session_uri, goals, user, collection, + self, flow, session_id, session_uri, goals, collection, respond, streaming, ): """Emit provenance for a supervisor decomposition step.""" @@ -506,7 +496,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -516,7 +506,7 @@ class PatternBase: )) async def emit_finding_triples( - self, flow, session_id, index, goal, answer_text, user, collection, + self, flow, session_id, index, goal, answer_text, collection, respond, streaming, subagent_session_id="", ): """Emit provenance for a subagent finding.""" @@ -532,7 +522,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=flow.workspace, content=answer_text, title=f"Finding: {goal[:60]}", ) @@ -545,7 +535,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -555,7 +545,7 @@ class PatternBase: )) async def emit_plan_triples( - self, flow, session_id, session_uri, steps, user, collection, + self, flow, session_id, session_uri, steps, collection, respond, streaming, ): """Emit provenance for a plan creation.""" @@ -565,7 +555,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -575,7 +565,7 @@ class PatternBase: )) async def emit_step_result_triples( - self, flow, session_id, index, goal, answer_text, user, collection, + self, flow, session_id, index, goal, answer_text, collection, respond, streaming, ): """Emit provenance for a plan step result.""" @@ -585,7 +575,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=flow.workspace, content=answer_text, title=f"Step result: {goal[:60]}", ) @@ -598,7 +588,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -608,7 +598,7 @@ class PatternBase: )) async def emit_synthesis_triples( - self, flow, session_id, previous_uris, answer_text, user, collection, + self, flow, session_id, previous_uris, answer_text, collection, respond, streaming, termination_reason=None, ): """Emit provenance for a synthesis answer.""" @@ -617,7 +607,7 @@ class PatternBase: doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc" try: await self.processor.save_answer_content( - doc_id=doc_id, user=user, + doc_id=doc_id, workspace=flow.workspace, content=answer_text, title="Synthesis", ) @@ -633,7 +623,7 @@ class PatternBase: GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( - metadata=Metadata(id=uri, user=user, collection=collection), + metadata=Metadata(id=uri, collection=collection), triples=triples, )) await respond(AgentResponse( @@ -751,7 +741,6 @@ class PatternBase: ) for h in history ], - user=request.user, collection=collection, streaming=streaming, session_id=session_id, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 1de31a92..0cc9013f 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -53,7 +53,7 @@ class PlanThenExecutePattern(PatternBase): if iteration_num == 1: await self.emit_session_triples( flow, session_uri, request.question, - request.user, collection, respond, streaming, + collection, respond, streaming, ) logger.info( @@ -109,11 +109,17 @@ class PlanThenExecutePattern(PatternBase): think = self.make_think_callback(respond, streaming) - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) framing = getattr(request, 'framing', '') context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -147,7 +153,7 @@ class PlanThenExecutePattern(PatternBase): step_goals = [ps.get("goal", "") for ps in plan_steps] await self.emit_plan_triples( flow, session_id, session_uri, step_goals, - request.user, collection, respond, streaming, + collection, respond, streaming, ) # Build PlanStep objects @@ -179,7 +185,6 @@ class PlanThenExecutePattern(PatternBase): state=request.state, group=getattr(request, 'group', []), history=new_history, - user=request.user, collection=collection, streaming=streaming, session_id=session_id, @@ -237,9 +242,15 @@ class PlanThenExecutePattern(PatternBase): "result": dep_result, }) - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) @@ -307,7 +318,7 @@ class PlanThenExecutePattern(PatternBase): # Emit step result provenance await self.emit_step_result_triples( flow, session_id, pending_idx, goal, step_result, - request.user, collection, respond, streaming, + collection, respond, streaming, ) # Build execution step for history @@ -327,7 +338,6 @@ class PlanThenExecutePattern(PatternBase): state=request.state, group=getattr(request, 'group', []), history=new_history, - user=request.user, collection=collection, streaming=streaming, session_id=session_id, @@ -352,7 +362,7 @@ class PlanThenExecutePattern(PatternBase): framing = getattr(request, 'framing', '') context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -387,7 +397,7 @@ class PlanThenExecutePattern(PatternBase): last_step_uri = make_step_result_uri(session_id, len(plan) - 1) await self.emit_synthesis_triples( flow, session_id, last_step_uri, - response_text, request.user, collection, respond, streaming, + response_text, collection, respond, streaming, termination_reason="plan-complete", ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 25264c26..4920ebf1 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -61,7 +61,7 @@ class ReactPattern(PatternBase): ) await self.emit_session_triples( flow, session_uri, request.question, - request.user, collection, respond, streaming, + collection, respond, streaming, parent_uri=parent_uri, ) @@ -80,13 +80,20 @@ class ReactPattern(PatternBase): observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id) answer_cb = self.make_answer_callback(respond, streaming, message_id=answer_msg_id) + # Look up the per-workspace agent + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + # Filter tools filtered_tools = self.filter_tools( - self.processor.agent.tools, request, + agent.tools, request, ) # Create temporary agent with filtered tools and optional framing - additional_context = self.processor.agent.additional_context + additional_context = agent.additional_context framing = getattr(request, 'framing', '') if framing: if additional_context: @@ -100,7 +107,7 @@ class ReactPattern(PatternBase): ) context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 3d08154d..b57ca79d 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -42,7 +42,7 @@ from ..tool_filter import validate_tool_config from ..react.types import Final, Action, Tool, Argument from . meta_router import MetaRouter -from . pattern_base import PatternBase, UserAwareContext +from . pattern_base import PatternBase, FlowContext from . react_pattern import ReactPattern from . plan_pattern import PlanThenExecutePattern from . supervisor_pattern import SupervisorPattern @@ -76,10 +76,9 @@ class Processor(AgentService): } ) - self.agent = AgentManager( - tools={}, - additional_context="", - ) + # Per-workspace agent managers and meta-routers + self.agents = {} + self.meta_routers = {} self.tool_service_clients = {} @@ -91,9 +90,6 @@ class Processor(AgentService): # Aggregator for supervisor fan-in self.aggregator = Aggregator() - # Meta-router (initialised on first config load) - self.meta_router = None - self.register_config_handler( self.on_tools_config, types=["tool", "tool-service"] ) @@ -204,13 +200,13 @@ class Processor(AgentService): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): request_id = str(uuid.uuid4()) doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "Agent Answer", document_type="answer", @@ -221,7 +217,7 @@ class Processor(AgentService): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) future = asyncio.get_event_loop().create_future() @@ -247,9 +243,12 @@ class Processor(AgentService): def provenance_session_uri(self, session_id): return agent_session_uri(session_id) - async def on_tools_config(self, config, version): + async def on_tools_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) try: tools = {} @@ -316,7 +315,6 @@ class Processor(AgentService): impl = functools.partial( StructuredQueryImpl, collection=data.get("collection"), - user=None, ) arguments = StructuredQueryImpl.get_arguments() elif impl_id == "row-embeddings-query": @@ -324,7 +322,6 @@ class Processor(AgentService): RowEmbeddingsQueryImpl, schema_name=data.get("schema-name"), collection=data.get("collection"), - user=None, index_name=data.get("index-name"), limit=int(data.get("limit", 10)), ) @@ -408,15 +405,17 @@ class Processor(AgentService): agent_config = config[self.config_key] additional = agent_config.get("additional-context", None) - self.agent = AgentManager( + self.agents[workspace] = AgentManager( tools=tools, additional_context=additional, ) - # Re-initialise meta-router with config - self.meta_router = MetaRouter(config=config) + # Re-initialise meta-router with config for this workspace + self.meta_routers[workspace] = MetaRouter(config=config) - logger.info(f"Loaded {len(tools)} tools") + logger.info( + f"Loaded {len(tools)} tools for workspace {workspace}" + ) except Exception as e: logger.error( @@ -466,7 +465,7 @@ class Processor(AgentService): await self.supervisor_pattern.emit_finding_triples( flow, parent_session_id, finding_index, subagent_goal, answer_text, - template.user, collection, + collection, respond, template.streaming, subagent_session_id=subagent_session_id, ) @@ -486,7 +485,6 @@ class Processor(AgentService): synthesis_request = self.aggregator.build_synthesis_request( correlation_id, original_question=template.question, - user=template.user, collection=getattr(template, 'collection', 'default'), ) @@ -515,10 +513,11 @@ class Processor(AgentService): # If no pattern set and this is the first iteration, route if not pattern and not request.history: - context = UserAwareContext(flow, request.user) + context = FlowContext(flow) - if self.meta_router: - pattern, task_type, framing = await self.meta_router.route( + meta_router = self.meta_routers.get(flow.workspace) + if meta_router: + pattern, task_type, framing = await meta_router.route( request.question, context, usage=usage, ) else: @@ -553,7 +552,6 @@ class Processor(AgentService): await selected.emit_pattern_decision_triples( flow, session_id, session_uri, pattern, getattr(request, 'task_type', ''), - request.user, getattr(request, 'collection', 'default'), respond, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 973a9966..f9a1751d 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -54,7 +54,7 @@ class SupervisorPattern(PatternBase): if iteration_num == 1: await self.emit_session_triples( flow, session_uri, request.question, - request.user, collection, respond, streaming, + collection, respond, streaming, ) logger.info( @@ -99,10 +99,16 @@ class SupervisorPattern(PatternBase): ) framing = getattr(request, 'framing', '') - tools = self.filter_tools(self.processor.agent.tools, request) + agent = self.processor.agents.get(flow.workspace) + if agent is None: + raise RuntimeError( + f"No agent configuration for workspace {flow.workspace}" + ) + + tools = self.filter_tools(agent.tools, request) context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -144,7 +150,7 @@ class SupervisorPattern(PatternBase): # Emit decomposition provenance await self.emit_decomposition_triples( flow, session_id, session_uri, goals, - request.user, collection, respond, streaming, + collection, respond, streaming, ) # Fan out: emit a subagent request for each goal @@ -155,7 +161,6 @@ class SupervisorPattern(PatternBase): state="", group=getattr(request, 'group', []), history=[], - user=request.user, collection=collection, streaming=False, # Subagents don't stream session_id=subagent_session, @@ -207,7 +212,7 @@ class SupervisorPattern(PatternBase): subagent_results = {"(no results)": "No subagent results available"} context = self.make_context( - flow, request.user, + flow, respond=respond, streaming=streaming, ) client = context("prompt-request") @@ -237,7 +242,7 @@ class SupervisorPattern(PatternBase): ] await self.emit_synthesis_triples( flow, session_id, finding_uris, - response_text, request.user, collection, respond, streaming, + response_text, collection, respond, streaming, termination_reason="subagents-complete", ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 1512fa83..7140284f 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -10,6 +10,7 @@ import sys import functools import logging import uuid +from typing import Dict from datetime import datetime, timezone # Module logger @@ -73,10 +74,8 @@ class Processor(AgentService): } ) - self.agent = AgentManager( - tools={}, - additional_context="", - ) + # Per-workspace agent managers + self.agents: Dict[str, AgentManager] = {} # Track active tool service clients for cleanup self.tool_service_clients = {} @@ -193,13 +192,13 @@ class Processor(AgentService): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """ Save answer content to the librarian. Args: doc_id: ID for the answer document - user: User ID + workspace: Workspace for isolation content: Answer text content title: Optional title timeout: Request timeout in seconds @@ -211,7 +210,7 @@ class Processor(AgentService): doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "Agent Answer", document_type="answer", @@ -222,7 +221,7 @@ class Processor(AgentService): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) # Create future for response @@ -249,9 +248,12 @@ class Processor(AgentService): self.pending_librarian_requests.pop(request_id, None) raise RuntimeError(f"Timeout saving answer document {doc_id}") - async def on_tools_config(self, config, version): + async def on_tools_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) try: @@ -321,7 +323,6 @@ class Processor(AgentService): impl = functools.partial( StructuredQueryImpl, collection=data.get("collection"), - user=None # User will be provided dynamically via context ) arguments = StructuredQueryImpl.get_arguments() elif impl_id == "row-embeddings-query": @@ -329,7 +330,6 @@ class Processor(AgentService): RowEmbeddingsQueryImpl, schema_name=data.get("schema-name"), collection=data.get("collection"), - user=None, # User will be provided dynamically via context index_name=data.get("index-name"), # Optional filter limit=int(data.get("limit", 10)) # Max results ) @@ -409,13 +409,17 @@ class Processor(AgentService): agent_config = config[self.config_key] additional = agent_config.get("additional-context", None) - self.agent = AgentManager( + self.agents[workspace] = AgentManager( tools=tools, additional_context=additional ) - logger.info(f"Loaded {len(tools)} tools") - logger.info("Tool configuration reloaded.") + logger.info( + f"Loaded {len(tools)} tools for workspace {workspace}" + ) + logger.info( + f"Tool configuration reloaded for workspace {workspace}." + ) except Exception as e: @@ -460,7 +464,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=session_uri, - user=request.user, collection=collection, ), triples=triples, @@ -557,35 +560,41 @@ class Processor(AgentService): await respond(r) + # Look up the agent for this workspace + workspace = flow.workspace + agent = self.agents.get(workspace) + if agent is None: + logger.error( + f"No agent configuration loaded for workspace " + f"{workspace}" + ) + raise RuntimeError( + f"No agent configuration for workspace {workspace}" + ) + # Apply tool filtering based on request groups and state filtered_tools = filter_tools_by_group_and_state( - tools=self.agent.tools, + tools=agent.tools, requested_groups=getattr(request, 'group', None), current_state=getattr(request, 'state', None) ) - + # Create temporary agent with filtered tools temp_agent = AgentManager( tools=filtered_tools, - additional_context=self.agent.additional_context + additional_context=agent.additional_context ) logger.debug("Call React") - # Create user-aware context wrapper that preserves the flow interface - # but adds user information for tools that need it - class UserAwareContext: - def __init__(self, flow, user): + # Thin wrapper around flow — carries only explain URI state. + class _Context: + def __init__(self, flow): self._flow = flow - self._user = user self.last_sub_explain_uri = None def __call__(self, service_name): - client = self._flow(service_name) - # For query clients that need user context, store it - if service_name in ("structured-query-request", "row-embeddings-query-request"): - client._current_user = self._user - return client + return self._flow(service_name) # Callback: emit Analysis+ToolUse triples before tool executes async def on_action(act_decision): @@ -604,7 +613,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=t_doc_id, - user=request.user, + workspace=flow.workspace, content=act_decision.thought, title=f"Agent Thought: {act_decision.name}", ) @@ -629,7 +638,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=iter_uri, - user=request.user, collection=collection, ), triples=iter_triples, @@ -644,7 +652,7 @@ class Processor(AgentService): explain_triples=iter_triples, )) - user_context = UserAwareContext(flow, request.user) + user_context = _Context(flow) act = await temp_agent.react( question = request.question, @@ -685,7 +693,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=answer_doc_id, - user=request.user, + workspace=flow.workspace, content=f, title=f"Agent Answer: {request.question[:50]}...", ) @@ -706,7 +714,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=final_uri, - user=request.user, collection=collection, ), triples=final_triples, @@ -763,7 +770,7 @@ class Processor(AgentService): try: await self.save_answer_content( doc_id=observation_doc_id, - user=request.user, + workspace=flow.workspace, content=act.observation, title=f"Agent Observation", ) @@ -783,7 +790,6 @@ class Processor(AgentService): await flow("explainability").send(Triples( metadata=Metadata( id=observation_entity_uri, - user=request.user, collection=collection, ), triples=obs_triples, @@ -820,7 +826,6 @@ class Processor(AgentService): ) for h in history ], - user=request.user, collection=collection, streaming=streaming, session_id=session_id, # Pass session_id for provenance continuity diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 6674c999..ae9507ab 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -116,31 +116,26 @@ class McpToolImpl: # This tool implementation knows how to query structured data using natural language class StructuredQueryImpl: - def __init__(self, context, collection=None, user=None): + def __init__(self, context, collection=None): self.context = context - self.collection = collection # For multi-tenant scenarios - self.user = user # User context for multi-tenancy - + self.collection = collection + @staticmethod def get_arguments(): return [ Argument( name="question", - type="string", + type="string", description="Natural language question about structured data (tables, databases, etc.)" ) ] - + async def invoke(self, **arguments): client = self.context("structured-query-request") logger.debug("Structured query question...") - - # Get user from client context if available, otherwise use instance user or default - user = getattr(client, '_current_user', self.user or "trustgraph") - + result = await client.structured_query( question=arguments.get("question"), - user=user, collection=self.collection or "default" ) @@ -159,11 +154,10 @@ class StructuredQueryImpl: # This tool implementation knows how to query row embeddings for semantic search class RowEmbeddingsQueryImpl: - def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10): + def __init__(self, context, schema_name, collection=None, index_name=None, limit=10): self.context = context self.schema_name = schema_name self.collection = collection - self.user = user self.index_name = index_name # Optional: filter to specific index self.limit = limit # Max results to return @@ -190,13 +184,9 @@ class RowEmbeddingsQueryImpl: client = self.context("row-embeddings-query-request") logger.debug("Row embeddings query...") - # Get user from client context if available - user = getattr(client, '_current_user', self.user or "trustgraph") - matches = await client.row_embeddings_query( vector=vector, schema_name=self.schema_name, - user=user, collection=self.collection or "default", index_name=self.index_name, limit=self.limit @@ -250,7 +240,7 @@ class ToolServiceImpl: Initialize a tool service implementation. Args: - context: The context function (provides user info) + context: Flow context (callable resolving service names to clients) request_queue: Full Pulsar topic for requests response_queue: Full Pulsar topic for responses config_values: Dict of config values (e.g., {"collection": "customers"}) @@ -325,17 +315,10 @@ class ToolServiceImpl: logger.debug(f"Config: {self.config_values}") logger.debug(f"Arguments: {arguments}") - # Get user from context if available - user = "trustgraph" - if hasattr(self.context, '_user'): - user = self.context._user - # Get or create the client client = await self._get_or_create_client() - # Call the tool service response = await client.call( - user=user, config=self.config_values, arguments=arguments, ) diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index dc7b357c..a0052c79 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -95,7 +95,7 @@ class Processor(ChunkingService): logger.info(f"Chunking document {v.metadata.id}...") # Get text content (fetches from librarian if needed) - text = await self.get_document_text(v) + text = await self.get_document_text(v, flow.workspace) # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( @@ -144,7 +144,7 @@ class Processor(ChunkingService): await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=chunk_content, document_type="chunk", title=f"Chunk {chunk_index}", @@ -168,7 +168,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -179,7 +178,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), chunk=chunk_content, diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 3f31beb9..c3935e4b 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -92,7 +92,7 @@ class Processor(ChunkingService): logger.info(f"Chunking document {v.metadata.id}...") # Get text content (fetches from librarian if needed) - text = await self.get_document_text(v) + text = await self.get_document_text(v, flow.workspace) # Extract chunk parameters from flow (allows runtime override) chunk_size, chunk_overlap = await self.chunk_document( @@ -140,7 +140,7 @@ class Processor(ChunkingService): await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=chunk_content, document_type="chunk", title=f"Chunk {chunk_index}", @@ -164,7 +164,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -175,7 +174,6 @@ class Processor(ChunkingService): metadata=Metadata( id=c_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), chunk=chunk_content, diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py index 6c897f6b..36af6026 100644 --- a/trustgraph-flow/trustgraph/config/service/config.py +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -9,42 +9,8 @@ from ... tables.config import ConfigTableStore # Module logger logger = logging.getLogger(__name__) -class ConfigurationClass: - - async def keys(self): - return await self.table_store.get_keys(self.type) - - async def values(self): - vals = await self.table_store.get_values(self.type) - return { - v[0]: v[1] - for v in vals - } - - async def get(self, key): - return await self.table_store.get_value(self.type, key) - - async def put(self, key, value): - return await self.table_store.put_config(self.type, key, value) - - async def delete(self, key): - return await self.table_store.delete_key(self.type, key) - - async def has(self, key): - val = await self.table_store.get_value(self.type, key) - return val is not None - class Configuration: - # FIXME: The state is held internally. This only works if there's - # one config service. Should be more than one, and use a - # back-end state store. - - # FIXME: This has state now, but does it address all of the above? - # REVIEW: Above - - # FIXME: Some version vs config race conditions - def __init__(self, push, host, username, password, keyspace): # External function to respond to update @@ -60,34 +26,17 @@ class Configuration: async def get_version(self): return await self.table_store.get_version() - def get(self, type): - - c = ConfigurationClass() - c.table_store = self.table_store - c.type = type - - return c - async def handle_get(self, v): - # for k in v.keys: - # if k.type not in self or k.key not in self[k.type]: - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) + workspace = v.workspace values = [ ConfigValue( type = k.type, key = k.key, - value = await self.table_store.get_value(k.type, k.key) + value = await self.table_store.get_value( + workspace, k.type, k.key + ) ) for k in v.keys ] @@ -96,43 +45,19 @@ class Configuration: version = await self.get_version(), values = values, ) - + async def handle_list(self, v): - # if v.type not in self: - - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = "No such type", - # ), - # ) - return ConfigResponse( version = await self.get_version(), - directory = await self.table_store.get_keys(v.type), + directory = await self.table_store.get_keys( + v.workspace, v.type + ), ) async def handle_getvalues(self, v): - # if v.type not in self: - - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) - - vals = await self.table_store.get_values(v.type) + vals = await self.table_store.get_values(v.workspace, v.type) values = map( lambda x: ConfigValue( @@ -146,39 +71,63 @@ class Configuration: values = list(values), ) + async def handle_getvalues_all_ws(self, v): + """Fetch all values of a given type across all workspaces. + Used by shared processors to load type-scoped config at + startup without enumerating workspaces separately.""" + + vals = await self.table_store.get_values_all_ws(v.type) + + values = [ + ConfigValue( + workspace = row[0], + type = v.type, + key = row[1], + value = row[2], + ) + for row in vals + ] + + return ConfigResponse( + version = await self.get_version(), + values = values, + ) + async def handle_delete(self, v): + workspace = v.workspace types = list(set(k.type for k in v.keys)) for k in v.keys: - - await self.table_store.delete_key(k.type, k.key) + await self.table_store.delete_key(workspace, k.type, k.key) await self.inc_version() - await self.push(types=types) + await self.push(changes={t: [workspace] for t in types}) return ConfigResponse( ) async def handle_put(self, v): + workspace = v.workspace types = list(set(k.type for k in v.values)) for k in v.values: - - await self.table_store.put_config(k.type, k.key, k.value) + await self.table_store.put_config( + workspace, k.type, k.key, k.value + ) await self.inc_version() - await self.push(types=types) + await self.push(changes={t: [workspace] for t in types}) return ConfigResponse( ) - async def get_config(self): + async def get_config(self, workspace): - table = await self.table_store.get_all() + table = await self.table_store.get_all_for_workspace(workspace) config = {} @@ -191,7 +140,7 @@ class Configuration: async def handle_config(self, v): - config = await self.get_config() + config = await self.get_config(v.workspace) return ConfigResponse( version = await self.get_version(), @@ -200,7 +149,20 @@ class Configuration: async def handle(self, msg): - logger.debug(f"Handling config message: {msg.operation}") + logger.debug( + f"Handling config message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + # getvalues-all-ws spans all workspaces, so no workspace + # required; everything else is workspace-scoped. + if msg.operation != "getvalues-all-ws" and not msg.workspace: + return ConfigResponse( + error=Error( + type = "bad-request", + message = "Workspace is required" + ) + ) if msg.operation == "get": @@ -214,6 +176,10 @@ class Configuration: resp = await self.handle_getvalues(msg) + elif msg.operation == "getvalues-all-ws": + + resp = await self.handle_getvalues_all_ws(msg) + elif msg.operation == "delete": resp = await self.handle_delete(msg) diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index fe44b852..56a54ee0 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -128,18 +128,21 @@ class Processor(AsyncProcessor): await self.push() # Startup poke: empty types = everything await self.config_request_consumer.start() - async def push(self, types=None): + async def push(self, changes=None): version = await self.config.get_version() resp = ConfigPush( version = version, - types = types or [], + changes = changes or {}, ) await self.config_push_producer.send(resp) - logger.info(f"Pushed config poke version {version}, types={resp.types}") + logger.info( + f"Pushed config poke version {version}, " + f"changes={resp.changes}" + ) async def on_config_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index d03d4ed6..ab5f78f0 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -33,7 +33,7 @@ class KnowledgeManager: logger.info("Deleting knowledge core...") await self.table_store.delete_kg_core( - request.user, request.id + request.workspace, request.id ) await respond( @@ -63,7 +63,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_triples( - request.user, + request.workspace, request.id, publish_triples, ) @@ -81,7 +81,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_graph_embeddings( - request.user, + request.workspace, request.id, publish_ge, ) @@ -100,7 +100,7 @@ class KnowledgeManager: async def list_kg_cores(self, request, respond): - ids = await self.table_store.list_kg_cores(request.user) + ids = await self.table_store.list_kg_cores(request.workspace) await respond( KnowledgeResponse( @@ -114,12 +114,14 @@ class KnowledgeManager: async def put_kg_core(self, request, respond): + workspace = request.workspace + if request.triples: - await self.table_store.add_triples(request.triples) + await self.table_store.add_triples(workspace, request.triples) if request.graph_embeddings: await self.table_store.add_graph_embeddings( - request.graph_embeddings + workspace, request.graph_embeddings ) await respond( @@ -178,10 +180,15 @@ class KnowledgeManager: if request.flow is None: raise RuntimeError("Flow ID must be specified") - if request.flow not in self.flow_config.flows: - raise RuntimeError("Invalid flow") + workspace = request.workspace + ws_flows = self.flow_config.flows.get(workspace, {}) + if request.flow not in ws_flows: + raise RuntimeError( + f"Invalid flow {request.flow} for workspace " + f"{workspace}" + ) - flow = self.flow_config.flows[request.flow] + flow = ws_flows[request.flow] if "interfaces" not in flow: raise RuntimeError("No defined interfaces") @@ -257,7 +264,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_triples( - request.user, + request.workspace, request.id, publish_triples, ) @@ -272,7 +279,7 @@ class KnowledgeManager: # Remove doc table row await self.table_store.get_graph_embeddings( - request.user, + request.workspace, request.id, publish_ge, ) diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 93017c30..15e8feb6 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -124,19 +124,21 @@ class Processor(AsyncProcessor): await self.knowledge_request_consumer.start() await self.knowledge_response_producer.start() - async def on_knowledge_config(self, config, version): + async def on_knowledge_config(self, workspace, config, version): - logger.info(f"Configuration version: {version}") + logger.info( + f"Configuration version: {version} workspace: {workspace}" + ) if "flow" in config: - self.flows = { + self.flows[workspace] = { k: json.loads(v) for k, v in config["flow"].items() } else: - self.flows = {} + self.flows[workspace] = {} - logger.debug(f"Flows: {self.flows}") + logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") async def process_request(self, v, id): diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 40b8c566..3436ca51 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -200,7 +200,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -215,7 +215,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): content = content.encode('utf-8') @@ -243,7 +243,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -265,7 +265,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -277,7 +276,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 7f9ca71d..f3eb3881 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -93,7 +93,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -114,7 +114,7 @@ class Processor(FlowProcessor): content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) # Content is base64 encoded @@ -157,7 +157,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -179,7 +179,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -191,7 +190,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 66bfe31f..4564fe1f 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -6,9 +6,9 @@ import re logger = logging.getLogger(__name__) -def make_safe_collection_name(user, collection, prefix): +def make_safe_collection_name(workspace, collection, prefix): """ - Create a safe Milvus collection name from user/collection parameters. + Create a safe Milvus collection name from workspace/collection parameters. Milvus only allows letters, numbers, and underscores. """ def sanitize(s): @@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix): safe = 'default' return safe - safe_user = sanitize(user) + safe_workspace = sanitize(workspace) safe_collection = sanitize(collection) - return f"{prefix}_{safe_user}_{safe_collection}" + return f"{prefix}_{safe_workspace}_{safe_collection}" class DocVectors: @@ -49,26 +49,26 @@ class DocVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def collection_exists(self, user, collection): + def collection_exists(self, workspace, collection): """ - Check if any collection exists for this user/collection combination. + Check if any collection exists for this workspace/collection combination. Since collections are dimension-specific, this checks if ANY dimension variant exists. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" all_collections = self.client.list_collections() return any(coll.startswith(prefix) for coll in all_collections) - def create_collection(self, user, collection, dimension=384): + def create_collection(self, workspace, collection, dimension=384): """ No-op for explicit collection creation. Collections are created lazily on first insert with actual dimension. """ - logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") + logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert") - def init_collection(self, dimension, user, collection): + def init_collection(self, dimension, workspace, collection): - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( @@ -116,15 +116,15 @@ class DocVectors: index_params=index_params ) - self.collections[(dimension, user, collection)] = collection_name + self.collections[(dimension, workspace, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, chunk_id, user, collection): + def insert(self, embeds, chunk_id, workspace, collection): dim = len(embeds) - if (dim, user, collection) not in self.collections: - self.init_collection(dim, user, collection) + if (dim, workspace, collection) not in self.collections: + self.init_collection(dim, workspace, collection) data = [ { @@ -134,25 +134,25 @@ class DocVectors: ] self.client.insert( - collection_name=self.collections[(dim, user, collection)], + collection_name=self.collections[(dim, workspace, collection)], data=data ) - def search(self, embeds, user, collection, fields=["chunk_id"], limit=10): + def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10): dim = len(embeds) # Check if collection exists - return empty if not - if (dim, user, collection) not in self.collections: - base_name = make_safe_collection_name(user, collection, self.prefix) + if (dim, workspace, collection) not in self.collections: + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dim}" if not self.client.has_collection(collection_name): logger.info(f"Collection {collection_name} does not exist, returning empty results") return [] # Collection exists but not in cache, add it - self.collections[(dim, user, collection)] = collection_name + self.collections[(dim, workspace, collection)] = collection_name - coll = self.collections[(dim, user, collection)] + coll = self.collections[(dim, workspace, collection)] logger.debug("Loading...") self.client.load_collection( @@ -181,12 +181,12 @@ class DocVectors: return res - def delete_collection(self, user, collection): + def delete_collection(self, workspace, collection): """ - Delete all dimension variants of the collection for the given user/collection. + Delete all dimension variants of the collection for the given workspace/collection. Since collections are created with dimension suffixes, we need to find and delete all. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" # Get all collections and filter for matches @@ -199,10 +199,10 @@ class DocVectors: for collection_name in matching_collections: self.client.drop_collection(collection_name) logger.info(f"Deleted Milvus collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") # Remove from our local cache - keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection] for key in keys_to_remove: del self.collections[key] diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index dcbf6734..7d5a640b 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -6,9 +6,9 @@ import re logger = logging.getLogger(__name__) -def make_safe_collection_name(user, collection, prefix): +def make_safe_collection_name(workspace, collection, prefix): """ - Create a safe Milvus collection name from user/collection parameters. + Create a safe Milvus collection name from workspace/collection parameters. Milvus only allows letters, numbers, and underscores. """ def sanitize(s): @@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix): safe = 'default' return safe - safe_user = sanitize(user) + safe_workspace = sanitize(workspace) safe_collection = sanitize(collection) - return f"{prefix}_{safe_user}_{safe_collection}" + return f"{prefix}_{safe_workspace}_{safe_collection}" class EntityVectors: @@ -49,26 +49,26 @@ class EntityVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def collection_exists(self, user, collection): + def collection_exists(self, workspace, collection): """ - Check if any collection exists for this user/collection combination. + Check if any collection exists for this workspace/collection combination. Since collections are dimension-specific, this checks if ANY dimension variant exists. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" all_collections = self.client.list_collections() return any(coll.startswith(prefix) for coll in all_collections) - def create_collection(self, user, collection, dimension=384): + def create_collection(self, workspace, collection, dimension=384): """ No-op for explicit collection creation. Collections are created lazily on first insert with actual dimension. """ - logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") + logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert") - def init_collection(self, dimension, user, collection): + def init_collection(self, dimension, workspace, collection): - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( @@ -122,15 +122,15 @@ class EntityVectors: index_params=index_params ) - self.collections[(dimension, user, collection)] = collection_name + self.collections[(dimension, workspace, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") - def insert(self, embeds, entity, user, collection, chunk_id=""): + def insert(self, embeds, entity, workspace, collection, chunk_id=""): dim = len(embeds) - if (dim, user, collection) not in self.collections: - self.init_collection(dim, user, collection) + if (dim, workspace, collection) not in self.collections: + self.init_collection(dim, workspace, collection) data = [ { @@ -141,25 +141,25 @@ class EntityVectors: ] self.client.insert( - collection_name=self.collections[(dim, user, collection)], + collection_name=self.collections[(dim, workspace, collection)], data=data ) - def search(self, embeds, user, collection, fields=["entity"], limit=10): + def search(self, embeds, workspace, collection, fields=["entity"], limit=10): dim = len(embeds) # Check if collection exists - return empty if not - if (dim, user, collection) not in self.collections: - base_name = make_safe_collection_name(user, collection, self.prefix) + if (dim, workspace, collection) not in self.collections: + base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dim}" if not self.client.has_collection(collection_name): logger.info(f"Collection {collection_name} does not exist, returning empty results") return [] # Collection exists but not in cache, add it - self.collections[(dim, user, collection)] = collection_name + self.collections[(dim, workspace, collection)] = collection_name - coll = self.collections[(dim, user, collection)] + coll = self.collections[(dim, workspace, collection)] logger.debug("Loading...") self.client.load_collection( @@ -188,12 +188,12 @@ class EntityVectors: return res - def delete_collection(self, user, collection): + def delete_collection(self, workspace, collection): """ - Delete all dimension variants of the collection for the given user/collection. + Delete all dimension variants of the collection for the given workspace/collection. Since collections are created with dimension suffixes, we need to find and delete all. """ - base_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" # Get all collections and filter for matches @@ -206,10 +206,10 @@ class EntityVectors: for collection_name in matching_collections: self.client.drop_collection(collection_name) logger.info(f"Deleted Milvus collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") # Remove from our local cache - keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection] for key in keys_to_remove: del self.collections[key] diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index 362bdec9..12f4cdc6 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -69,19 +69,26 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.register_config_handler(self.on_schema_config, types=["schema"]) self.register_config_handler(self.on_collection_config, types=["collection"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -115,13 +122,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) def get_index_names(self, schema: RowSchema) -> List[str]: """Get all index names for a schema.""" @@ -149,23 +162,29 @@ class Processor(CollectionConfigHandler, FlowProcessor): """Process incoming ExtractedObject and compute embeddings""" obj = msg.value() + workspace = flow.workspace logger.info( f"Computing embeddings for {len(obj.values)} rows, " - f"schema {obj.schema_name}, doc {obj.metadata.id}" + f"schema {obj.schema_name}, doc {obj.metadata.id}, " + f"workspace {workspace}" ) # Validate collection exists before processing - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + if not self.collection_exists(workspace, obj.metadata.collection): logger.warning( - f"Collection {obj.metadata.collection} for user {obj.metadata.user} " + f"Collection {obj.metadata.collection} for workspace {workspace} " f"does not exist in config. Dropping message." ) return - # Get schema definition - schema = self.schemas.get(obj.schema_name) + # Get schema definition for this workspace + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(obj.schema_name) if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") + logger.warning( + f"No schema found for {obj.schema_name} in " + f"workspace {workspace} - skipping" + ) return # Get all index names for this schema @@ -239,13 +258,13 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error("Exception during embedding computation", exc_info=True) raise e - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Collection creation notification - no action needed for embedding stage""" - logger.debug(f"Row embeddings collection notification for {user}/{collection}") + logger.debug(f"Row embeddings collection notification for {workspace}/{collection}") - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Collection deletion notification - no action needed for embedding stage""" - logger.debug(f"Row embeddings collection delete notification for {user}/{collection}") + logger.debug(f"Row embeddings collection delete notification for {workspace}/{collection}") @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index ce8d6aae..285b956c 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -75,24 +75,36 @@ class Processor(FlowProcessor): ) ) - # Null configuration, should reload quickly - self.manager = PromptManager() + # Per-workspace prompt managers + self.managers = {} - async def on_prompt_config(self, config, version): + async def on_prompt_config(self, workspace, config, version): - logger.info(f"Loading configuration version {version}") + logger.info( + f"Loading configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) return - config = config[self.config_key] + prompt_config = config[self.config_key] try: - self.manager.load_config(config) + manager = self.managers.get(workspace) + if manager is None: + manager = PromptManager() + self.managers[workspace] = manager - logger.info("Prompt configuration reloaded") + manager.load_config(prompt_config) + + logger.info( + f"Prompt configuration reloaded for {workspace}" + ) except Exception as e: @@ -107,7 +119,6 @@ class Processor(FlowProcessor): metadata = Metadata( id = metadata.id, root = metadata.root, - user = metadata.user, collection = metadata.collection, ), triples = triples, @@ -120,7 +131,6 @@ class Processor(FlowProcessor): metadata = Metadata( id = metadata.id, root = metadata.root, - user = metadata.user, collection = metadata.collection, ), entities = entity_contexts, @@ -170,13 +180,24 @@ class Processor(FlowProcessor): try: v = msg.value() + workspace = flow.workspace # Extract chunk text chunk_text = v.chunk.decode('utf-8') - logger.debug("Processing chunk for agent extraction") + logger.debug( + f"Processing chunk for agent extraction, " + f"workspace {workspace}" + ) - prompt = self.manager.render( + manager = self.managers.get(workspace) + if manager is None: + logger.error( + f"No prompt configuration for workspace {workspace}" + ) + return + + prompt = manager.render( self.template_id, { "text": chunk_text diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 9b5bbb79..31f45ae9 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -213,7 +213,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch @@ -227,7 +226,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index e024ad40..a05f4dfe 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -109,20 +109,22 @@ class Processor(FlowProcessor): # Register config handler for ontology updates self.register_config_handler(self.on_ontology_config, types=["ontology"]) - # Shared components (not flow-specific) - self.ontology_loader = OntologyLoader() + # Per-workspace ontology loaders + self.ontology_loaders = {} # workspace -> OntologyLoader self.text_processor = TextProcessor() - # Per-flow components (each flow gets its own embedder/vector store/selector) - self.flow_components = {} # flow_id -> {embedder, vector_store, selector} + # Per-flow components (each flow gets its own embedder/vector + # store/selector). Keyed by id(flow) — Flow objects are unique + # per (workspace, flow), so this is implicitly workspace-scoped. + self.flow_components = {} # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.3) - # Track loaded ontology version - self.current_ontology_version = None - self.loaded_ontology_ids = set() + # Per-workspace ontology version tracking + self.current_ontology_versions = {} # workspace -> version + self.loaded_ontology_ids = {} # workspace -> set of ids async def initialize_flow_components(self, flow): """Initialize per-flow OntoRAG components. @@ -167,17 +169,23 @@ class Processor(FlowProcessor): vector_store=vector_store ) - # Embed all loaded ontologies for this flow - if self.ontology_loader.get_all_ontologies(): - logger.info(f"Embedding ontologies for flow {flow_id}") - for ont_id, ontology in self.ontology_loader.get_all_ontologies().items(): + workspace = flow.workspace + + # Embed all loaded ontologies for this workspace + loader = self.ontology_loaders.get(workspace) + if loader is not None and loader.get_all_ontologies(): + logger.info( + f"Embedding ontologies for flow {flow_id} " + f"(workspace {workspace})" + ) + for ont_id, ontology in loader.get_all_ontologies().items(): await ontology_embedder.embed_ontology(ontology) logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}") # Initialize ontology selector ontology_selector = OntologySelector( ontology_embedder=ontology_embedder, - ontology_loader=self.ontology_loader, + ontology_loader=loader, top_k=self.top_k, similarity_threshold=self.similarity_threshold ) @@ -187,7 +195,8 @@ class Processor(FlowProcessor): 'embedder': ontology_embedder, 'vector_store': vector_store, 'selector': ontology_selector, - 'dimension': dimension + 'dimension': dimension, + 'workspace': workspace, } logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})") @@ -197,31 +206,27 @@ class Processor(FlowProcessor): logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True) raise - async def on_ontology_config(self, config, version): - """ - Handle ontology configuration updates from ConfigPush queue. - - Parses and stores ontologies. Embedding happens per-flow on first message. - - Called automatically when: - - Processor starts (gets full config history via start_of_messages=True) - - Config service pushes updates (immediate event-driven notification) - - Args: - config: Full configuration map - config[type][key] = value - version: Config version number (monotonically increasing) - """ + async def on_ontology_config(self, workspace, config, version): + """Handle ontology configuration updates for a workspace.""" try: - logger.info(f"Received ontology config update, version={version}") + logger.info( + f"Received ontology config update, " + f"version={version} workspace={workspace}" + ) - # Skip if we've already processed this version - if version == self.current_ontology_version: - logger.debug(f"Already at version {version}, skipping") + # Skip if we've already processed this version for this workspace + if version == self.current_ontology_versions.get(workspace): + logger.debug( + f"Already at version {version} for {workspace}, " + f"skipping" + ) return # Extract ontology configurations if "ontology" not in config: - logger.warning("No 'ontology' section in config") + logger.warning( + f"No 'ontology' section in config for {workspace}" + ) return ontology_configs = config["ontology"] @@ -235,38 +240,65 @@ class Processor(FlowProcessor): logger.error(f"Failed to parse ontology '{ont_id}': {e}") continue - logger.info(f"Loaded {len(ontologies)} ontology definitions") + logger.info( + f"Loaded {len(ontologies)} ontology definitions " + f"for {workspace}" + ) - # Determine what changed (for incremental updates) + # Determine what changed for this workspace + ws_loaded_ids = self.loaded_ontology_ids.get(workspace, set()) new_ids = set(ontologies.keys()) - added_ids = new_ids - self.loaded_ontology_ids - removed_ids = self.loaded_ontology_ids - new_ids - updated_ids = new_ids & self.loaded_ontology_ids # May have changed content + added_ids = new_ids - ws_loaded_ids + removed_ids = ws_loaded_ids - new_ids + updated_ids = new_ids & ws_loaded_ids # May have changed content if added_ids: - logger.info(f"New ontologies: {added_ids}") + logger.info(f"New ontologies in {workspace}: {added_ids}") if removed_ids: - logger.info(f"Removed ontologies: {removed_ids}") + logger.info(f"Removed ontologies in {workspace}: {removed_ids}") if updated_ids: - logger.info(f"Updated ontologies: {updated_ids}") + logger.info(f"Updated ontologies in {workspace}: {updated_ids}") - # Update ontology loader's internal state - self.ontology_loader.update_ontologies(ontologies) + # Get or create per-workspace loader + loader = self.ontology_loaders.get(workspace) + if loader is None: + loader = OntologyLoader() + self.ontology_loaders[workspace] = loader + loader.update_ontologies(ontologies) - # Clear all flow components to force re-embedding with new ontologies + # Clear flow components for this workspace to force + # re-embedding with new ontologies. if added_ids or removed_ids or updated_ids: - logger.info("Clearing flow components to trigger re-embedding") - self.flow_components.clear() + self._clear_workspace_flow_components(workspace) # Update tracking - self.current_ontology_version = version - self.loaded_ontology_ids = new_ids + self.current_ontology_versions[workspace] = version + self.loaded_ontology_ids[workspace] = new_ids - logger.info(f"Ontology config update complete, version={version}") + logger.info( + f"Ontology config update complete for {workspace}, " + f"version={version}" + ) except Exception as e: logger.error(f"Failed to process ontology config: {e}", exc_info=True) + def _clear_workspace_flow_components(self, workspace): + """Drop cached flow components belonging to the given workspace + so they're re-initialised on next message with fresh ontology + embeddings.""" + to_remove = [ + fid for fid, comp in self.flow_components.items() + if comp.get("workspace") == workspace + ] + if to_remove: + logger.info( + f"Clearing {len(to_remove)} flow components for " + f"workspace {workspace}" + ) + for fid in to_remove: + del self.flow_components[fid] + async def on_message(self, msg, consumer, flow): """Process incoming chunk message.""" v = msg.value() @@ -624,7 +656,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=metadata.id, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=triples, @@ -637,7 +668,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=metadata.id, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), entities=entities, diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 8068a23d..ee3e2ed2 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -207,7 +207,6 @@ class Processor(FlowProcessor): Metadata( id=v.metadata.id, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), batch diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 973bb3d7..f1dd4fe0 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -84,32 +84,39 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type schemas_config = config[self.config_key] - + # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): - + try: # Parse the JSON schema definition schema_def = json.loads(schema_json) - + # Create Field objects fields = [] for field_def in schema_def.get("fields", []): @@ -124,21 +131,27 @@ class Processor(FlowProcessor): indexed=field_def.get("indexed", False) ) fields.append(field) - + # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), fields=fields ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - + + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) + except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def extract_objects_for_schema(self, text: str, schema_name: str, schema: RowSchema, flow) -> List[Dict[str, Any]]: """Extract objects from text for a specific schema""" @@ -234,18 +247,26 @@ class Processor(FlowProcessor): """Process incoming chunk and extract objects""" v = msg.value() - logger.info(f"Extracting objects from chunk {v.metadata.id}...") + workspace = flow.workspace + logger.info( + f"Extracting objects from chunk {v.metadata.id} " + f"(workspace {workspace})..." + ) chunk_text = v.chunk.decode("utf-8") - # If no schemas configured, log warning and return - if not self.schemas: - logger.warning("No schemas configured - skipping extraction") + # If no schemas configured for this workspace, log and return + ws_schemas = self.schemas.get(workspace, {}) + if not ws_schemas: + logger.warning( + f"No schemas configured for workspace {workspace} " + f"- skipping extraction" + ) return try: # Extract objects for each configured schema - for schema_name, schema in self.schemas.items(): + for schema_name, schema in ws_schemas.items(): logger.debug(f"Extracting {schema_name} objects from chunk") @@ -274,7 +295,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=f"{v.metadata.id}:{schema_name}", root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), schema_name=schema_name, diff --git a/trustgraph-flow/trustgraph/flow/service/flow.py b/trustgraph-flow/trustgraph/flow/service/flow.py index a5e4a7e1..ed0158f6 100644 --- a/trustgraph-flow/trustgraph/flow/service/flow.py +++ b/trustgraph-flow/trustgraph/flow/service/flow.py @@ -17,14 +17,18 @@ class FlowConfig: self.config = config self.pubsub = pubsub - # Cache for parameter type definitions to avoid repeated lookups + # Per-workspace cache for parameter type definitions + # Keyed by (workspace, type-name) self.param_type_cache = {} - async def resolve_parameters(self, flow_blueprint, user_params): + async def resolve_parameters( + self, workspace, flow_blueprint, user_params + ): """ Resolve parameters by merging user-provided values with defaults. Args: + workspace: Workspace containing the parameter-type definitions flow_blueprint: The flow blueprint definition dict user_params: User-provided parameters dict (may be None or empty) @@ -55,24 +59,25 @@ class FlowConfig: # Look up the parameter type definition param_type = param_meta.get("type") if param_type: + cache_key = (workspace, param_type) # Check cache first - if param_type not in self.param_type_cache: + if cache_key not in self.param_type_cache: try: # Fetch parameter type definition from config store type_def = await self.config.get( - "parameter-type", param_type + workspace, "parameter-type", param_type ) if type_def: - self.param_type_cache[param_type] = json.loads(type_def) + self.param_type_cache[cache_key] = json.loads(type_def) else: logger.warning(f"Parameter type '{param_type}' not found in config") - self.param_type_cache[param_type] = {} + self.param_type_cache[cache_key] = {} except Exception as e: logger.error(f"Error fetching parameter type '{param_type}': {e}") - self.param_type_cache[param_type] = {} + self.param_type_cache[cache_key] = {} # Apply default from type definition (as string) - type_def = self.param_type_cache[param_type] + type_def = self.param_type_cache[cache_key] if "default" in type_def: default_value = type_def["default"] # Convert to string based on type @@ -94,8 +99,9 @@ class FlowConfig: else: # Controller has no value, try to get default from type definition param_type = param_meta.get("type") - if param_type and param_type in self.param_type_cache: - type_def = self.param_type_cache[param_type] + cache_key = (workspace, param_type) if param_type else None + if cache_key and cache_key in self.param_type_cache: + type_def = self.param_type_cache[cache_key] if "default" in type_def: default_value = type_def["default"] # Convert to string based on type @@ -114,7 +120,9 @@ class FlowConfig: async def handle_list_blueprints(self, msg): - names = list(await self.config.keys("flow-blueprint")) + names = list(await self.config.keys( + msg.workspace, "flow-blueprint" + )) return FlowResponse( error = None, @@ -126,14 +134,14 @@ class FlowConfig: return FlowResponse( error = None, blueprint_definition = await self.config.get( - "flow-blueprint", msg.blueprint_name + msg.workspace, "flow-blueprint", msg.blueprint_name ), ) async def handle_put_blueprint(self, msg): await self.config.put( - "flow-blueprint", + msg.workspace, "flow-blueprint", msg.blueprint_name, msg.blueprint_definition ) @@ -145,7 +153,9 @@ class FlowConfig: logger.debug(f"Flow config message: {msg}") - await self.config.delete("flow-blueprint", msg.blueprint_name) + await self.config.delete( + msg.workspace, "flow-blueprint", msg.blueprint_name + ) return FlowResponse( error = None, @@ -153,7 +163,7 @@ class FlowConfig: async def handle_list_flows(self, msg): - names = list(await self.config.keys("flow")) + names = list(await self.config.keys(msg.workspace, "flow")) return FlowResponse( error = None, @@ -162,7 +172,9 @@ class FlowConfig: async def handle_get_flow(self, msg): - flow_data = await self.config.get("flow", msg.flow_id) + flow_data = await self.config.get( + msg.workspace, "flow", msg.flow_id + ) flow = json.loads(flow_data) return FlowResponse( @@ -174,37 +186,49 @@ class FlowConfig: async def handle_start_flow(self, msg): + workspace = msg.workspace + if msg.blueprint_name is None: raise RuntimeError("No blueprint name") if msg.flow_id is None: raise RuntimeError("No flow ID") - if msg.flow_id in await self.config.keys("flow"): + if msg.flow_id in await self.config.keys(workspace, "flow"): raise RuntimeError("Flow already exists") if msg.description is None: raise RuntimeError("No description") - if msg.blueprint_name not in await self.config.keys("flow-blueprint"): + if msg.blueprint_name not in await self.config.keys( + workspace, "flow-blueprint" + ): raise RuntimeError("Blueprint does not exist") cls = json.loads( - await self.config.get("flow-blueprint", msg.blueprint_name) + await self.config.get( + workspace, "flow-blueprint", msg.blueprint_name + ) ) # Resolve parameters by merging user-provided values with defaults user_params = msg.parameters if msg.parameters else {} - parameters = await self.resolve_parameters(cls, user_params) + parameters = await self.resolve_parameters( + workspace, cls, user_params + ) # Log the resolved parameters for debugging logger.debug(f"User provided parameters: {user_params}") logger.debug(f"Resolved parameters (with defaults): {parameters}") - # Apply parameter substitution to template replacement function + # Apply parameter substitution to template replacement function. + # {workspace} is substituted from msg.workspace to isolate + # queue names across workspaces. def repl_template_with_params(tmp): result = tmp.replace( + "{workspace}", workspace + ).replace( "{blueprint}", msg.blueprint_name ).replace( "{id}", msg.flow_id @@ -253,7 +277,7 @@ class FlowConfig: json.dumps(entry), )) - await self.config.put_many(updates) + await self.config.put_many(workspace, updates) def repl_interface(i): return { @@ -270,7 +294,7 @@ class FlowConfig: interfaces = {} await self.config.put( - "flow", msg.flow_id, + workspace, "flow", msg.flow_id, json.dumps({ "description": msg.description, "blueprint-name": msg.blueprint_name, @@ -283,68 +307,77 @@ class FlowConfig: error = None, ) - async def ensure_existing_flow_topics(self): - """Ensure topics exist for all already-running flows. + async def ensure_existing_flow_topics(self, workspaces): + """Ensure topics exist for all already-running flows across + the given workspaces. Called on startup to handle flows that were started before this version of the flow service was deployed, or before a restart. """ - flow_ids = await self.config.keys("flow") + for workspace in workspaces: + flow_ids = await self.config.keys(workspace, "flow") - for flow_id in flow_ids: - try: - flow_data = await self.config.get("flow", flow_id) - if flow_data is None: - continue - - flow = json.loads(flow_data) - - blueprint_name = flow.get("blueprint-name") - if blueprint_name is None: - continue - - # Skip flows that are mid-shutdown - if flow.get("status") == "stopping": - continue - - parameters = flow.get("parameters", {}) - - blueprint_data = await self.config.get( - "flow-blueprint", blueprint_name - ) - if blueprint_data is None: - logger.warning( - f"Blueprint '{blueprint_name}' not found for " - f"flow '{flow_id}', skipping topic creation" + for flow_id in flow_ids: + try: + flow_data = await self.config.get( + workspace, "flow", flow_id ) - continue + if flow_data is None: + continue - cls = json.loads(blueprint_data) + flow = json.loads(flow_data) - def repl_template(tmp): - result = tmp.replace( - "{blueprint}", blueprint_name - ).replace( - "{id}", flow_id + blueprint_name = flow.get("blueprint-name") + if blueprint_name is None: + continue + + # Skip flows that are mid-shutdown + if flow.get("status") == "stopping": + continue + + parameters = flow.get("parameters", {}) + + blueprint_data = await self.config.get( + workspace, "flow-blueprint", blueprint_name ) - for param_name, param_value in parameters.items(): - result = result.replace( - f"{{{param_name}}}", str(param_value) + if blueprint_data is None: + logger.warning( + f"Blueprint '{blueprint_name}' not found " + f"for flow '{workspace}/{flow_id}', skipping " + f"topic creation" ) - return result + continue - topics = self._collect_flow_topics(cls, repl_template) - for topic in topics: - await self.pubsub.ensure_topic(topic) + cls = json.loads(blueprint_data) - logger.info( - f"Ensured topics for existing flow '{flow_id}'" - ) + def repl_template(tmp): + result = tmp.replace( + "{workspace}", workspace + ).replace( + "{blueprint}", blueprint_name + ).replace( + "{id}", flow_id + ) + for param_name, param_value in parameters.items(): + result = result.replace( + f"{{{param_name}}}", str(param_value) + ) + return result - except Exception as e: - logger.error( - f"Failed to ensure topics for flow '{flow_id}': {e}" - ) + topics = self._collect_flow_topics(cls, repl_template) + for topic in topics: + await self.pubsub.ensure_topic(topic) + + logger.info( + f"Ensured topics for existing flow " + f"'{workspace}/{flow_id}'" + ) + + except Exception as e: + logger.error( + f"Failed to ensure topics for flow " + f"'{workspace}/{flow_id}': {e}" + ) def _collect_flow_topics(self, cls, repl_template): """Collect unique topic identifiers from the blueprint. @@ -393,79 +426,95 @@ class FlowConfig: return topics - async def _live_owned_topic_closure(self, exclude_flow_id=None): - """Union of flow-owned topics referenced by all live flows. + async def _live_owned_topic_closure( + self, exclude_workspace=None, exclude_flow_id=None, + ): + """Union of flow-owned topics referenced by all live flows, + across every workspace. Walks every flow record currently registered in the config - service (except ``exclude_flow_id``, typically the flow being - torn down), resolves its blueprint + parameter templates, and - collects the set of flow-owned topics those templates produce. + service (except the single ``(exclude_workspace, exclude_flow_id)`` + pair — typically the flow being torn down), resolves its + blueprint + parameter templates, and collects the set of + flow-owned topics those templates produce. Used to drive closure-based topic cleanup on flow stop: a - topic may only be deleted if no remaining live flow would - still template to it. This handles all three scoping cases - transparently — ``{id}`` topics have no other references once - their flow is excluded; ``{blueprint}`` topics stay alive - while another flow of the same blueprint exists; ``{workspace}`` - (when introduced) stays alive while any flow in the workspace - exists. + topic may only be deleted if no remaining live flow (in any + workspace) would still template to it. This handles all + scoping cases transparently — ``{id}`` topics have no other + references once their flow is excluded; ``{blueprint}`` topics + stay alive while another flow of the same blueprint exists; + ``{workspace}`` topics stay alive while any flow in the same + workspace remains. """ live = set() - flow_ids = await self.config.keys("flow") + workspaces = await self.config.workspaces_for_type("flow") - for fid in flow_ids: + for ws in workspaces: - if fid == exclude_flow_id: - continue + flow_ids = await self.config.keys(ws, "flow") - try: - frec_raw = await self.config.get("flow", fid) - if frec_raw is None: + for fid in flow_ids: + + if ws == exclude_workspace and fid == exclude_flow_id: continue - frec = json.loads(frec_raw) - except Exception as e: - logger.warning( - f"Closure sweep: skipping flow {fid}: {e}" - ) - continue - # Flows mid-shutdown don't keep their topics alive. - if frec.get("status") == "stopping": - continue - - bp_name = frec.get("blueprint-name") - if bp_name is None: - continue - - try: - bp_raw = await self.config.get("flow-blueprint", bp_name) - if bp_raw is None: - continue - bp = json.loads(bp_raw) - except Exception as e: - logger.warning( - f"Closure sweep: skipping flow {fid} " - f"(blueprint {bp_name}): {e}" - ) - continue - - parameters = frec.get("parameters", {}) - - def repl(tmp, bp_name=bp_name, fid=fid, parameters=parameters): - result = tmp.replace( - "{blueprint}", bp_name - ).replace( - "{id}", fid - ) - for pname, pvalue in parameters.items(): - result = result.replace( - f"{{{pname}}}", str(pvalue) + try: + frec_raw = await self.config.get(ws, "flow", fid) + if frec_raw is None: + continue + frec = json.loads(frec_raw) + except Exception as e: + logger.warning( + f"Closure sweep: skipping flow {ws}/{fid}: {e}" ) - return result + continue - live.update(self._collect_owned_topics(bp, repl)) + # Flows mid-shutdown don't keep their topics alive. + if frec.get("status") == "stopping": + continue + + bp_name = frec.get("blueprint-name") + if bp_name is None: + continue + + try: + bp_raw = await self.config.get( + ws, "flow-blueprint", bp_name + ) + if bp_raw is None: + continue + bp = json.loads(bp_raw) + except Exception as e: + logger.warning( + f"Closure sweep: skipping flow {ws}/{fid} " + f"(blueprint {bp_name}): {e}" + ) + continue + + parameters = frec.get("parameters", {}) + + def repl( + tmp, + ws=ws, bp_name=bp_name, fid=fid, + parameters=parameters, + ): + result = tmp.replace( + "{workspace}", ws + ).replace( + "{blueprint}", bp_name + ).replace( + "{id}", fid + ) + for pname, pvalue in parameters.items(): + result = result.replace( + f"{{{pname}}}", str(pvalue) + ) + return result + + live.update(self._collect_owned_topics(bp, repl)) return live @@ -501,13 +550,17 @@ class FlowConfig: async def handle_stop_flow(self, msg): + workspace = msg.workspace + if msg.flow_id is None: raise RuntimeError("No flow ID") - if msg.flow_id not in await self.config.keys("flow"): + if msg.flow_id not in await self.config.keys(workspace, "flow"): raise RuntimeError("Flow ID invalid") - flow = json.loads(await self.config.get("flow", msg.flow_id)) + flow = json.loads( + await self.config.get(workspace, "flow", msg.flow_id) + ) if "blueprint-name" not in flow: raise RuntimeError("Internal error: flow has no flow blueprint") @@ -516,11 +569,15 @@ class FlowConfig: parameters = flow.get("parameters", {}) cls = json.loads( - await self.config.get("flow-blueprint", blueprint_name) + await self.config.get( + workspace, "flow-blueprint", blueprint_name + ) ) def repl_template(tmp): result = tmp.replace( + "{workspace}", workspace + ).replace( "{blueprint}", blueprint_name ).replace( "{id}", msg.flow_id @@ -539,7 +596,7 @@ class FlowConfig: # The config push tells processors to shut down their consumers. flow["status"] = "stopping" await self.config.put( - "flow", msg.flow_id, json.dumps(flow) + workspace, "flow", msg.flow_id, json.dumps(flow) ) # Delete all processor config entries for this flow. @@ -552,7 +609,7 @@ class FlowConfig: deletes.append((f"processor:{processor}", variant)) - await self.config.delete_many(deletes) + await self.config.delete_many(workspace, deletes) # Phase 2: Closure-based sweep. Only delete topics that no # other live flow still references via its blueprint templates. @@ -560,6 +617,7 @@ class FlowConfig: # of the same blueprint is still running, and {workspace}-scoped # topics while any flow in that workspace remains. live_owned = await self._live_owned_topic_closure( + exclude_workspace=workspace, exclude_flow_id=msg.flow_id, ) @@ -571,13 +629,13 @@ class FlowConfig: kept = this_flow_owned - to_delete if kept: logger.info( - f"Flow {msg.flow_id}: keeping {len(kept)} topics " - f"still referenced by other live flows" + f"Flow {workspace}/{msg.flow_id}: keeping {len(kept)} " + f"topics still referenced by other live flows" ) # Phase 3: Remove the flow record. - if msg.flow_id in await self.config.keys("flow"): - await self.config.delete("flow", msg.flow_id) + if msg.flow_id in await self.config.keys(workspace, "flow"): + await self.config.delete(workspace, "flow", msg.flow_id) return FlowResponse( error = None, @@ -585,7 +643,18 @@ class FlowConfig: async def handle(self, msg): - logger.debug(f"Handling flow message: {msg.operation}") + logger.debug( + f"Handling flow message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + if not msg.workspace: + return FlowResponse( + error=Error( + type="bad-request", + message="Workspace is required", + ), + ) if msg.operation == "list-blueprints": resp = await self.handle_list_blueprints(msg) diff --git a/trustgraph-flow/trustgraph/flow/service/service.py b/trustgraph-flow/trustgraph/flow/service/service.py index e1997452..74077ccb 100644 --- a/trustgraph-flow/trustgraph/flow/service/service.py +++ b/trustgraph-flow/trustgraph/flow/service/service.py @@ -103,7 +103,12 @@ class Processor(AsyncProcessor): await self.pubsub.ensure_topic(self.flow_request_topic) await self.config_client.start() - await self.flow.ensure_existing_flow_topics() + + # Discover workspaces with existing flow config and ensure + # their topics exist before we start accepting requests. + workspaces = await self.config_client.workspaces_for_type("flow") + await self.flow.ensure_existing_flow_topics(workspaces) + await self.flow_request_consumer.start() async def on_flow_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index c721a46a..5bc781a9 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -30,6 +30,7 @@ class ConfigReceiver: self.flow_handlers = [] + # Per-workspace flow tracking: {workspace: {flow_id: flow_def}} self.flows = {} self.config_version = 0 @@ -43,7 +44,7 @@ class ConfigReceiver: v = msg.value() notify_version = v.version - notify_types = set(v.types) + changes = v.changes # Skip if we already have this version or newer if notify_version <= self.config_version: @@ -53,20 +54,27 @@ class ConfigReceiver: ) return - # Gateway cares about flow config - if notify_types and "flow" not in notify_types: + # Gateway cares about flow config — check if any flow + # types changed in any workspace + flow_workspaces = changes.get("flow", []) + if changes and not flow_workspaces: logger.debug( f"Ignoring config notify v{notify_version}, " - f"no flow types in {notify_types}" + f"no flow changes" ) self.config_version = notify_version return logger.info( - f"Config notify v{notify_version}, fetching config..." + f"Config notify v{notify_version} " + f"types={list(changes.keys())}, fetching config..." ) - await self.fetch_and_apply() + # Refresh config for each affected workspace + for workspace in flow_workspaces: + await self.fetch_and_apply_workspace(workspace) + + self.config_version = notify_version except Exception as e: logger.error( @@ -98,20 +106,25 @@ class ConfigReceiver: response_metrics=config_resp_metrics, ) - async def fetch_and_apply(self, retry=False): - """Fetch full config and apply flow changes. + async def fetch_and_apply_workspace(self, workspace, retry=False): + """Fetch config for a single workspace and apply flow changes. If retry=True, keeps retrying until successful.""" while True: try: - logger.info("Fetching config from config service...") + logger.info( + f"Fetching config for workspace {workspace}..." + ) client = self._create_config_client() try: await client.start() resp = await client.request( - ConfigRequest(operation="config"), + ConfigRequest( + operation="config", + workspace=workspace, + ), timeout=10, ) finally: @@ -137,18 +150,22 @@ class ConfigReceiver: flows = config.get("flow", {}) + ws_flows = self.flows.get(workspace, {}) + wanted = list(flows.keys()) - current = list(self.flows.keys()) + current = list(ws_flows.keys()) for k in wanted: if k not in current: - self.flows[k] = json.loads(flows[k]) - await self.start_flow(k, self.flows[k]) + ws_flows[k] = json.loads(flows[k]) + await self.start_flow(workspace, k, ws_flows[k]) for k in current: if k not in wanted: - await self.stop_flow(k, self.flows[k]) - del self.flows[k] + await self.stop_flow(workspace, k, ws_flows[k]) + del ws_flows[k] + + self.flows[workspace] = ws_flows return @@ -164,27 +181,91 @@ class ConfigReceiver: ) return - async def start_flow(self, id, flow): + async def fetch_all_workspaces(self, retry=False): + """Fetch config for all workspaces at startup. + Discovers workspaces via the config service getvalues-all-ws + operation on the flow type.""" - logger.info(f"Starting flow: {id}") + while True: + + try: + logger.info("Discovering workspaces with flows...") + + client = self._create_config_client() + try: + await client.start() + + # Discover workspaces that have any flow config + resp = await client.request( + ConfigRequest( + operation="getvalues-all-ws", + type="flow", + ), + timeout=10, + ) + + if resp.error: + raise RuntimeError( + f"Config error: {resp.error.message}" + ) + + workspaces = { + v.workspace for v in resp.values if v.workspace + } + + # Always include the default workspace, even if + # empty, so that newly-created flows in it can be + # picked up by subsequent notifications. + workspaces.add("default") + + logger.info( + f"Found workspaces with flows: {workspaces}" + ) + + finally: + await client.stop() + + # Fetch and apply config for each workspace + for workspace in workspaces: + await self.fetch_and_apply_workspace( + workspace, retry=retry + ) + + return + + except Exception as e: + if retry: + logger.warning( + f"Workspace fetch failed: {e}, retrying in 2s..." + ) + await asyncio.sleep(2) + continue + logger.error( + f"Workspace fetch exception: {e}", exc_info=True + ) + return + + async def start_flow(self, workspace, id, flow): + + logger.info(f"Starting flow: {workspace}/{id}") for handler in self.flow_handlers: try: - await handler.start_flow(id, flow) + await handler.start_flow(workspace, id, flow) except Exception as e: logger.error( f"Config processing exception: {e}", exc_info=True ) - async def stop_flow(self, id, flow): + async def stop_flow(self, workspace, id, flow): - logger.info(f"Stopping flow: {id}") + logger.info(f"Stopping flow: {workspace}/{id}") for handler in self.flow_handlers: try: - await handler.stop_flow(id, flow) + await handler.stop_flow(workspace, id, flow) except Exception as e: logger.error( f"Config processing exception: {e}", exc_info=True @@ -218,7 +299,7 @@ class ConfigReceiver: # Fetch current config (subscribe-then-fetch pattern) # Retry until config service is available - await self.fetch_and_apply(retry=True) + await self.fetch_all_workspaces(retry=True) logger.info( "Config loader initialised, waiting for notifys..." diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 3a37c4e3..6696afbe 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -16,7 +16,7 @@ class CoreExport: async def process(self, data, error, ok, request): id = request.query["id"] - user = request.query["user"] + workspace = request.query.get("workspace", "default") response = await ok() @@ -41,7 +41,6 @@ class CoreExport: { "m": { "i": data["metadata"]["id"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ @@ -65,7 +64,6 @@ class CoreExport: { "m": { "i": data["metadata"]["id"], - "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -78,7 +76,7 @@ class CoreExport: await kr.process( { "operation": "get-kg-core", - "user": user, + "workspace": workspace, "id": id, }, responder diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index 0ca07319..d03d4efd 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -17,7 +17,7 @@ class CoreImport: async def process(self, data, error, ok, request): id = request.query["id"] - user = request.query["user"] + workspace = request.query.get("workspace", "default") kr = KnowledgeRequestor( backend = self.backend, @@ -43,12 +43,11 @@ class CoreImport: msg = unpacked[1] msg = { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "triples": { "metadata": { "id": id, - "user": user, "collection": "default", # Not used? }, "triples": msg["t"], @@ -61,12 +60,11 @@ class CoreImport: msg = unpacked[1] msg = { "operation": "put-kg-core", - "user": user, + "workspace": workspace, "id": id, "graph-embeddings": { "metadata": { "id": id, - "user": user, "collection": "default", # Not used? }, "entities": [ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py index e70bf6de..2992d99f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py @@ -14,12 +14,12 @@ class DocumentStreamExport: async def process(self, data, error, ok, request): - user = request.query.get("user") + workspace = request.query.get("workspace", "default") document_id = request.query.get("document-id") chunk_size = int(request.query.get("chunk-size", 1024 * 1024)) - if not user or not document_id: - return await error("Missing required parameters: user, document-id") + if not document_id: + return await error("Missing required parameter: document-id") response = await ok() @@ -45,7 +45,7 @@ class DocumentStreamExport: await lr.process( { "operation": "stream-document", - "user": user, + "workspace": workspace, "document-id": document_id, "chunk-size": chunk_size, }, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index de0fe52d..91e47aaf 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -48,7 +48,6 @@ class EntityContextsImport: elt = EntityContexts( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 7c7dc915..3e246335 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -48,7 +48,6 @@ class GraphEmbeddingsImport: elt = GraphEmbeddings( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 592120b1..f3db3290 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -116,18 +116,20 @@ class DispatcherManager: # Format: {"config": {"request": "...", "response": "..."}, ...} self.queue_overrides = queue_overrides or {} + # Flows keyed by (workspace, flow_id) self.flows = {} + # Dispatchers keyed by (workspace, flow_id, kind) self.dispatchers = {} self.dispatcher_lock = asyncio.Lock() - async def start_flow(self, id, flow): - logger.info(f"Starting flow {id}") - self.flows[id] = flow + async def start_flow(self, workspace, id, flow): + logger.info(f"Starting flow {workspace}/{id}") + self.flows[(workspace, id)] = flow return - async def stop_flow(self, id, flow): - logger.info(f"Stopping flow {id}") - del self.flows[id] + async def stop_flow(self, workspace, id, flow): + logger.info(f"Stopping flow {workspace}/{id}") + del self.flows[(workspace, id)] return def dispatch_global_service(self): @@ -203,18 +205,20 @@ class DispatcherManager: async def process_flow_import(self, ws, running, params): + workspace = params.get("workspace", "default") flow = params.get("flow") kind = params.get("kind") - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") if kind not in import_dispatchers: raise RuntimeError("Invalid kind") - key = (flow, kind) + key = (workspace, flow, kind) - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] # FIXME: The -store bit, does it make sense? if kind == "entity-contexts": @@ -242,18 +246,20 @@ class DispatcherManager: async def process_flow_export(self, ws, running, params): + workspace = params.get("workspace", "default") flow = params.get("flow") kind = params.get("kind") - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") if kind not in export_dispatchers: raise RuntimeError("Invalid kind") - key = (flow, kind) + key = (workspace, flow, kind) - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] # FIXME: The -store bit, does it make sense? if kind == "entity-contexts": @@ -286,22 +292,36 @@ class DispatcherManager: async def process_flow_service(self, data, responder, params): + # Workspace can come from URL or from request body, defaulting + # to "default". Having it in the URL allows gateway routing to + # be workspace-aware without touching the body. + workspace = params.get("workspace") + if not workspace and isinstance(data, dict): + workspace = data.get("workspace") + if not workspace: + workspace = "default" + flow = params.get("flow") kind = params.get("kind") - return await self.invoke_flow_service(data, responder, flow, kind) + return await self.invoke_flow_service( + data, responder, workspace, flow, kind, + ) - async def invoke_flow_service(self, data, responder, flow, kind): + async def invoke_flow_service( + self, data, responder, workspace, flow, kind, + ): - if flow not in self.flows: - raise RuntimeError("Invalid flow") + flow_key = (workspace, flow) + if flow_key not in self.flows: + raise RuntimeError(f"Invalid flow {workspace}/{flow}") - key = (flow, kind) + key = (workspace, flow, kind) if key not in self.dispatchers: async with self.dispatcher_lock: if key not in self.dispatchers: - intf_defs = self.flows[flow]["interfaces"] + intf_defs = self.flows[flow_key]["interfaces"] if kind not in intf_defs: raise RuntimeError("This kind not supported by flow") @@ -314,8 +334,8 @@ class DispatcherManager: request_queue = qconfig["request"], response_queue = qconfig["response"], timeout = 120, - consumer = f"{self.prefix}-{flow}-{kind}-request", - subscriber = f"{self.prefix}-{flow}-{kind}-request", + consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request", ) elif kind in sender_dispatchers: dispatcher = sender_dispatchers[kind]( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index fabd5c44..3d610dca 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -47,7 +47,9 @@ class Mux: raise RuntimeError("Bad message") await self.q.put(( - data["id"], data.get("flow"), + data["id"], + data.get("workspace", "default"), + data.get("flow"), data["service"], data["request"] )) @@ -87,8 +89,10 @@ class Mux: # worker[0] still running, move on break - async def start_request_task(self, ws, id, flow, svc, request, workers): - + async def start_request_task( + self, ws, id, workspace, flow, svc, request, workers, + ): + # Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS while len(workers) > MAX_OUTSTANDING_REQUESTS: @@ -106,19 +110,23 @@ class Mux: }) worker = asyncio.create_task( - self.request_task(id, request, responder, flow, svc) + self.request_task( + id, request, responder, workspace, flow, svc, + ) ) workers.append(worker) - async def request_task(self, id, request, responder, flow, svc): + async def request_task( + self, id, request, responder, workspace, flow, svc, + ): try: if flow: await self.dispatcher_manager.invoke_flow_service( - request, responder, flow, svc + request, responder, workspace, flow, svc, ) else: @@ -148,7 +156,7 @@ class Mux: # Get next request on queue item = await asyncio.wait_for(self.q.get(), 1) - id, flow, svc, request = item + id, workspace, flow, svc, request = item except TimeoutError: continue @@ -172,7 +180,7 @@ class Mux: try: await self.start_request_task( - self.ws, id, flow, svc, request, workers + self.ws, id, workspace, flow, svc, request, workers ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py index ad634cab..8f92fa59 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py @@ -53,7 +53,6 @@ class RowsImport: elt = ExtractedObject( metadata=Metadata( id=data["metadata"]["id"], - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), schema_name=data["schema_name"], diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 7267e320..28b0ded5 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -38,7 +38,6 @@ def serialize_triples(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "triples": serialize_subgraph(message.triples), @@ -50,7 +49,6 @@ def serialize_graph_embeddings(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ @@ -68,7 +66,6 @@ def serialize_entity_contexts(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "entities": [ @@ -86,7 +83,6 @@ def serialize_document_embeddings(message): "metadata": { "id": message.metadata.id, "root": message.metadata.root, - "user": message.metadata.user, "collection": message.metadata.collection, }, "chunks": [ @@ -120,8 +116,8 @@ def serialize_document_metadata(message): if message.metadata: ret["metadata"] = serialize_subgraph(message.metadata) - if message.user: - ret["user"] = message.user + if message.workspace: + ret["workspace"] = message.workspace if message.tags is not None: ret["tags"] = message.tags @@ -144,8 +140,8 @@ def serialize_processing_metadata(message): if message.flow: ret["flow"] = message.flow - if message.user: - ret["user"] = message.user + if message.workspace: + ret["workspace"] = message.workspace if message.collection: ret["collection"] = message.collection @@ -164,7 +160,7 @@ def to_document_metadata(x): title = x.get("title", None), comments = x.get("comments", None), metadata = to_subgraph(x["metadata"]), - user = x.get("user", None), + workspace = x.get("workspace", None), tags = x.get("tags", None), ) @@ -175,7 +171,7 @@ def to_processing_metadata(x): document_id = x.get("document-id", None), time = x.get("time", None), flow = x.get("flow", None), - user = x.get("user", None), + workspace = x.get("workspace", None), collection = x.get("collection", None), tags = x.get("tags", None), ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 37f123fa..358faa8d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -49,7 +49,6 @@ class TriplesImport: metadata=Metadata( id=data["metadata"]["id"], root=data["metadata"].get("root", ""), - user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), triples=to_subgraph(data["triples"]), diff --git a/trustgraph-flow/trustgraph/librarian/collection_manager.py b/trustgraph-flow/trustgraph/librarian/collection_manager.py index 34ce1de8..09932adf 100644 --- a/trustgraph-flow/trustgraph/librarian/collection_manager.py +++ b/trustgraph-flow/trustgraph/librarian/collection_manager.py @@ -3,6 +3,7 @@ Collection management for the librarian - uses config service for storage """ import asyncio +import dataclasses import logging import json import uuid @@ -20,7 +21,6 @@ logger = logging.getLogger(__name__) def metadata_to_dict(metadata: CollectionMetadata) -> dict: """Convert CollectionMetadata to dictionary for JSON serialization""" return { - 'user': metadata.user, 'collection': metadata.collection, 'name': metadata.name, 'description': metadata.description, @@ -92,38 +92,38 @@ class CollectionManager: self.pending_config_requests[response_id + "_response"] = response self.pending_config_requests[response_id].set() - async def ensure_collection_exists(self, user: str, collection: str): + async def ensure_collection_exists(self, workspace: str, collection: str): """ Ensure a collection exists, creating it if necessary Args: - user: User ID + workspace: Workspace ID collection: Collection ID """ try: # Check if collection exists via config service request = ConfigRequest( operation='get', - keys=[ConfigKey(type='collection', key=f'{user}:{collection}')] + workspace=workspace, + keys=[ConfigKey(type='collection', key=collection)] ) response = await self.send_config_request(request) # Validate response if not response.values or len(response.values) == 0: - raise Exception(f"Invalid response from config service when checking collection {user}/{collection}") + raise Exception(f"Invalid response from config service when checking collection {workspace}/{collection}") # Check if collection exists (value not None means it exists) if response.values[0].value is not None: - logger.debug(f"Collection {user}/{collection} already exists") + logger.debug(f"Collection {workspace}/{collection} already exists") return # Collection doesn't exist (value is None), proceed to create # Create new collection with default metadata - logger.info(f"Auto-creating collection {user}/{collection}") + logger.info(f"Auto-creating collection {workspace}/{collection}") metadata = CollectionMetadata( - user=user, collection=collection, name=collection, # Default name to collection ID description="", @@ -132,9 +132,10 @@ class CollectionManager: request = ConfigRequest( operation='put', + workspace=workspace, values=[ConfigValue( type='collection', - key=f'{user}:{collection}', + key=collection, value=json.dumps(metadata_to_dict(metadata)) )] ) @@ -144,7 +145,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {user}/{collection} auto-created in config service") + logger.info(f"Collection {workspace}/{collection} auto-created in config service") except Exception as e: logger.error(f"Error ensuring collection exists: {e}") @@ -161,9 +162,10 @@ class CollectionManager: CollectionManagementResponse with list of collections """ try: - # Get all collections from config service + # Get all collections in this workspace from config service config_request = ConfigRequest( operation='getvalues', + workspace=request.workspace, type='collection' ) @@ -172,15 +174,19 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config query failed: {response.error.message}") - # Parse collections and filter by user + # Every value in this workspace is a collection. + # Filter to fields the current schema knows about — older + # persisted values may carry fields that have since been + # dropped (e.g. `user` from the pre-workspace-refactor era). + known_fields = {f.name for f in dataclasses.fields(CollectionMetadata)} collections = [] for config_value in response.values: - if ":" in config_value.key: - coll_user, coll_name = config_value.key.split(":", 1) - if coll_user == request.user: - metadata_dict = json.loads(config_value.value) - metadata = CollectionMetadata(**metadata_dict) - collections.append(metadata) + metadata_dict = json.loads(config_value.value) + metadata_dict = { + k: v for k, v in metadata_dict.items() if k in known_fields + } + metadata = CollectionMetadata(**metadata_dict) + collections.append(metadata) # Apply tag filtering if specified if request.tag_filter: @@ -221,7 +227,6 @@ class CollectionManager: tags = list(request.tags) if request.tags else [] metadata = CollectionMetadata( - user=request.user, collection=request.collection, name=name, description=description, @@ -231,9 +236,10 @@ class CollectionManager: # Send put request to config service config_request = ConfigRequest( operation='put', + workspace=request.workspace, values=[ConfigValue( type='collection', - key=f'{request.user}:{request.collection}', + key=request.collection, value=json.dumps(metadata_to_dict(metadata)) )] ) @@ -243,7 +249,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {request.user}/{request.collection} updated in config service") + logger.info(f"Collection {request.workspace}/{request.collection} updated in config service") # Config service will trigger config push automatically # Storage services will receive update and create/update collections @@ -269,12 +275,13 @@ class CollectionManager: CollectionManagementResponse indicating success or failure """ try: - logger.info(f"Deleting collection {request.user}/{request.collection}") + logger.info(f"Deleting collection {request.workspace}/{request.collection}") # Send delete request to config service config_request = ConfigRequest( operation='delete', - keys=[ConfigKey(type='collection', key=f'{request.user}:{request.collection}')] + workspace=request.workspace, + keys=[ConfigKey(type='collection', key=request.collection)] ) response = await self.send_config_request(config_request) @@ -282,7 +289,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config delete failed: {response.error.message}") - logger.info(f"Collection {request.user}/{request.collection} deleted from config service") + logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service") # Config service will trigger config push automatically # Storage services will receive update and delete collections diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 77232650..af1d69b1 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -48,7 +48,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RuntimeError("Document already exists") @@ -78,7 +78,7 @@ class Librarian: logger.debug("Removing document...") if not await self.table_store.document_exists( - request.user, + request.workspace, request.document_id, ): raise RuntimeError("Document does not exist") @@ -89,17 +89,17 @@ class Librarian: logger.debug(f"Cascade deleting child document {child.id}") try: child_object_id = await self.table_store.get_document_object_id( - child.user, + child.workspace, child.id ) await self.blob_store.remove(child_object_id) - await self.table_store.remove_document(child.user, child.id) + await self.table_store.remove_document(child.workspace, child.id) except Exception as e: logger.warning(f"Failed to delete child document {child.id}: {e}") # Now remove the parent document object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) @@ -108,7 +108,7 @@ class Librarian: # Remove doc table row await self.table_store.remove_document( - request.user, + request.workspace, request.document_id ) @@ -120,10 +120,10 @@ class Librarian: logger.debug("Updating document...") - # You can't update the document ID, user or kind. + # You can't update the document ID, workspace or kind. if not await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RuntimeError("Document does not exist") @@ -139,7 +139,7 @@ class Librarian: logger.debug("Getting document metadata...") doc = await self.table_store.get_document( - request.user, + request.workspace, request.document_id ) @@ -156,7 +156,7 @@ class Librarian: logger.debug("Getting document content...") object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) @@ -180,18 +180,18 @@ class Librarian: raise RuntimeError("Collection parameter is required") if await self.table_store.processing_exists( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.id ): raise RuntimeError("Processing already exists") doc = await self.table_store.get_document( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.document_id ) object_id = await self.table_store.get_document_object_id( - request.processing_metadata.user, + request.processing_metadata.workspace, request.processing_metadata.document_id ) @@ -222,14 +222,14 @@ class Librarian: logger.debug("Removing processing metadata...") if not await self.table_store.processing_exists( - request.user, + request.workspace, request.processing_id, ): raise RuntimeError("Processing object does not exist") # Remove doc table row await self.table_store.remove_processing( - request.user, + request.workspace, request.processing_id ) @@ -239,7 +239,7 @@ class Librarian: async def list_documents(self, request): - docs = await self.table_store.list_documents(request.user) + docs = await self.table_store.list_documents(request.workspace) # Filter out child documents and answer documents by default include_children = getattr(request, 'include_children', False) @@ -256,7 +256,7 @@ class Librarian: async def list_processing(self, request): - procs = await self.table_store.list_processing(request.user) + procs = await self.table_store.list_processing(request.workspace) return LibrarianResponse( processing_metadatas = procs, @@ -276,7 +276,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -312,14 +312,14 @@ class Librarian: "kind": request.document_metadata.kind, "title": request.document_metadata.title, "comments": request.document_metadata.comments, - "user": request.document_metadata.user, + "workspace": request.document_metadata.workspace, "tags": request.document_metadata.tags, }) # Store session in Cassandra await self.table_store.create_upload_session( upload_id=upload_id, - user=request.document_metadata.user, + workspace=request.document_metadata.workspace, document_id=request.document_metadata.id, document_metadata=doc_meta_json, s3_upload_id=s3_upload_id, @@ -352,7 +352,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to upload to this session") # Validate chunk index @@ -419,7 +419,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to complete this upload") # Verify all chunks received @@ -457,7 +457,7 @@ class Librarian: kind=doc_meta_dict["kind"], title=doc_meta_dict.get("title", ""), comments=doc_meta_dict.get("comments", ""), - user=doc_meta_dict["user"], + workspace=doc_meta_dict["workspace"], tags=doc_meta_dict.get("tags", []), metadata=[], # Triples not supported in chunked upload yet ) @@ -488,7 +488,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to abort this upload") # Abort S3 multipart upload @@ -520,7 +520,7 @@ class Librarian: ) # Validate ownership - if session["user"] != request.user: + if session["workspace"] != request.workspace: raise RequestError("Not authorized to view this upload") chunks_received = session["chunks_received"] @@ -548,11 +548,11 @@ class Librarian: async def list_uploads(self, request): """ - List all in-progress uploads for a user. + List all in-progress uploads for a workspace. """ - logger.debug(f"Listing uploads for user {request.user}") + logger.debug(f"Listing uploads for workspace {request.workspace}") - sessions = await self.table_store.list_upload_sessions(request.user) + sessions = await self.table_store.list_upload_sessions(request.workspace) upload_sessions = [ UploadSession( @@ -591,7 +591,7 @@ class Librarian: # Verify parent exists if not await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.parent_id ): raise RequestError( @@ -599,7 +599,7 @@ class Librarian: ) if await self.table_store.document_exists( - request.document_metadata.user, + request.document_metadata.workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -665,7 +665,7 @@ class Librarian: ) object_id = await self.table_store.get_document_object_id( - request.user, + request.workspace, request.document_id ) diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index ed005298..c24a5fe8 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -277,18 +277,22 @@ class Processor(AsyncProcessor): """Forward config responses to collection manager""" await self.collection_manager.on_config_response(message, consumer, flow) - async def on_librarian_config(self, config, version): + async def on_librarian_config(self, workspace, config, version): - logger.info(f"Configuration version: {version}") + logger.info( + f"Configuration version: {version} workspace: {workspace}" + ) if "flow" in config: - self.flows = { + self.flows[workspace] = { k: json.loads(v) for k, v in config["flow"].items() } + else: + self.flows[workspace] = {} - logger.debug(f"Flows: {self.flows}") + logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") def __del__(self): @@ -345,7 +349,6 @@ class Processor(AsyncProcessor): metadata=Metadata( id=doc_uri, root=document.id, - user=processing.user, collection=processing.collection, ), triples=all_triples, @@ -363,10 +366,15 @@ class Processor(AsyncProcessor): logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}") - if processing.flow not in self.flows: - raise RuntimeError("Invalid flow ID") + workspace = processing.workspace + ws_flows = self.flows.get(workspace, {}) + if processing.flow not in ws_flows: + raise RuntimeError( + f"Invalid flow ID {processing.flow} for workspace " + f"{workspace}" + ) - flow = self.flows[processing.flow] + flow = ws_flows[processing.flow] if document.kind == "text/plain": kind = "text-load" @@ -386,7 +394,6 @@ class Processor(AsyncProcessor): metadata = Metadata( id = document.id, root = document.id, - user = processing.user, collection = processing.collection ), document_id = document.id, @@ -398,7 +405,6 @@ class Processor(AsyncProcessor): metadata = Metadata( id = document.id, root = document.id, - user = processing.user, collection = processing.collection ), document_id = document.id, @@ -429,9 +435,9 @@ class Processor(AsyncProcessor): """ # Ensure collection exists when processing is added if hasattr(request, 'processing_metadata') and request.processing_metadata: - user = request.processing_metadata.user + workspace = request.processing_metadata.workspace collection = request.processing_metadata.collection - await self.collection_manager.ensure_collection_exists(user, collection) + await self.collection_manager.ensure_collection_exists(workspace, collection) # Call the original add_processing method return await self.librarian.add_processing(request) diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 46460b1f..a63b60ae 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -50,30 +50,37 @@ class Processor(FlowProcessor): ) ) + # Per-workspace price tables self.prices = {} self.config_key = "token-cost" - # Load token costs from the config service - async def on_cost_config(self, config, version): + async def on_cost_config(self, workspace, config, version): - logger.info(f"Loading metering configuration version {version}") + logger.info( + f"Loading metering configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) + self.prices[workspace] = {} return - config = config[self.config_key] + prices = config[self.config_key] - self.prices = { + self.prices[workspace] = { k: json.loads(v) - for k, v in config.items() + for k, v in prices.items() } - def get_prices(self, modelname): + def get_prices(self, workspace, modelname): - if modelname in self.prices: - model = self.prices[modelname] + ws_prices = self.prices.get(workspace, {}) + if modelname in ws_prices: + model = ws_prices[modelname] return model["input_price"], model["output_price"] return None, None # Return None if model is not found @@ -81,6 +88,8 @@ class Processor(FlowProcessor): v = msg.value() + workspace = flow.workspace + modelname = v.model or "unknown" num_in = v.in_token or 0 num_out = v.out_token or 0 @@ -89,7 +98,9 @@ class Processor(FlowProcessor): __class__.token_metric.labels(model=modelname, direction="input").inc(num_in) __class__.token_metric.labels(model=modelname, direction="output").inc(num_out) - model_input_price, model_output_price = self.get_prices(modelname) + model_input_price, model_output_price = self.get_prices( + workspace, modelname + ) if model_input_price == None: cost_per_call = f"Model Not Found in Price list" diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index c599ce77..5da329d3 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -66,24 +66,37 @@ class Processor(FlowProcessor): self.register_config_handler(self.on_prompt_config, types=["prompt"]) - # Null configuration, should reload quickly - self.manager = PromptManager() + # Per-workspace prompt managers. Populated lazily as config + # arrives for each workspace. + self.managers = {} - async def on_prompt_config(self, config, version): + async def on_prompt_config(self, workspace, config, version): - logger.info(f"Loading prompt configuration version {version}") + logger.info( + f"Loading prompt configuration version {version} " + f"for workspace {workspace}" + ) if self.config_key not in config: - logger.warning(f"No key {self.config_key} in config") + logger.warning( + f"No key {self.config_key} in config for {workspace}" + ) return - config = config[self.config_key] + prompt_config = config[self.config_key] try: - self.manager.load_config(config) + manager = self.managers.get(workspace) + if manager is None: + manager = PromptManager() + self.managers[workspace] = manager - logger.info("Prompt configuration reloaded") + manager.load_config(prompt_config) + + logger.info( + f"Prompt configuration reloaded for {workspace}" + ) except Exception as e: @@ -103,6 +116,29 @@ class Processor(FlowProcessor): # Check if streaming is requested streaming = getattr(v, 'streaming', False) + # Look up the prompt manager for this workspace. If none is + # loaded yet, the request can't be handled. + workspace = flow.workspace + manager = self.managers.get(workspace) + if manager is None: + logger.error( + f"No prompt configuration loaded for workspace {workspace}" + ) + r = PromptResponse( + error=Error( + type="no-configuration", + message=( + f"No prompt configuration for workspace " + f"{workspace}" + ), + ), + text=None, + object=None, + end_of_stream=True, + ) + await flow("response").send(r, properties={"id": id}) + return + try: logger.debug(f"Prompt terms: {v.terms}") @@ -149,7 +185,7 @@ class Processor(FlowProcessor): return "" try: - await self.manager.invoke(kind, input, llm_streaming) + await manager.invoke(kind, input, llm_streaming) except Exception as e: logger.error(f"Prompt streaming exception: {e}", exc_info=True) raise e @@ -177,7 +213,7 @@ class Processor(FlowProcessor): return None try: - resp = await self.manager.invoke(kind, input, llm) + resp = await manager.invoke(kind, input, llm) except Exception as e: logger.error(f"Prompt invocation exception: {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 98350961..0a1d8e0f 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -31,7 +31,7 @@ class Processor(DocumentEmbeddingsQueryService): self.vecstore = DocVectors(store_uri) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -45,7 +45,7 @@ class Processor(DocumentEmbeddingsQueryService): resp = self.vecstore.search( vec, - msg.user, + workspace, msg.collection, limit=msg.limit ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 406f979c..e1bc39fc 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -48,7 +48,7 @@ class Processor(DocumentEmbeddingsQueryService): } ) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -63,7 +63,7 @@ class Processor(DocumentEmbeddingsQueryService): dim = len(vec) # Use dimension suffix in index name - index_name = f"d-{msg.user}-{msg.collection}-{dim}" + index_name = f"d-{workspace}-{msg.collection}-{dim}" # Check if index exists - return empty if not if not self.pinecone.has_index(index_name): diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index f056b1c1..1d59c835 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -65,7 +65,7 @@ class Processor(DocumentEmbeddingsQueryService): """Check if collection exists (no implicit creation)""" return self.qdrant.collection_exists(collection) - async def query_document_embeddings(self, msg): + async def query_document_embeddings(self, workspace, msg): try: @@ -75,7 +75,7 @@ class Processor(DocumentEmbeddingsQueryService): # Use dimension suffix in collection name dim = len(vec) - collection = f"d_{msg.user}_{msg.collection}_{dim}" + collection = f"d_{workspace}_{msg.collection}_{dim}" # Check if collection exists - return empty if not if not self.collection_exists(collection): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index 94eee387..1c5e8160 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -37,7 +37,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -51,7 +51,7 @@ class Processor(GraphEmbeddingsQueryService): resp = self.vecstore.search( vec, - msg.user, + workspace, msg.collection, limit=msg.limit * 2 ) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index ca443a6f..f612e3e8 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -55,7 +55,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -70,7 +70,7 @@ class Processor(GraphEmbeddingsQueryService): dim = len(vec) # Use dimension suffix in index name - index_name = f"t-{msg.user}-{msg.collection}-{dim}" + index_name = f"t-{workspace}-{msg.collection}-{dim}" # Check if index exists - return empty if not if not self.pinecone.has_index(index_name): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index df93ad8b..b8fb1361 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -71,7 +71,7 @@ class Processor(GraphEmbeddingsQueryService): else: return Term(type=LITERAL, value=ent) - async def query_graph_embeddings(self, msg): + async def query_graph_embeddings(self, workspace, msg): try: @@ -81,7 +81,7 @@ class Processor(GraphEmbeddingsQueryService): # Use dimension suffix in collection name dim = len(vec) - collection = f"t_{msg.user}_{msg.collection}_{dim}" + collection = f"t_{workspace}_{msg.collection}_{dim}" # Check if collection exists - return empty if not if not self.collection_exists(collection): diff --git a/trustgraph-flow/trustgraph/query/graphql/schema.py b/trustgraph-flow/trustgraph/query/graphql/schema.py index 0c97b1d9..af136cf7 100644 --- a/trustgraph-flow/trustgraph/query/graphql/schema.py +++ b/trustgraph-flow/trustgraph/query/graphql/schema.py @@ -70,7 +70,7 @@ class GraphQLSchemaBuilder: Build the GraphQL schema with the provided query callback. The query callback will be invoked when resolving queries, with: - - user: str + - workspace: str - collection: str - schema_name: str - row_schema: RowSchema @@ -228,7 +228,7 @@ class GraphQLSchemaBuilder: limit: Optional[int] = 100 ) -> List[graphql_type]: # Get context values - user = info.context["user"] + workspace = info.context["workspace"] collection = info.context["collection"] # Parse the where clause @@ -236,7 +236,7 @@ class GraphQLSchemaBuilder: # Call the query backend results = await query_callback( - user, collection, schema_name, row_schema, + workspace, collection, schema_name, row_schema, filters, limit, order_by, direction ) diff --git a/trustgraph-flow/trustgraph/query/ontology/query_explanation.py b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py index bd72aedc..6cced915 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_explanation.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py @@ -167,7 +167,7 @@ class QueryExplainer: question_components, query_results, processing_metadata ) - # Generate user-friendly explanation + # Generate workspace-friendly explanation user_friendly_explanation = self._generate_user_friendly_explanation( question, question_components, ontology_subsets, final_answer ) @@ -503,7 +503,7 @@ class QueryExplainer: question_components: QuestionComponents, ontology_subsets: List[QueryOntologySubset], final_answer: str) -> str: - """Generate user-friendly explanation of the process.""" + """Generate workspace-friendly explanation of the process.""" explanation_parts = [] # Introduction diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py index ec7884ed..c6057cc1 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_service.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) @dataclass class QueryRequest: - """Query request from user.""" + """Query request from workspace.""" question: str context: Optional[str] = None ontology_hint: Optional[str] = None diff --git a/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py index 3e48ac78..de39a89c 100644 --- a/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py +++ b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py @@ -1,6 +1,6 @@ """ Question analyzer for ontology-sensitive query system. -Decomposes user questions into semantic components. +Decomposes workspace questions into semantic components. """ import logging diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 7fc20303..dd89a8d8 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -1,7 +1,7 @@ """ Row embeddings query service for Qdrant. -Input is query vectors plus user/collection/schema context. +Input is query vectors plus workspace/collection/schema context. Output is matching row index information (index_name, index_value) for use in subsequent Cassandra lookups. """ @@ -70,10 +70,10 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]: - """Find the Qdrant collection for a given user/collection/schema""" + def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: + """Find the Qdrant collection for a given workspace/collection/schema""" prefix = ( - f"rows_{self.sanitize_name(user)}_" + f"rows_{self.sanitize_name(workspace)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) @@ -93,22 +93,22 @@ class Processor(FlowProcessor): return None - async def query_row_embeddings(self, request: RowEmbeddingsRequest): + async def query_row_embeddings(self, workspace, request: RowEmbeddingsRequest): """Execute row embeddings query""" vec = request.vector if not vec: return [] - # Find the collection for this user/collection/schema + # Find the collection for this workspace/collection/schema qdrant_collection = self.find_collection( - request.user, request.collection, request.schema_name + workspace, request.collection, request.schema_name ) if not qdrant_collection: logger.info( f"No Qdrant collection found for " - f"{request.user}/{request.collection}/{request.schema_name}" + f"{workspace}/{request.collection}/{request.schema_name}" ) return [] @@ -163,11 +163,11 @@ class Processor(FlowProcessor): logger.debug( f"Handling row embeddings query for " - f"{request.user}/{request.collection}/{request.schema_name}..." + f"{flow.workspace}/{request.collection}/{request.schema_name}..." ) # Execute query - matches = await self.query_row_embeddings(request) + matches = await self.query_row_embeddings(flow.workspace, request) response = RowEmbeddingsResponse( error=None, diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 019d5610..cabdf617 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -87,12 +87,12 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} - # GraphQL schema builder and generated schema - self.schema_builder = GraphQLSchemaBuilder() - self.graphql_schema = None + # Per-workspace GraphQL schema builders and compiled schemas + self.schema_builders: Dict[str, GraphQLSchemaBuilder] = {} + self.graphql_schemas: Dict[str, Any] = {} # Cassandra session self.cluster = None @@ -133,17 +133,27 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} - self.schema_builder.clear() + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas + + builder = GraphQLSchemaBuilder() + self.schema_builders[workspace] = builder # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) + self.graphql_schemas[workspace] = None return # Get the schemas dictionary for our type @@ -177,17 +187,23 @@ class Processor(FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - self.schema_builder.add_schema(schema_name, row_schema) - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + builder.add_schema(schema_name, row_schema) + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) - # Regenerate GraphQL schema - self.graphql_schema = self.schema_builder.build(self.query_cassandra) + # Regenerate GraphQL schema for this workspace + self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: """Get all index names for a schema.""" @@ -222,7 +238,7 @@ class Processor(FlowProcessor): async def query_cassandra( self, - user: str, + workspace: str, collection: str, schema_name: str, row_schema: RowSchema, @@ -240,7 +256,7 @@ class Processor(FlowProcessor): # Connect if needed self.connect_cassandra() - safe_keyspace = self.sanitize_name(user) + safe_keyspace = self.sanitize_name(workspace) # Try to find an index that matches the filters index_match = self.find_matching_index(row_schema, filters) @@ -389,26 +405,30 @@ class Processor(FlowProcessor): async def execute_graphql_query( self, + workspace: str, query: str, variables: Dict[str, Any], operation_name: Optional[str], - user: str, collection: str ) -> Dict[str, Any]: - """Execute a GraphQL query""" + """Execute a GraphQL query against the workspace's schema""" - if not self.graphql_schema: - raise RuntimeError("No GraphQL schema available - no schemas loaded") + graphql_schema = self.graphql_schemas.get(workspace) + if not graphql_schema: + raise RuntimeError( + f"No GraphQL schema available for workspace {workspace} " + f"- no schemas loaded" + ) # Create context for the query context = { "processor": self, - "user": user, + "workspace": workspace, "collection": collection } # Execute the query - result = await self.graphql_schema.execute( + result = await graphql_schema.execute( query, variable_values=variables, operation_name=operation_name, @@ -454,10 +474,10 @@ class Processor(FlowProcessor): # Execute GraphQL query result = await self.execute_graphql_query( + workspace=flow.workspace, query=request.query, variables=dict(request.variables) if request.variables else {}, operation_name=request.operation_name, - user=request.user, collection=request.collection ) diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index eda83efb..bff9a336 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -30,14 +30,14 @@ class EvaluationError(Exception): pass -async def evaluate(node, triples_client, user, collection, limit=10000): +async def evaluate(node, triples_client, workspace, collection, limit=10000): """ Evaluate a SPARQL algebra node. Args: node: rdflib CompValue algebra node triples_client: TriplesClient instance for triple pattern queries - user: user/keyspace identifier + workspace: workspace/keyspace identifier collection: collection identifier limit: safety limit on results @@ -55,24 +55,24 @@ async def evaluate(node, triples_client, user, collection, limit=10000): logger.warning(f"Unsupported algebra node: {name}") return [{}] - return await handler(node, triples_client, user, collection, limit) + return await handler(node, triples_client, workspace, collection, limit) # --- Node handlers --- -async def _eval_select_query(node, tc, user, collection, limit): +async def _eval_select_query(node, tc, workspace, collection, limit): """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) -async def _eval_project(node, tc, user, collection, limit): +async def _eval_project(node, tc, workspace, collection, limit): """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) variables = [str(v) for v in node.PV] return project(solutions, variables) -async def _eval_bgp(node, tc, user, collection, limit): +async def _eval_bgp(node, tc, workspace, collection, limit): """ Evaluate a Basic Graph Pattern. @@ -107,7 +107,7 @@ async def _eval_bgp(node, tc, user, collection, limit): # Query the triples store results = await _query_pattern( - tc, s_val, p_val, o_val, user, collection, limit + tc, s_val, p_val, o_val, workspace, collection, limit ) # Map results back to variable bindings, @@ -130,17 +130,17 @@ async def _eval_bgp(node, tc, user, collection, limit): return solutions[:limit] -async def _eval_join(node, tc, user, collection, limit): +async def _eval_join(node, tc, workspace, collection, limit): """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, user, collection, limit) - right = await evaluate(node.p2, tc, user, collection, limit) + left = await evaluate(node.p1, tc, workspace, collection, limit) + right = await evaluate(node.p2, tc, workspace, collection, limit) return hash_join(left, right)[:limit] -async def _eval_left_join(node, tc, user, collection, limit): +async def _eval_left_join(node, tc, workspace, collection, limit): """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, user, collection, limit) - right_sols = await evaluate(node.p2, tc, user, collection, limit) + left_sols = await evaluate(node.p1, tc, workspace, collection, limit) + right_sols = await evaluate(node.p2, tc, workspace, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -153,16 +153,16 @@ async def _eval_left_join(node, tc, user, collection, limit): return left_join(left_sols, right_sols, filter_fn)[:limit] -async def _eval_union(node, tc, user, collection, limit): +async def _eval_union(node, tc, workspace, collection, limit): """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, user, collection, limit) - right = await evaluate(node.p2, tc, user, collection, limit) + left = await evaluate(node.p1, tc, workspace, collection, limit) + right = await evaluate(node.p2, tc, workspace, collection, limit) return union(left, right)[:limit] -async def _eval_filter(node, tc, user, collection, limit): +async def _eval_filter(node, tc, workspace, collection, limit): """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) expr = node.expr return [ sol for sol in solutions @@ -170,22 +170,22 @@ async def _eval_filter(node, tc, user, collection, limit): ] -async def _eval_distinct(node, tc, user, collection, limit): +async def _eval_distinct(node, tc, workspace, collection, limit): """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) return distinct(solutions) -async def _eval_reduced(node, tc, user, collection, limit): +async def _eval_reduced(node, tc, workspace, collection, limit): """Evaluate a Reduced node (like Distinct but implementation-defined).""" # Treat same as Distinct - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) return distinct(solutions) -async def _eval_order_by(node, tc, user, collection, limit): +async def _eval_order_by(node, tc, workspace, collection, limit): """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) key_fns = [] for cond in node.expr: @@ -206,7 +206,7 @@ async def _eval_order_by(node, tc, user, collection, limit): return order_by(solutions, key_fns) -async def _eval_slice(node, tc, user, collection, limit): +async def _eval_slice(node, tc, workspace, collection, limit): """Evaluate a Slice node (LIMIT/OFFSET).""" # Pass tighter limit downstream if possible inner_limit = limit @@ -214,13 +214,13 @@ async def _eval_slice(node, tc, user, collection, limit): offset = node.start or 0 inner_limit = min(limit, offset + node.length) - solutions = await evaluate(node.p, tc, user, collection, inner_limit) + solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) return slice_solutions(solutions, node.start or 0, node.length) -async def _eval_extend(node, tc, user, collection, limit): +async def _eval_extend(node, tc, workspace, collection, limit): """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) var_name = str(node.var) expr = node.expr @@ -246,9 +246,9 @@ async def _eval_extend(node, tc, user, collection, limit): return result -async def _eval_group(node, tc, user, collection, limit): +async def _eval_group(node, tc, workspace, collection, limit): """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) # Extract grouping expressions group_exprs = [] @@ -289,9 +289,9 @@ async def _eval_group(node, tc, user, collection, limit): return result -async def _eval_aggregate_join(node, tc, user, collection, limit): +async def _eval_aggregate_join(node, tc, workspace, collection, limit): """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, user, collection, limit) + solutions = await evaluate(node.p, tc, workspace, collection, limit) result = [] for sol in solutions: @@ -310,7 +310,7 @@ async def _eval_aggregate_join(node, tc, user, collection, limit): return result -async def _eval_graph(node, tc, user, collection, limit): +async def _eval_graph(node, tc, workspace, collection, limit): """Evaluate a Graph node (GRAPH clause).""" term = node.term @@ -319,16 +319,16 @@ async def _eval_graph(node, tc, user, collection, limit): # We'd need to pass graph to triples queries # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) elif isinstance(term, Variable): # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) else: - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) -async def _eval_values(node, tc, user, collection, limit): +async def _eval_values(node, tc, workspace, collection, limit): """Evaluate a VALUES clause (inline data).""" variables = [str(v) for v in node.var] solutions = [] @@ -343,9 +343,9 @@ async def _eval_values(node, tc, user, collection, limit): return solutions -async def _eval_to_multiset(node, tc, user, collection, limit): +async def _eval_to_multiset(node, tc, workspace, collection, limit): """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, user, collection, limit) + return await evaluate(node.p, tc, workspace, collection, limit) # --- Aggregate computation --- @@ -487,7 +487,7 @@ def _resolve_term(tmpl, solution): return rdflib_term_to_term(tmpl) -async def _query_pattern(tc, s, p, o, user, collection, limit): +async def _query_pattern(tc, s, p, o, workspace, collection, limit): """ Issue a streaming triple pattern query via TriplesClient. @@ -496,7 +496,7 @@ async def _query_pattern(tc, s, p, o, user, collection, limit): results = await tc.query( s=s, p=p, o=o, limit=limit, - user=user, + workspace=workspace, collection=collection, ) return results diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 38488032..983cd4f6 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -141,7 +141,7 @@ class Processor(FlowProcessor): solutions = await evaluate( parsed.algebra, triples_client, - user=request.user or "trustgraph", + workspace=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, ) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 905aaaf2..efce5968 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -178,34 +178,34 @@ class Processor(TriplesQueryService): self.cassandra_password = password self.table = None - def ensure_connection(self, user): + def ensure_connection(self, workspace): """Ensure we have a connection to the correct keyspace.""" - if user != self.table: + if workspace != self.table: KGClass = EntityCentricKnowledgeGraph if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, ) - self.table = user + self.table = workspace - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: # ensure_connection may construct a fresh # EntityCentricKnowledgeGraph which does sync schema # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per user. - await asyncio.to_thread(self.ensure_connection, query.user) + # so the event loop doesn't block on first-use per workspace. + await asyncio.to_thread(self.ensure_connection, workspace) # Extract values from query s_val = get_term_value(query.s) @@ -359,13 +359,13 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e - async def query_triples_stream(self, query): + async def query_triples_stream(self, workspace, query): """ Streaming query - yields (batch, is_final) tuples. Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, query.user) + await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 @@ -395,7 +395,7 @@ class Processor(TriplesQueryService): else: # For specific patterns, fall back to non-streaming # (these typically return small result sets anyway) - async for batch, is_final in self._fallback_stream(query, batch_size): + async for batch, is_final in self._fallback_stream(workspace, query, batch_size): yield batch, is_final return @@ -452,9 +452,9 @@ class Processor(TriplesQueryService): logger.error(f"Exception in streaming query: {e}", exc_info=True) raise e - async def _fallback_stream(self, query, batch_size): + async def _fallback_stream(self, workspace, query, batch_size): """Fallback to non-streaming query with post-hoc batching.""" - triples = await self.query_triples(query) + triples = await self.query_triples(workspace, query) for i in range(0, len(triples), batch_size): batch = triples[i:i + batch_size] diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index 14b24d52..9781aaaf 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -58,7 +58,7 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index 37633f34..173f07dd 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -63,12 +63,11 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: - # Extract user and collection, use defaults if not provided - user = query.user if query.user else "default" + workspace = workspace collection = query.collection if query.collection else "default" triples = [] @@ -80,13 +79,13 @@ class Processor(TriplesQueryService): # SPO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -94,13 +93,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -112,13 +111,13 @@ class Processor(TriplesQueryService): # SP records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -127,13 +126,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -148,13 +147,13 @@ class Processor(TriplesQueryService): # SO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -163,13 +162,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -182,13 +181,13 @@ class Processor(TriplesQueryService): # S records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -197,13 +196,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -221,13 +220,13 @@ class Processor(TriplesQueryService): # PO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -236,13 +235,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), dest=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -255,13 +254,13 @@ class Processor(TriplesQueryService): # P records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -270,13 +269,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -291,13 +290,13 @@ class Processor(TriplesQueryService): # O records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -306,13 +305,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -325,12 +324,12 @@ class Processor(TriplesQueryService): # * records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -339,12 +338,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 4cb1ab21..b47d49a9 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -63,14 +63,12 @@ class Processor(TriplesQueryService): else: return Term(type=LITERAL, value=ent) - async def query_triples(self, query): + async def query_triples(self, workspace, query): try: - # Extract user and collection, use defaults if not provided - user = query.user if query.user else "default" collection = query.collection if query.collection else "default" - + triples = [] if query.s is not None: @@ -80,13 +78,13 @@ class Processor(TriplesQueryService): # SPO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -94,13 +92,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN $src as src " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -112,13 +110,13 @@ class Processor(TriplesQueryService): # SP records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -127,13 +125,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), rel=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -148,13 +146,13 @@ class Processor(TriplesQueryService): # SO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -163,13 +161,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel " "LIMIT " + str(query.limit), src=get_term_value(query.s), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -182,13 +180,13 @@ class Processor(TriplesQueryService): # S records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -197,13 +195,13 @@ class Processor(TriplesQueryService): triples.append((get_term_value(query.s), data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), src=get_term_value(query.s), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -221,13 +219,13 @@ class Processor(TriplesQueryService): # PO records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -236,13 +234,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src " "LIMIT " + str(query.limit), uri=get_term_value(query.p), dest=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -255,13 +253,13 @@ class Processor(TriplesQueryService): # P records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.value as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -270,13 +268,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], get_term_value(query.p), data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, dest.uri as dest " "LIMIT " + str(query.limit), uri=get_term_value(query.p), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -291,13 +289,13 @@ class Processor(TriplesQueryService): # O records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {value: $value, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), value=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -306,13 +304,13 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], get_term_value(query.o))) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {uri: $uri, user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel " "LIMIT " + str(query.limit), uri=get_term_value(query.o), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -325,12 +323,12 @@ class Processor(TriplesQueryService): # * records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Literal {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Literal {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -339,12 +337,12 @@ class Processor(TriplesQueryService): triples.append((data["src"], data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( - "MATCH (src:Node {user: $user, collection: $collection})-" - "[rel:Rel {user: $user, collection: $collection}]->" - "(dest:Node {user: $user, collection: $collection}) " + "MATCH (src:Node {workspace: $workspace, collection: $collection})-" + "[rel:Rel {workspace: $workspace, collection: $collection}]->" + "(dest:Node {workspace: $workspace, collection: $collection}) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "LIMIT " + str(query.limit), - user=user, collection=collection, + workspace=workspace, collection=collection, database_=self.db, ) @@ -367,7 +365,7 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e - + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index dfe4e051..1864e1ad 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -26,11 +26,11 @@ LABEL="http://www.w3.org/2000/01/rdf-schema#label" class Query: def __init__( - self, rag, user, collection, verbose, + self, rag, workspace, collection, verbose, doc_limit=20, track_usage=None, ): self.rag = rag - self.user = user + self.workspace = workspace self.collection = collection self.verbose = verbose self.doc_limit = doc_limit @@ -97,7 +97,7 @@ class Query: async def query_concept(vec): return await self.rag.doc_embeddings_client.query( vector=vec, limit=per_concept_limit, - user=self.user, collection=self.collection, + collection=self.collection, ) results = await asyncio.gather( @@ -122,7 +122,7 @@ class Query: for match in chunk_matches: if match.chunk_id: try: - content = await self.rag.fetch_chunk(match.chunk_id, self.user) + content = await self.rag.fetch_chunk(match.chunk_id, self.workspace) docs.append(content) chunk_ids.append(match.chunk_id) except Exception as e: @@ -154,7 +154,7 @@ class DocumentRag: logger.debug("DocumentRag initialized") async def query( - self, query, user="trustgraph", collection="default", + self, query, workspace="default", collection="default", doc_limit=20, streaming=False, chunk_callback=None, explain_callback=None, save_answer_callback=None, ): @@ -163,7 +163,7 @@ class DocumentRag: Args: query: The query string - user: User identifier + workspace: Workspace for isolation (also scopes chunk lookup) collection: Collection identifier doc_limit: Max chunks to retrieve streaming: Enable streaming LLM response @@ -210,7 +210,8 @@ class DocumentRag: await explain_callback(q_triples, q_uri) q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose, + rag=self, workspace=workspace, collection=collection, + verbose=self.verbose, doc_limit=doc_limit, track_usage=track_usage, ) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index dc7296ad..30333c0e 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -96,19 +96,19 @@ class Processor(FlowProcessor): await super(Processor, self).start() await self.librarian.start() - async def fetch_chunk_content(self, chunk_id, user, timeout=120): + async def fetch_chunk_content(self, chunk_id, workspace, timeout=120): """Fetch chunk content from librarian. Chunks are small so single request-response is fine.""" return await self.librarian.fetch_document_text( - document_id=chunk_id, user=user, timeout=timeout, + document_id=chunk_id, workspace=workspace, timeout=timeout, ) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """Save answer content to the librarian.""" doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "DocumentRAG Answer", document_type="answer", @@ -119,7 +119,7 @@ class Processor(FlowProcessor): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) await self.librarian.request(request, timeout=timeout) @@ -150,14 +150,13 @@ class Processor(FlowProcessor): doc_limit = self.doc_limit # Real-time explainability callback - emits triples and IDs as they're generated - # Triples are stored in the user's collection with a named graph (urn:graph:retrieval) + # Triples are stored in the request's collection with a named graph (urn:graph:retrieval) async def send_explainability(triples, explain_id): # Send triples to explainability queue - stores in same collection with named graph await flow("explainability").send(Triples( metadata=Metadata( id=explain_id, - user=v.user, - collection=v.collection, # Store in user's collection + collection=v.collection, ), triples=triples, )) @@ -178,7 +177,7 @@ class Processor(FlowProcessor): async def save_answer(doc_id, answer_text): await self.save_answer_content( doc_id=doc_id, - user=v.user, + workspace=flow.workspace, content=answer_text, title=f"DocumentRAG Answer: {v.query[:50]}...", ) @@ -202,7 +201,7 @@ class Processor(FlowProcessor): # All chunks (including final one with end_of_stream=True) are sent via callback response, usage = await self.rag.query( v.query, - user=v.user, + workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, streaming=True, @@ -227,7 +226,7 @@ class Processor(FlowProcessor): # Non-streaming path - single response with answer and token usage response, usage = await self.rag.query( v.query, - user=v.user, + workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, explain_callback=send_explainability, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index a4b14644..81dc8fe2 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -75,12 +75,11 @@ def edge_id(s, p, o): class LRUCacheWithTTL: - """LRU cache with TTL for label caching + """LRU cache with TTL for label caching. - CRITICAL SECURITY WARNING: - This cache is shared within a GraphRag instance but GraphRag instances - are created per-request. Cache keys MUST include user:collection prefix - to ensure data isolation between different security contexts. + GraphRag instances are created per-request, so this cache is + request-scoped. Cache keys include the collection prefix to keep + entries from different collections distinct within one request. """ def __init__(self, max_size=5000, ttl=300): @@ -119,12 +118,11 @@ class LRUCacheWithTTL: class Query: def __init__( - self, rag, user, collection, verbose, + self, rag, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, max_path_length=2, track_usage=None, ): self.rag = rag - self.user = user self.collection = collection self.verbose = verbose self.entity_limit = entity_limit @@ -194,7 +192,7 @@ class Query: entity_tasks = [ self.rag.graph_embeddings_client.query( vector=v, limit=per_concept_limit, - user=self.user, collection=self.collection, + collection=self.collection, ) for v in vectors ] @@ -222,18 +220,18 @@ class Query: async def maybe_label(self, e): - # CRITICAL SECURITY: Cache key MUST include user and collection - # to prevent data leakage between different contexts - cache_key = f"{self.user}:{self.collection}:{e}" + # The label cache lives on a per-request GraphRag instance — no + # cross-request isolation concern. The collection prefix keeps + # entries from different collections distinct within one request. + cache_key = f"{self.collection}:{e}" - # Check LRU cache first with isolated key cached_label = self.rag.label_cache.get(cache_key) if cached_label is not None: return cached_label res = await self.rag.triples_client.query( s=e, p=LABEL, o=None, limit=1, - user=self.user, collection=self.collection, + collection=self.collection, g="", ) @@ -255,19 +253,19 @@ class Query: self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection, + collection=self.collection, batch_size=20, g="", ), self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection, + collection=self.collection, batch_size=20, g="", ), self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, - user=self.user, collection=self.collection, + collection=self.collection, batch_size=20, g="", ) ]) @@ -468,7 +466,7 @@ class Query: subgraph_tasks.append( self.rag.triples_client.query( s=None, p=TG_CONTAINS, o=quoted, limit=1, - user=self.user, collection=self.collection, + collection=self.collection, g=GRAPH_SOURCE, ) ) @@ -501,7 +499,7 @@ class Query: derivation_tasks = [ self.rag.triples_client.query( s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5, - user=self.user, collection=self.collection, + collection=self.collection, g=GRAPH_SOURCE, ) for uri in current_uris @@ -535,7 +533,7 @@ class Query: metadata_tasks = [ self.rag.triples_client.query( s=uri, p=None, o=None, limit=50, - user=self.user, collection=self.collection, + collection=self.collection, ) for uri in doc_uris ] @@ -560,11 +558,9 @@ class Query: class GraphRag: """ - CRITICAL SECURITY: - This class MUST be instantiated per-request to ensure proper isolation - between users and collections. The cache within this instance will only - live for the duration of a single request, preventing cross-contamination - of data between different security contexts. + Must be instantiated per-request so the label cache lives only for + the duration of a single request. Workspace isolation is enforced + by the trusted flow layer (flow.workspace), not by this class. """ def __init__( @@ -587,7 +583,7 @@ class GraphRag: logger.debug("GraphRag initialized") async def query( - self, query, user = "trustgraph", collection = "default", + self, query, collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, max_path_length = 2, edge_score_limit = 30, edge_limit = 25, streaming = False, @@ -600,7 +596,6 @@ class GraphRag: Args: query: The query string - user: User identifier collection: Collection identifier entity_limit: Max entities to retrieve triple_limit: Max triples per entity @@ -657,7 +652,7 @@ class GraphRag: await explain_callback(q_triples, q_uri) q = Query( - rag = self, user = user, collection = collection, + rag = self, collection = collection, verbose = self.verbose, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 15c30ba1..acb111e1 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -62,9 +62,9 @@ class Processor(FlowProcessor): self.default_edge_score_limit = edge_score_limit self.default_edge_limit = edge_limit - # CRITICAL SECURITY: NEVER share data between users or collections - # Each user/collection combination MUST have isolated data access - # Caching must NEVER allow information leakage across these boundaries + # Workspace isolation is enforced by the flow layer (flow.workspace). + # Per-request caching (see GraphRag) keeps within-request state + # scoped; no cross-request sharing here. self.register_specification( ConsumerSpec( @@ -170,13 +170,13 @@ class Processor(FlowProcessor): future = self.pending_librarian_requests.pop(request_id) future.set_result(response) - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): + async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120): """ Save answer content to the librarian. Args: doc_id: ID for the answer document - user: User ID + workspace: Workspace for isolation content: Answer text content title: Optional title timeout: Request timeout in seconds @@ -188,7 +188,7 @@ class Processor(FlowProcessor): doc_metadata = DocumentMetadata( id=doc_id, - user=user, + workspace=workspace, kind="text/plain", title=title or "GraphRAG Answer", document_type="answer", @@ -199,7 +199,7 @@ class Processor(FlowProcessor): document_id=doc_id, document_metadata=doc_metadata, content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), - user=user, + workspace=workspace, ) # Create future for response @@ -241,14 +241,13 @@ class Processor(FlowProcessor): explainability_refs_emitted = [] # Real-time explainability callback - emits triples and IDs as they're generated - # Triples are stored in the user's collection with a named graph (urn:graph:retrieval) + # Triples are stored in the request's collection with a named graph (urn:graph:retrieval) async def send_explainability(triples, explain_id): # Send triples to explainability queue - stores in same collection with named graph await flow("explainability").send(Triples( metadata=Metadata( id=explain_id, - user=v.user, - collection=v.collection, # Store in user's collection, not separate explainability collection + collection=v.collection, ), triples=triples, )) @@ -266,9 +265,9 @@ class Processor(FlowProcessor): explainability_refs_emitted.append(explain_id) - # CRITICAL SECURITY: Create new GraphRag instance per request - # This ensures proper isolation between users and collections - # Flow clients are request-scoped and must not be shared + # Create new GraphRag instance per request — its label cache + # is request-scoped, and flow clients must not be shared + # across requests. rag = GraphRag( embeddings_client=flow("embeddings-request"), graph_embeddings_client=flow("graph-embeddings-request"), @@ -311,7 +310,7 @@ class Processor(FlowProcessor): async def save_answer(doc_id, answer_text): await self.save_answer_content( doc_id=doc_id, - user=v.user, + workspace=flow.workspace, content=answer_text, title=f"GraphRAG Answer: {v.query[:50]}...", ) @@ -333,7 +332,7 @@ class Processor(FlowProcessor): # Query with streaming and real-time explain response, usage = await rag.query( - query = v.query, user = v.user, collection = v.collection, + query = v.query, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, @@ -349,7 +348,7 @@ class Processor(FlowProcessor): else: # Non-streaming path with real-time explain response, usage = await rag.query( - query = v.query, user = v.user, collection = v.collection, + query = v.query, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, @@ -464,7 +463,7 @@ class Processor(FlowProcessor): help=f'Max edges after LLM scoring (default: 25)' ) - # Note: Explainability triples are now stored in the user's collection + # Note: Explainability triples are now stored in the request's collection # with the named graph urn:graph:retrieval (no separate collection needed) def run(): diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py index b567cc7b..091069ad 100644 --- a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py @@ -66,32 +66,39 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} + logger.info("NLP Query service initialized") - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) + + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas + # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return - + # Get the schemas dictionary for our type schemas_config = config[self.config_key] - + # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: # Parse the JSON schema definition schema_def = json.loads(schema_json) - + # Create Field objects fields = [] for field_def in schema_def.get("fields", []): @@ -106,29 +113,37 @@ class Processor(FlowProcessor): indexed=field_def.get("indexed", False) ) fields.append(field) - + # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), fields=fields ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - + + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) + except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def phase1_select_schemas(self, question: str, flow) -> List[str]: """Phase 1: Use prompt service to select relevant schemas for the question""" logger.info("Starting Phase 1: Schema selection") - + + ws_schemas = self.schemas.get(flow.workspace, {}) + # Prepare schema information for the prompt schema_info = [] - for name, schema in self.schemas.items(): + for name, schema in ws_schemas.items(): schema_desc = { "name": name, "description": schema.description, @@ -176,12 +191,14 @@ class Processor(FlowProcessor): async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]: """Phase 2: Generate GraphQL query using selected schemas""" logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}") - + + ws_schemas = self.schemas.get(flow.workspace, {}) + # Get detailed schema information for selected schemas only selected_schema_info = [] for schema_name in selected_schemas: - if schema_name in self.schemas: - schema = self.schemas[schema_name] + if schema_name in ws_schemas: + schema = ws_schemas[schema_name] schema_desc = { "name": schema_name, "description": schema.description, diff --git a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py index b878bf61..6dd79cbb 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py @@ -72,21 +72,28 @@ class Processor(FlowProcessor): # Register config handler for schema updates self.register_config_handler(self.on_schema_config, types=["schema"]) - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} + # Per-workspace schema storage: {workspace: {name: RowSchema}} + self.schemas: Dict[str, Dict[str, RowSchema]] = {} logger.info("Structured Data Diagnosis service initialized") - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -120,13 +127,19 @@ class Processor(FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) async def on_message(self, msg, consumer, flow): """Handle incoming structured data diagnosis request""" @@ -216,15 +229,19 @@ class Processor(FlowProcessor): ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - # Get target schema - if request.schema_name not in self.schemas: + # Get target schema from this workspace's schemas + ws_schemas = self.schemas.get(flow.workspace, {}) + if request.schema_name not in ws_schemas: error = Error( type="SchemaNotFound", - message=f"Schema '{request.schema_name}' not found in configuration" + message=( + f"Schema '{request.schema_name}' not found " + f"in configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - target_schema = self.schemas[request.schema_name] + target_schema = ws_schemas[request.schema_name] # Generate descriptor using prompt service descriptor = await self.generate_descriptor_with_prompt( @@ -260,26 +277,33 @@ class Processor(FlowProcessor): return StructuredDataDiagnosisResponse(error=error, operation=request.operation) # Step 2: Use provided schema name or auto-select first available + ws_schemas = self.schemas.get(flow.workspace, {}) schema_name = request.schema_name - if not schema_name and self.schemas: - schema_name = list(self.schemas.keys())[0] + if not schema_name and ws_schemas: + schema_name = list(ws_schemas.keys())[0] logger.info(f"Auto-selected schema: {schema_name}") if not schema_name: error = Error( type="NoSchemaAvailable", - message="No schema specified and no schemas available in configuration" + message=( + f"No schema specified and no schemas available " + f"in configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - if schema_name not in self.schemas: + if schema_name not in ws_schemas: error = Error( type="SchemaNotFound", - message=f"Schema '{schema_name}' not found in configuration" + message=( + f"Schema '{schema_name}' not found in " + f"configuration for workspace {flow.workspace}" + ) ) return StructuredDataDiagnosisResponse(error=error, operation=request.operation) - target_schema = self.schemas[schema_name] + target_schema = ws_schemas[schema_name] # Step 3: Generate descriptor descriptor = await self.generate_descriptor_with_prompt( @@ -316,8 +340,9 @@ class Processor(FlowProcessor): logger.info("Processing schema-selection operation") # Prepare all schemas for the prompt - match the original config format + ws_schemas = self.schemas.get(flow.workspace, {}) all_schemas = [] - for schema_name, row_schema in self.schemas.items(): + for schema_name, row_schema in ws_schemas.items(): schema_info = { "name": row_schema.name, "description": row_schema.description, diff --git a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py index e39f9041..151703cb 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py @@ -111,9 +111,9 @@ class Processor(FlowProcessor): else: variables_as_strings[key] = str(value) - # Use user/collection values from request + # Use collection from request. Workspace isolation is + # enforced by flow.workspace at the rows-query service. objects_request = RowsQueryRequest( - user=request.user, collection=request.collection, query=nlp_response.graphql_query, variables=variables_as_strings, diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index f5c12441..7c8db19d 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -33,7 +33,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): for emb in message.chunks: @@ -45,7 +45,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if vec: self.vecstore.insert( vec, chunk_id, - message.metadata.user, + workspace, message.metadata.collection ) @@ -60,27 +60,27 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") - self.vecstore.create_collection(user, collection) + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") + self.vecstore.create_collection(workspace, collection) except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for document embeddings via config push""" try: - self.vecstore.delete_collection(user, collection) - logger.info(f"Successfully deleted collection {user}/{collection}") + self.vecstore.delete_collection(workspace, collection) + logger.info(f"Successfully deleted collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 31a70f23..41a1e5a5 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -88,12 +88,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -112,7 +112,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Create index name with dimension suffix for lazy creation dim = len(vec) index_name = ( - f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + f"d-{workspace}-{message.metadata.collection}-{dim}" ) # Lazily create index if it doesn't exist (but only if authorized in config) @@ -165,22 +165,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - indexes are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for document embeddings via config push""" try: - prefix = f"d-{user}-{collection}-" + prefix = f"d-{workspace}-{collection}-" # Get all indexes and filter for matches all_indexes = self.pinecone.list_indexes() @@ -195,10 +195,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for index_name in matching_indexes: self.pinecone.delete_index(index_name) logger.info(f"Deleted Pinecone index: {index_name}") - logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") + logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index e5e7e705..fb7166b5 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -39,12 +39,12 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_document_embeddings(self, message): + async def store_document_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -63,7 +63,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( - f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + f"d_{workspace}_{message.metadata.collection}_{dim}" ) # Lazily create collection if it doesn't exist (but only if authorized in config) @@ -107,22 +107,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): help=f'Qdrant API key (default: None)' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for document embeddings via config push""" try: - prefix = f"d_{user}_{collection}_" + prefix = f"d_{workspace}_{collection}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -137,10 +137,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): for collection_name in matching_collections: self.qdrant.delete_collection(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 9346c948..2068d58c 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): for entity in message.entities: entity_value = get_term_value(entity.entity) @@ -57,7 +57,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if vec: self.vecstore.insert( vec, entity_value, - message.metadata.user, + workspace, message.metadata.collection, chunk_id=entity.chunk_id or "", ) @@ -73,27 +73,27 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") - self.vecstore.create_collection(user, collection) + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") + self.vecstore.create_collection(workspace, collection) except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for graph embeddings via config push""" try: - self.vecstore.delete_collection(user, collection) - logger.info(f"Successfully deleted collection {user}/{collection}") + self.vecstore.delete_collection(workspace, collection) + logger.info(f"Successfully deleted collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 6a95a38d..23662f7f 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -102,12 +102,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -126,7 +126,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Create index name with dimension suffix for lazy creation dim = len(vec) index_name = ( - f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + f"t-{workspace}-{message.metadata.collection}-{dim}" ) # Lazily create index if it doesn't exist (but only if authorized in config) @@ -183,22 +183,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - indexes are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for graph embeddings via config push""" try: - prefix = f"t-{user}-{collection}-" + prefix = f"t-{workspace}-{collection}-" # Get all indexes and filter for matches all_indexes = self.pinecone.list_indexes() @@ -213,10 +213,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): for index_name in matching_indexes: self.pinecone.delete_index(index_name) logger.info(f"Deleted Pinecone index: {index_name}") - logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") + logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 9a7672f8..391c2a04 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -54,12 +54,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_graph_embeddings(self, message): + async def store_graph_embeddings(self, workspace, message): # Validate collection exists in config before processing - if not self.collection_exists(message.metadata.user, message.metadata.collection): + if not self.collection_exists(workspace, message.metadata.collection): logger.warning( - f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"Collection {message.metadata.collection} for workspace {workspace} " f"does not exist in config (likely deleted while data was in-flight). " f"Dropping message." ) @@ -78,7 +78,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( - f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" + f"t_{workspace}_{message.metadata.collection}_{dim}" ) # Lazily create collection if it doesn't exist (but only if authorized in config) @@ -126,22 +126,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): help=f'Qdrant API key' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """ Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for graph embeddings via config push""" try: - prefix = f"t_{user}_{collection}_" + prefix = f"t_{workspace}_{collection}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -156,10 +156,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): for collection_name in matching_collections: self.qdrant.delete_collection(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index 475604b6..57e1fe48 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -65,13 +65,13 @@ class Processor(FlowProcessor): v = msg.value() if v.triples: - await self.table_store.add_triples(v) + await self.table_store.add_triples(flow.workspace, v) async def on_graph_embeddings(self, msg, consumer, flow): v = msg.value() if v.entities: - await self.table_store.add_graph_embeddings(v) + await self.table_store.add_graph_embeddings(flow.workspace, v) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index a6ec4ff7..32d87871 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -2,13 +2,13 @@ Row embeddings writer for Qdrant (Stage 2). Consumes RowEmbeddings messages (which already contain computed vectors) -and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair. +and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair. This follows the two-stage pattern used by graph-embeddings and document-embeddings: Stage 1 (row-embeddings): Compute embeddings Stage 2 (this processor): Store embeddings -Collection naming: rows_{user}_{collection}_{schema_name}_{dimension} +Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension} Payload structure: - index_name: The indexed field(s) this embedding represents @@ -77,10 +77,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): return safe_name.lower() def get_collection_name( - self, user: str, collection: str, schema_name: str, dimension: int + self, workspace: str, collection: str, schema_name: str, dimension: int ) -> str: """Generate Qdrant collection name""" - safe_user = self.sanitize_name(user) + safe_user = self.sanitize_name(workspace) safe_collection = self.sanitize_name(collection) safe_schema = self.sanitize_name(schema_name) return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" @@ -114,18 +114,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{embeddings.schema_name} from {embeddings.metadata.id}" ) + workspace = flow.workspace + # Validate collection exists in config before processing if not self.collection_exists( - embeddings.metadata.user, embeddings.metadata.collection + workspace, embeddings.metadata.collection ): logger.warning( - f"Collection {embeddings.metadata.collection} for user " - f"{embeddings.metadata.user} does not exist in config. " + f"Collection {embeddings.metadata.collection} for workspace " + f"{workspace} does not exist in config. " f"Dropping message." ) return - user = embeddings.metadata.user collection = embeddings.metadata.collection schema_name = embeddings.schema_name @@ -145,7 +146,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( - user, collection, schema_name, dimension + workspace, collection, schema_name, dimension ) self.ensure_collection(qdrant_collection, dimension) @@ -168,17 +169,17 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Collection creation via config push - collections created lazily on first write""" logger.info( - f"Row embeddings collection create request for {user}/{collection} - " + f"Row embeddings collection create request for {workspace}/{collection} - " f"will be created lazily on first write" ) - async def delete_collection(self, user: str, collection: str): - """Delete all Qdrant collections for a given user/collection""" + async def delete_collection(self, workspace: str, collection: str): + """Delete all Qdrant collections for a given workspace/collection""" try: - prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_" + prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -196,23 +197,23 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info( f"Deleted {len(matching_collections)} collection(s) " - f"for {user}/{collection}" + f"for {workspace}/{collection}" ) except Exception as e: logger.error( - f"Failed to delete collection {user}/{collection}: {e}", + f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True ) raise async def delete_collection_schema( - self, user: str, collection: str, schema_name: str + self, workspace: str, collection: str, schema_name: str ): - """Delete Qdrant collection for a specific user/collection/schema""" + """Delete Qdrant collection for a specific workspace/collection/schema""" try: prefix = ( - f"rows_{self.sanitize_name(user)}_" + f"rows_{self.sanitize_name(workspace)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) @@ -233,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): except Exception as e: logger.error( - f"Failed to delete collection {user}/{collection}/{schema_name}: {e}", + f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}", exc_info=True ) raise diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index d0eec2e1..acfe00d2 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -119,19 +119,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) raise - async def on_schema_config(self, config, version): + async def on_schema_config(self, workspace, config, version): """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") + logger.info( + f"Loading schema configuration version {version} " + f"for workspace {workspace}" + ) - # Track which schemas changed so we can clear partition cache - old_schema_names = set(self.schemas.keys()) + # Track which schemas changed in this workspace + old_schemas = self.schemas.get(workspace, {}) + old_schema_names = set(old_schemas.keys()) - # Clear existing schemas - self.schemas = {} + # Replace existing schemas for this workspace + ws_schemas: Dict[str, RowSchema] = {} + self.schemas[workspace] = ws_schemas # Check if our config type exists if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") + logger.warning( + f"No '{self.config_key}' type in configuration " + f"for {workspace}" + ) return # Get the schemas dictionary for our type @@ -165,24 +173,32 @@ class Processor(CollectionConfigHandler, FlowProcessor): fields=fields ) - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + ws_schemas[schema_name] = row_schema + logger.info( + f"Loaded schema: {schema_name} with " + f"{len(fields)} fields for {workspace}" + ) except Exception as e: logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + logger.info( + f"Schema configuration loaded for {workspace}: " + f"{len(ws_schemas)} schemas" + ) - # Clear partition cache for schemas that changed - # This ensures next write will re-register partitions - new_schema_names = set(self.schemas.keys()) + # Clear partition cache for schemas that changed in this workspace + new_schema_names = set(ws_schemas.keys()) changed_schemas = old_schema_names.symmetric_difference(new_schema_names) if changed_schemas: self.registered_partitions = { (col, sch) for col, sch in self.registered_partitions if sch not in changed_schemas } - logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}") + logger.info( + f"Cleared partition cache for changed schemas " + f"in {workspace}: {changed_schemas}" + ) def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -286,7 +302,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): return index_names - def register_partitions(self, keyspace: str, collection: str, schema_name: str): + def register_partitions( + self, keyspace: str, collection: str, schema_name: str, + workspace: str, + ): """ Register partition entries for a (collection, schema_name) pair. Called once on first row for each pair. @@ -295,9 +314,13 @@ class Processor(CollectionConfigHandler, FlowProcessor): if cache_key in self.registered_partitions: return - schema = self.schemas.get(schema_name) + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(schema_name) if not schema: - logger.warning(f"Cannot register partitions - schema {schema_name} not found") + logger.warning( + f"Cannot register partitions - schema {schema_name} " + f"not found in workspace {workspace}" + ) return safe_keyspace = self.sanitize_name(keyspace) @@ -338,13 +361,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): """Process incoming ExtractedObject and store in Cassandra""" obj = msg.value() + workspace = flow.workspace logger.info( f"Storing {len(obj.values)} rows for schema {obj.schema_name} " - f"from {obj.metadata.id}" + f"from {obj.metadata.id} (workspace {workspace})" ) # Validate collection exists before accepting writes - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + if not self.collection_exists(workspace, obj.metadata.collection): error_msg = ( f"Collection {obj.metadata.collection} does not exist. " f"Create it first via collection management API." @@ -352,13 +376,17 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(error_msg) raise ValueError(error_msg) - # Get schema definition - schema = self.schemas.get(obj.schema_name) + # Get schema definition for this workspace + ws_schemas = self.schemas.get(workspace, {}) + schema = ws_schemas.get(obj.schema_name) if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") + logger.warning( + f"No schema found for {obj.schema_name} in " + f"workspace {workspace} - skipping" + ) return - keyspace = obj.metadata.user + keyspace = workspace collection = obj.metadata.collection schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' @@ -370,7 +398,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register partitions if first time seeing this (collection, schema_name) await asyncio.to_thread( - self.register_partitions, keyspace, collection, schema_name + self.register_partitions, + keyspace, collection, schema_name, workspace, ) safe_keyspace = self.sanitize_name(keyspace) @@ -430,25 +459,25 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"({len(index_names)} indexes per row)" ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" # Connect if not already connected (sync, push to thread) await asyncio.to_thread(self.connect_cassandra) # Ensure tables exist (sync DDL, push to thread) - await asyncio.to_thread(self.ensure_tables, user) + await asyncio.to_thread(self.ensure_tables, workspace) - logger.info(f"Collection {collection} ready for user {user}") + logger.info(f"Collection {collection} ready for workspace {workspace}") - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection using partition tracking""" # Connect if not already connected await asyncio.to_thread(self.connect_cassandra) - safe_keyspace = self.sanitize_name(user) + safe_keyspace = self.sanitize_name(workspace) # Check if keyspace exists - if user not in self.known_keyspaces: + if workspace not in self.known_keyspaces: check_keyspace_cql = """ SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s @@ -459,7 +488,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): if not result: logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") return - self.known_keyspaces.add(user) + self.known_keyspaces.add(workspace) # Discover all partitions for this collection select_partitions_cql = f""" @@ -522,12 +551,12 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"from keyspace {safe_keyspace}" ) - async def delete_collection_schema(self, user: str, collection: str, schema_name: str): + async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" # Connect if not already connected await asyncio.to_thread(self.connect_cassandra) - safe_keyspace = self.sanitize_name(user) + safe_keyspace = self.sanitize_name(workspace) # Discover partitions for this collection + schema select_partitions_cql = f""" diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 01d95c8b..05331d09 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -147,9 +147,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - async def store_triples(self, message): - - user = message.metadata.user + async def store_triples(self, workspace, message): # The cassandra-driver work below — connection, schema # setup, and per-triple inserts — is all synchronous. @@ -159,7 +157,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): def _do_store(): - if self.table is None or self.table != user: + if self.table is None or self.table != workspace: self.tg = None @@ -170,21 +168,21 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=message.metadata.user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=message.metadata.user, + keyspace=workspace, ) except Exception as e: logger.error(f"Exception: {e}", exc_info=True) time.sleep(1) raise e - self.table = user + self.table = workspace for t in message.triples: # Extract values from Term objects @@ -212,12 +210,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): await asyncio.to_thread(_do_store) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" def _do_create(): - # Create or reuse connection for this user's keyspace - if self.table is None or self.table != user: + # Create or reuse connection for this workspace's keyspace + if self.table is None or self.table != workspace: self.tg = None # Use factory function to select implementation @@ -227,23 +225,23 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, ) except Exception as e: - logger.error(f"Failed to connect to Cassandra for user {user}: {e}") + logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") raise - self.table = user + self.table = workspace # Create collection using the built-in method - logger.info(f"Creating collection {collection} for user {user}") + logger.info(f"Creating collection {collection} for workspace {workspace}") if self.tg.collection_exists(collection): logger.info(f"Collection {collection} already exists") @@ -254,15 +252,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): try: await asyncio.to_thread(_do_create) except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection from the unified triples table""" def _do_delete(): - # Create or reuse connection for this user's keyspace - if self.table is None or self.table != user: + # Create or reuse connection for this workspace's keyspace + if self.table is None or self.table != workspace: self.tg = None # Use factory function to select implementation @@ -272,29 +270,29 @@ class Processor(CollectionConfigHandler, TriplesStoreService): if self.cassandra_username and self.cassandra_password: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, username=self.cassandra_username, password=self.cassandra_password, ) else: self.tg = KGClass( hosts=self.cassandra_host, - keyspace=user, + keyspace=workspace, ) except Exception as e: - logger.error(f"Failed to connect to Cassandra for user {user}: {e}") + logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") raise - self.table = user + self.table = workspace # Delete all triples for this collection using the built-in method self.tg.delete_collection(collection) - logger.info(f"Deleted all triples for collection {collection} from keyspace {user}") + logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") try: await asyncio.to_thread(_do_delete) except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise @staticmethod diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index 86f9a6e3..77c32919 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -59,15 +59,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) - def create_node(self, uri, user, collection): + def create_node(self, uri, workspace, collection): - logger.debug(f"Create node {uri} for user={user}, collection={collection}") + logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}") res = self.io.query( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", params={ "uri": uri, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -77,15 +77,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def create_literal(self, value, user, collection): + def create_literal(self, value, workspace, collection): - logger.debug(f"Create literal {value} for user={user}, collection={collection}") + logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}") res = self.io.query( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", params={ "value": value, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -95,19 +95,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def relate_node(self, src, uri, dest, user, collection): + def relate_node(self, src, uri, dest, workspace, collection): - logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") res = self.io.query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src, "dest": dest, "uri": uri, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -117,19 +117,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def relate_literal(self, src, uri, dest, user, collection): + def relate_literal(self, src, uri, dest, workspace, collection): - logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") res = self.io.query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", params={ "src": src, "dest": dest, "uri": uri, - "user": user, + "workspace": workspace, "collection": collection, }, ) @@ -139,36 +139,34 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=res.run_time_ms )) - def collection_exists(self, user, collection): + def collection_exists(self, workspace, collection): """Check if collection metadata node exists""" result = self.io.query( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "RETURN c LIMIT 1", - params={"user": user, "collection": collection} + params={"workspace": workspace, "collection": collection} ) return result.result_set is not None and len(result.result_set) > 0 - def create_collection(self, user, collection): + def create_collection(self, workspace, collection): """Create collection metadata node""" import datetime self.io.query( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", params={ - "user": user, + "workspace": workspace, "collection": collection, "created_at": datetime.datetime.now().isoformat() } ) - logger.info(f"Created collection metadata node for {user}/{collection}") + logger.info(f"Created collection metadata node for {workspace}/{collection}") - async def store_triples(self, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" + async def store_triples(self, workspace, message): collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -182,14 +180,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) @staticmethod def add_args(parser): @@ -208,58 +206,58 @@ class Processor(CollectionConfigHandler, TriplesStoreService): help=f'FalkorDB database (default: {default_database})' ) - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create collection metadata in FalkorDB via config push""" try: # Check if collection exists result = self.io.query( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1", - params={"user": user, "collection": collection} + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) RETURN c LIMIT 1", + params={"workspace": workspace, "collection": collection} ) if result.result_set: - logger.info(f"Collection {user}/{collection} already exists") + logger.info(f"Collection {workspace}/{collection} already exists") else: # Create collection metadata node import datetime self.io.query( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", params={ - "user": user, + "workspace": workspace, "collection": collection, "created_at": datetime.datetime.now().isoformat() } ) - logger.info(f"Created collection {user}/{collection}") + logger.info(f"Created collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete the collection for FalkorDB triples via config push""" try: - # Delete all nodes and literals for this user/collection + # Delete all nodes and literals for this workspace/collection node_result = self.io.query( - "MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n", - params={"user": user, "collection": collection} + "MATCH (n:Node {workspace: $workspace, collection: $collection}) DETACH DELETE n", + params={"workspace": workspace, "collection": collection} ) literal_result = self.io.query( - "MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n", - params={"user": user, "collection": collection} + "MATCH (n:Literal {workspace: $workspace, collection: $collection}) DETACH DELETE n", + params={"workspace": workspace, "collection": collection} ) # Delete collection metadata node metadata_result = self.io.query( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c", - params={"user": user, "collection": collection} + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) DELETE c", + params={"workspace": workspace, "collection": collection} ) - logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}") + logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 16a7d3ed..3e1a8288 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -117,10 +117,10 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Maybe index already exists logger.warning("Index create failure ignored") - # New indexes for user/collection filtering + # New indexes for workspace/collection filtering try: session.run( - "CREATE INDEX ON :Node(user)" + "CREATE INDEX ON :Node(workspace)" ) except Exception as e: logger.warning(f"User index create failure: {e}") @@ -136,7 +136,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): try: session.run( - "CREATE INDEX ON :Literal(user)" + "CREATE INDEX ON :Literal(workspace)" ) except Exception as e: logger.warning(f"User index create failure: {e}") @@ -152,13 +152,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): logger.info("Index creation done") - def create_node(self, uri, user, collection): + def create_node(self, uri, workspace, collection): - logger.debug(f"Create node {uri} for user={user}, collection={collection}") + logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=uri, user=user, collection=collection, + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -167,13 +167,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def create_literal(self, value, user, collection): + def create_literal(self, value, workspace, collection): - logger.debug(f"Create literal {value} for user={user}, collection={collection}") + logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=value, user=user, collection=collection, + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value=value, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -182,15 +182,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_node(self, src, uri, dest, user, collection): + def relate_node(self, src, uri, dest, workspace, collection): - logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -199,15 +199,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_literal(self, src, uri, dest, user, collection): + def relate_literal(self, src, uri, dest, workspace, collection): - logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def create_triple(self, tx, t, user, collection): + def create_triple(self, tx, t, workspace, collection): s_val = get_term_value(t.s) p_val = get_term_value(t.p) @@ -224,48 +224,46 @@ class Processor(CollectionConfigHandler, TriplesStoreService): # Create new s node with given uri, if not exists result = tx.run( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=s_val, user=user, collection=collection + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=s_val, workspace=workspace, collection=collection ) if t.o.type == IRI: # Create new o node with given uri, if not exists result = tx.run( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=o_val, user=user, collection=collection + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=o_val, workspace=workspace, collection=collection ) result = tx.run( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection, ) else: # Create new o literal with given uri, if not exists result = tx.run( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=o_val, user=user, collection=collection + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value=o_val, workspace=workspace, collection=collection ) result = tx.run( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection, ) - async def store_triples(self, message): + async def store_triples(self, workspace, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -279,18 +277,18 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) # Alternative implementation using transactions # with self.io.session(database=self.db) as session: - # session.execute_write(self.create_triple, t, user, collection) + # session.execute_write(self.create_triple, t, workspace, collection) @staticmethod def add_args(parser): @@ -321,72 +319,72 @@ class Processor(CollectionConfigHandler, TriplesStoreService): help=f'Memgraph database (default: {default_database})' ) - def _collection_exists_in_db(self, user, collection): + def _collection_exists_in_db(self, workspace, collection): """Check if collection metadata node exists""" with self.io.session(database=self.db) as session: result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "RETURN c LIMIT 1", - user=user, collection=collection + workspace=workspace, collection=collection ) return bool(list(result)) - def _create_collection_in_db(self, user, collection): + def _create_collection_in_db(self, workspace, collection): """Create collection metadata node""" import datetime with self.io.session(database=self.db) as session: session.run( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", - user=user, collection=collection, + workspace=workspace, collection=collection, created_at=datetime.datetime.now().isoformat() ) - logger.info(f"Created collection metadata node for {user}/{collection}") + logger.info(f"Created collection metadata node for {workspace}/{collection}") - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create collection metadata in Memgraph via config push""" try: - if self._collection_exists_in_db(user, collection): - logger.info(f"Collection {user}/{collection} already exists") + if self._collection_exists_in_db(workspace, collection): + logger.info(f"Collection {workspace}/{collection} already exists") else: - self._create_collection_in_db(user, collection) - logger.info(f"Created collection {user}/{collection}") + self._create_collection_in_db(workspace, collection) + logger.info(f"Created collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection via config push""" try: with self.io.session(database=self.db) as session: - # Delete all nodes for this user and collection + # Delete all nodes for this workspace and collection node_result = session.run( - "MATCH (n:Node {user: $user, collection: $collection}) " + "MATCH (n:Node {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) nodes_deleted = node_result.consume().counters.nodes_deleted - # Delete all literals for this user and collection + # Delete all literals for this workspace and collection literal_result = session.run( - "MATCH (n:Literal {user: $user, collection: $collection}) " + "MATCH (n:Literal {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) literals_deleted = literal_result.consume().counters.nodes_deleted # Delete collection metadata node metadata_result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "DELETE c", - user=user, collection=collection + workspace=workspace, collection=collection ) metadata_deleted = metadata_result.consume().counters.nodes_deleted # Note: Relationships are automatically deleted with DETACH DELETE - logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") + logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index f7b2d947..22e25153 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -80,14 +80,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): logger.info("Create indexes...") - # Legacy indexes for backwards compatibility try: session.run( "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", ) except Exception as e: logger.warning(f"Index create failure: {e}") - # Maybe index already exists logger.warning("Index create failure ignored") try: @@ -96,7 +94,6 @@ class Processor(CollectionConfigHandler, TriplesStoreService): ) except Exception as e: logger.warning(f"Index create failure: {e}") - # Maybe index already exists logger.warning("Index create failure ignored") try: @@ -105,13 +102,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService): ) except Exception as e: logger.warning(f"Index create failure: {e}") - # Maybe index already exists logger.warning("Index create failure ignored") - # New compound indexes for user/collection filtering try: session.run( - "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", + "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)", ) except Exception as e: logger.warning(f"Compound index create failure: {e}") @@ -119,17 +114,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService): try: session.run( - "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", + "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)", ) except Exception as e: logger.warning(f"Compound index create failure: {e}") logger.warning("Index create failure ignored") - # Note: Neo4j doesn't support compound indexes on relationships in all versions - # Try to create individual indexes on relationship properties + # Neo4j doesn't support compound indexes on relationships in all versions try: session.run( - "CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", + "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)", ) except Exception as e: logger.warning(f"Relationship index create failure: {e}") @@ -145,13 +139,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): logger.info("Index creation done") - def create_node(self, uri, user, collection): + def create_node(self, uri, workspace, collection): - logger.debug(f"Create node {uri} for user={user}, collection={collection}") + logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", - uri=uri, user=user, collection=collection, + "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})", + uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -160,13 +154,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def create_literal(self, value, user, collection): + def create_literal(self, value, workspace, collection): - logger.debug(f"Create literal {value} for user={user}, collection={collection}") + logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", - value=value, user=user, collection=collection, + "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})", + value=value, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -175,15 +169,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_node(self, src, uri, dest, user, collection): + def relate_node(self, src, uri, dest, workspace, collection): - logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -192,15 +186,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - def relate_literal(self, src, uri, dest, user, collection): + def relate_literal(self, src, uri, dest, workspace, collection): - logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") + logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}") summary = self.io.execute_query( - "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " - "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " - "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", - src=src, dest=dest, uri=uri, user=user, collection=collection, + "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)", + src=src, dest=dest, uri=uri, workspace=workspace, collection=collection, database_=self.db, ).summary @@ -209,14 +203,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService): time=summary.result_available_after )) - async def store_triples(self, message): + async def store_triples(self, workspace, message): - # Extract user and collection from metadata - user = message.metadata.user if message.metadata.user else "default" collection = message.metadata.collection if message.metadata.collection else "default" # Validate collection exists before accepting writes - if not self.collection_exists(user, collection): + if not self.collection_exists(workspace, collection): error_msg = ( f"Collection {collection} does not exist. " f"Create it first via collection management API." @@ -230,14 +222,14 @@ class Processor(CollectionConfigHandler, TriplesStoreService): p_val = get_term_value(t.p) o_val = get_term_value(t.o) - self.create_node(s_val, user, collection) + self.create_node(s_val, workspace, collection) if t.o.type == IRI: - self.create_node(o_val, user, collection) - self.relate_node(s_val, p_val, o_val, user, collection) + self.create_node(o_val, workspace, collection) + self.relate_node(s_val, p_val, o_val, workspace, collection) else: - self.create_literal(o_val, user, collection) - self.relate_literal(s_val, p_val, o_val, user, collection) + self.create_literal(o_val, workspace, collection) + self.relate_literal(s_val, p_val, o_val, workspace, collection) @staticmethod def add_args(parser): @@ -268,75 +260,70 @@ class Processor(CollectionConfigHandler, TriplesStoreService): help=f'Neo4j database (default: {default_database})' ) - def _collection_exists_in_db(self, user, collection): + def _collection_exists_in_db(self, workspace, collection): """Check if collection metadata node exists""" with self.io.session(database=self.db) as session: result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "RETURN c LIMIT 1", - user=user, collection=collection + workspace=workspace, collection=collection ) return bool(list(result)) - def _create_collection_in_db(self, user, collection): + def _create_collection_in_db(self, workspace, collection): """Create collection metadata node""" import datetime with self.io.session(database=self.db) as session: session.run( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "SET c.created_at = $created_at", - user=user, collection=collection, + workspace=workspace, collection=collection, created_at=datetime.datetime.now().isoformat() ) - logger.info(f"Created collection metadata node for {user}/{collection}") + logger.info(f"Created collection metadata node for {workspace}/{collection}") - async def create_collection(self, user: str, collection: str, metadata: dict): + async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create collection metadata in Neo4j via config push""" try: - if self._collection_exists_in_db(user, collection): - logger.info(f"Collection {user}/{collection} already exists") + if self._collection_exists_in_db(workspace, collection): + logger.info(f"Collection {workspace}/{collection} already exists") else: - self._create_collection_in_db(user, collection) - logger.info(f"Created collection {user}/{collection}") + self._create_collection_in_db(workspace, collection) + logger.info(f"Created collection {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise - async def delete_collection(self, user: str, collection: str): + async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection via config push""" try: with self.io.session(database=self.db) as session: - # Delete all nodes for this user and collection node_result = session.run( - "MATCH (n:Node {user: $user, collection: $collection}) " + "MATCH (n:Node {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) nodes_deleted = node_result.consume().counters.nodes_deleted - # Delete all literals for this user and collection literal_result = session.run( - "MATCH (n:Literal {user: $user, collection: $collection}) " + "MATCH (n:Literal {workspace: $workspace, collection: $collection}) " "DETACH DELETE n", - user=user, collection=collection + workspace=workspace, collection=collection ) literals_deleted = literal_result.consume().counters.nodes_deleted - # Note: Relationships are automatically deleted with DETACH DELETE - - # Delete collection metadata node metadata_result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) " "DELETE c", - user=user, collection=collection + workspace=workspace, collection=collection ) metadata_deleted = metadata_result.consume().counters.nodes_deleted - logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") + logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) + logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py index d9a8711b..8fd00427 100644 --- a/trustgraph-flow/trustgraph/tables/config.py +++ b/trustgraph-flow/trustgraph/tables/config.py @@ -72,10 +72,11 @@ class ConfigTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS config ( + workspace text, class text, key text, value text, - PRIMARY KEY (class, key) + PRIMARY KEY ((workspace, class), key) ); """); @@ -124,52 +125,63 @@ class ConfigTableStore: def prepare_statements(self): self.put_config_stmt = self.cassandra.prepare(""" - INSERT INTO config ( class, key, value ) - VALUES (?, ?, ?) - """) - - self.get_classes_stmt = self.cassandra.prepare(""" - SELECT DISTINCT class FROM config; + INSERT INTO config ( workspace, class, key, value ) + VALUES (?, ?, ?, ?) """) self.get_keys_stmt = self.cassandra.prepare(""" - SELECT key FROM config WHERE class = ?; + SELECT key FROM config + WHERE workspace = ? AND class = ?; """) self.get_value_stmt = self.cassandra.prepare(""" - SELECT value FROM config WHERE class = ? AND key = ?; + SELECT value FROM config + WHERE workspace = ? AND class = ? AND key = ?; """) self.delete_key_stmt = self.cassandra.prepare(""" DELETE FROM config - WHERE class = ? AND key = ?; + WHERE workspace = ? AND class = ? AND key = ?; """) self.get_all_stmt = self.cassandra.prepare(""" - SELECT class AS cls, key, value FROM config; + SELECT workspace, class AS cls, key, value FROM config; + """) + + self.get_all_for_workspace_stmt = self.cassandra.prepare(""" + SELECT class AS cls, key, value FROM config + WHERE workspace = ? + ALLOW FILTERING; """) self.get_values_stmt = self.cassandra.prepare(""" - SELECT key, value FROM config WHERE class = ?; + SELECT key, value FROM config + WHERE workspace = ? AND class = ?; """) - async def put_config(self, cls, key, value): + self.get_values_all_ws_stmt = self.cassandra.prepare(""" + SELECT workspace, key, value FROM config + WHERE class = ? + ALLOW FILTERING; + """) + + async def put_config(self, workspace, cls, key, value): try: await async_execute( self.cassandra, self.put_config_stmt, - (cls, key, value), + (workspace, cls, key, value), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - async def get_value(self, cls, key): + async def get_value(self, workspace, cls, key): try: rows = await async_execute( self.cassandra, self.get_value_stmt, - (cls, key), + (workspace, cls, key), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -179,12 +191,12 @@ class ConfigTableStore: return row[0] return None - async def get_values(self, cls): + async def get_values(self, workspace, cls): try: rows = await async_execute( self.cassandra, self.get_values_stmt, - (cls,), + (workspace, cls), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -192,18 +204,20 @@ class ConfigTableStore: return [[row[0], row[1]] for row in rows] - async def get_classes(self): + async def get_values_all_ws(self, cls): + """Return (workspace, key, value) tuples for all workspaces + with entries of the given class.""" try: rows = await async_execute( self.cassandra, - self.get_classes_stmt, - (), + self.get_values_all_ws_stmt, + (cls,), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - return [row[0] for row in rows] + return [(row[0], row[1], row[2]) for row in rows] async def get_all(self): try: @@ -216,14 +230,27 @@ class ConfigTableStore: logger.error("Exception occurred", exc_info=True) raise + return [(row[0], row[1], row[2], row[3]) for row in rows] + + async def get_all_for_workspace(self, workspace): + try: + rows = await async_execute( + self.cassandra, + self.get_all_for_workspace_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + return [(row[0], row[1], row[2]) for row in rows] - async def get_keys(self, cls): + async def get_keys(self, workspace, cls): try: rows = await async_execute( self.cassandra, self.get_keys_stmt, - (cls,), + (workspace, cls), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -231,12 +258,12 @@ class ConfigTableStore: return [row[0] for row in rows] - async def delete_key(self, cls, key): + async def delete_key(self, workspace, cls, key): try: await async_execute( self.cassandra, self.delete_key_stmt, - (cls, key), + (workspace, cls, key), ) except Exception: logger.error("Exception occurred", exc_info=True) diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index b06f4862..4d729956 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -88,7 +88,7 @@ class KnowledgeTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS triples ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -98,7 +98,7 @@ class KnowledgeTableStore: triples list>, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); @@ -106,7 +106,7 @@ class KnowledgeTableStore: self.cassandra.execute(""" create table if not exists graph_embeddings ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -119,20 +119,20 @@ class KnowledgeTableStore: list > >, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS graph_embeddings_user ON - graph_embeddings ( user ); + CREATE INDEX IF NOT EXISTS graph_embeddings_workspace ON + graph_embeddings ( workspace ); """); logger.debug("document_embeddings table...") self.cassandra.execute(""" create table if not exists document_embeddings ( - user text, + workspace text, document_id text, id uuid, time timestamp, @@ -145,13 +145,13 @@ class KnowledgeTableStore: list > >, - PRIMARY KEY ((user, document_id), id) + PRIMARY KEY ((workspace, document_id), id) ); """); self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS document_embeddings_user ON - document_embeddings ( user ); + CREATE INDEX IF NOT EXISTS document_embeddings_workspace ON + document_embeddings ( workspace ); """); logger.info("Cassandra schema OK.") @@ -161,7 +161,7 @@ class KnowledgeTableStore: self.insert_triples_stmt = self.cassandra.prepare(""" INSERT INTO triples ( - id, user, document_id, + id, workspace, document_id, time, metadata, triples ) VALUES (?, ?, ?, ?, ?, ?) @@ -170,7 +170,7 @@ class KnowledgeTableStore: self.insert_graph_embeddings_stmt = self.cassandra.prepare(""" INSERT INTO graph_embeddings ( - id, user, document_id, time, metadata, entity_embeddings + id, workspace, document_id, time, metadata, entity_embeddings ) VALUES (?, ?, ?, ?, ?, ?) """) @@ -178,45 +178,45 @@ class KnowledgeTableStore: self.insert_document_embeddings_stmt = self.cassandra.prepare(""" INSERT INTO document_embeddings ( - id, user, document_id, time, metadata, chunks + id, workspace, document_id, time, metadata, chunks ) VALUES (?, ?, ?, ?, ?, ?) """) self.list_cores_stmt = self.cassandra.prepare(""" - SELECT DISTINCT user, document_id FROM graph_embeddings - WHERE user = ? + SELECT DISTINCT workspace, document_id FROM graph_embeddings + WHERE workspace = ? """) self.get_triples_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, triples FROM triples - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.get_graph_embeddings_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, entity_embeddings FROM graph_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.get_document_embeddings_stmt = self.cassandra.prepare(""" SELECT id, time, metadata, chunks FROM document_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.delete_triples_stmt = self.cassandra.prepare(""" DELETE FROM triples - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) self.delete_graph_embeddings_stmt = self.cassandra.prepare(""" DELETE FROM graph_embeddings - WHERE user = ? AND document_id = ? + WHERE workspace = ? AND document_id = ? """) - async def add_triples(self, m): + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -232,7 +232,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_triples_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], triples, ), @@ -241,7 +241,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def add_graph_embeddings(self, m): + async def add_graph_embeddings(self, workspace, m): when = int(time.time() * 1000) @@ -258,7 +258,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_graph_embeddings_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], entities, ), @@ -267,7 +267,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def add_document_embeddings(self, m): + async def add_document_embeddings(self, workspace, m): when = int(time.time() * 1000) @@ -284,7 +284,7 @@ class KnowledgeTableStore: self.cassandra, self.insert_document_embeddings_stmt, ( - uuid.uuid4(), m.metadata.user, + uuid.uuid4(), workspace, m.metadata.root or m.metadata.id, when, [], chunks, ), @@ -293,7 +293,7 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - async def list_kg_cores(self, user): + async def list_kg_cores(self, workspace): logger.debug("List kg cores...") @@ -301,7 +301,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.list_cores_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -313,7 +313,7 @@ class KnowledgeTableStore: return lst - async def delete_kg_core(self, user, document_id): + async def delete_kg_core(self, workspace, document_id): logger.debug("Delete kg cores...") @@ -321,7 +321,7 @@ class KnowledgeTableStore: await async_execute( self.cassandra, self.delete_triples_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -331,13 +331,13 @@ class KnowledgeTableStore: await async_execute( self.cassandra, self.delete_graph_embeddings_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) raise - async def get_triples(self, user, document_id, receiver): + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -345,7 +345,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.get_triples_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -369,7 +369,6 @@ class KnowledgeTableStore: Triples( metadata = Metadata( id = document_id, - user = user, collection = "default", # FIXME: What to put here? ), triples = triples @@ -378,7 +377,7 @@ class KnowledgeTableStore: logger.debug("Done") - async def get_graph_embeddings(self, user, document_id, receiver): + async def get_graph_embeddings(self, workspace, document_id, receiver): logger.debug("Get GE...") @@ -386,7 +385,7 @@ class KnowledgeTableStore: rows = await async_execute( self.cassandra, self.get_graph_embeddings_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -409,12 +408,11 @@ class KnowledgeTableStore: GraphEmbeddings( metadata = Metadata( id = document_id, - user = user, collection = "default", # FIXME: What to put here? ), entities = entities ) - ) + ) logger.debug("Done") diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index c85ae72a..86706079 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -64,7 +64,7 @@ class LibraryTableStore: self.cluster = Cluster(cassandra_host) self.cassandra = self.cluster.connect() - + logger.info("Connected.") self.ensure_cassandra_schema() @@ -76,13 +76,13 @@ class LibraryTableStore: logger.debug("Ensure Cassandra schema...") logger.debug("Keyspace...") - + # FIXME: Replication factor should be configurable self.cassandra.execute(f""" create keyspace if not exists {self.keyspace} - with replication = {{ - 'class' : 'SimpleStrategy', - 'replication_factor' : 1 + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 }}; """); @@ -93,7 +93,7 @@ class LibraryTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS document ( id text, - user text, + workspace text, time timestamp, kind text, title text, @@ -103,7 +103,9 @@ class LibraryTableStore: >>, tags list, object_id uuid, - PRIMARY KEY (user, id) + parent_id text, + document_type text, + PRIMARY KEY (workspace, id) ); """); @@ -114,27 +116,6 @@ class LibraryTableStore: ON document (object_id) """); - # Add parent_id and document_type columns for child document support - logger.debug("document table parent_id column...") - - try: - self.cassandra.execute(""" - ALTER TABLE document ADD parent_id text - """); - except Exception as e: - # Column may already exist - if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): - logger.debug(f"parent_id column may already exist: {e}") - - try: - self.cassandra.execute(""" - ALTER TABLE document ADD document_type text - """); - except Exception as e: - # Column may already exist - if "already exists" not in str(e).lower() and "Invalid column name" not in str(e): - logger.debug(f"document_type column may already exist: {e}") - logger.debug("document parent index...") self.cassandra.execute(""" @@ -150,10 +131,10 @@ class LibraryTableStore: document_id text, time timestamp, flow text, - user text, + workspace text, collection text, tags list, - PRIMARY KEY (user, id) + PRIMARY KEY (workspace, id) ); """); @@ -162,7 +143,7 @@ class LibraryTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS upload_session ( upload_id text PRIMARY KEY, - user text, + workspace text, document_id text, document_metadata text, s3_upload_id text, @@ -176,11 +157,11 @@ class LibraryTableStore: ) WITH default_time_to_live = 86400; """); - logger.debug("upload_session user index...") + logger.debug("upload_session workspace index...") self.cassandra.execute(""" - CREATE INDEX IF NOT EXISTS upload_session_user - ON upload_session (user) + CREATE INDEX IF NOT EXISTS upload_session_workspace + ON upload_session (workspace) """); logger.info("Cassandra schema OK.") @@ -190,7 +171,7 @@ class LibraryTableStore: self.insert_document_stmt = self.cassandra.prepare(""" INSERT INTO document ( - id, user, time, + id, workspace, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type @@ -202,25 +183,25 @@ class LibraryTableStore: UPDATE document SET time = ?, title = ?, comments = ?, metadata = ?, tags = ? - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.get_document_stmt = self.cassandra.prepare(""" SELECT time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.delete_document_stmt = self.cassandra.prepare(""" DELETE FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.test_document_exists_stmt = self.cassandra.prepare(""" SELECT id FROM document - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? LIMIT 1 """) @@ -229,7 +210,7 @@ class LibraryTableStore: id, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? + WHERE workspace = ? """) self.list_document_by_tag_stmt = self.cassandra.prepare(""" @@ -237,7 +218,7 @@ class LibraryTableStore: id, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document - WHERE user = ? AND tags CONTAINS ? + WHERE workspace = ? AND tags CONTAINS ? ALLOW FILTERING """) @@ -245,7 +226,7 @@ class LibraryTableStore: INSERT INTO processing ( id, document_id, time, - flow, user, collection, + flow, workspace, collection, tags ) VALUES (?, ?, ?, ?, ?, ?, ?) @@ -253,13 +234,13 @@ class LibraryTableStore: self.delete_processing_stmt = self.cassandra.prepare(""" DELETE FROM processing - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? """) self.test_processing_exists_stmt = self.cassandra.prepare(""" SELECT id FROM processing - WHERE user = ? AND id = ? + WHERE workspace = ? AND id = ? LIMIT 1 """) @@ -267,14 +248,14 @@ class LibraryTableStore: SELECT id, document_id, time, flow, collection, tags FROM processing - WHERE user = ? + WHERE workspace = ? """) # Upload session prepared statements self.insert_upload_session_stmt = self.cassandra.prepare(""" INSERT INTO upload_session ( - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at ) @@ -283,7 +264,7 @@ class LibraryTableStore: self.get_upload_session_stmt = self.cassandra.prepare(""" SELECT - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at FROM upload_session @@ -308,25 +289,25 @@ class LibraryTableStore: total_size, chunk_size, total_chunks, chunks_received, created_at, updated_at FROM upload_session - WHERE user = ? + WHERE workspace = ? """) # Child document queries self.list_children_stmt = self.cassandra.prepare(""" SELECT - id, user, time, kind, title, comments, metadata, tags, + id, workspace, time, kind, title, comments, metadata, tags, object_id, parent_id, document_type FROM document WHERE parent_id = ? ALLOW FILTERING """) - async def document_exists(self, user, id): + async def document_exists(self, workspace, id): rows = await async_execute( self.cassandra, self.test_document_exists_stmt, - (user, id), + (workspace, id), ) return bool(rows) @@ -351,7 +332,7 @@ class LibraryTableStore: self.cassandra, self.insert_document_stmt, ( - document.id, document.user, int(document.time * 1000), + document.id, document.workspace, int(document.time * 1000), document.kind, document.title, document.comments, metadata, document.tags, object_id, parent_id, document_type @@ -381,7 +362,7 @@ class LibraryTableStore: ( int(document.time * 1000), document.title, document.comments, metadata, document.tags, - document.user, document.id + document.workspace, document.id ), ) except Exception: @@ -390,7 +371,7 @@ class LibraryTableStore: logger.debug("Update complete") - async def remove_document(self, user, document_id): + async def remove_document(self, workspace, document_id): logger.info(f"Removing document {document_id}") @@ -398,7 +379,7 @@ class LibraryTableStore: await async_execute( self.cassandra, self.delete_document_stmt, - (user, document_id), + (workspace, document_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -406,7 +387,7 @@ class LibraryTableStore: logger.debug("Delete complete") - async def list_documents(self, user): + async def list_documents(self, workspace): logger.debug("List documents...") @@ -414,7 +395,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.list_document_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -423,7 +404,7 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - user = user, + workspace = workspace, time = int(time.mktime(row[1].timetuple())), kind = row[2], title = row[3], @@ -465,7 +446,7 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - user = row[1], + workspace = row[1], time = int(time.mktime(row[2].timetuple())), kind = row[3], title = row[4], @@ -489,7 +470,7 @@ class LibraryTableStore: return lst - async def get_document(self, user, id): + async def get_document(self, workspace, id): logger.debug("Get document") @@ -497,7 +478,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.get_document_stmt, - (user, id), + (workspace, id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -506,7 +487,7 @@ class LibraryTableStore: for row in rows: doc = DocumentMetadata( id = id, - user = user, + workspace = workspace, time = int(time.mktime(row[0].timetuple())), kind = row[1], title = row[2], @@ -529,7 +510,7 @@ class LibraryTableStore: raise RuntimeError("No such document row?") - async def get_document_object_id(self, user, id): + async def get_document_object_id(self, workspace, id): logger.debug("Get document obj ID") @@ -537,7 +518,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.get_document_stmt, - (user, id), + (workspace, id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -549,12 +530,12 @@ class LibraryTableStore: raise RuntimeError("No such document row?") - async def processing_exists(self, user, id): + async def processing_exists(self, workspace, id): rows = await async_execute( self.cassandra, self.test_processing_exists_stmt, - (user, id), + (workspace, id), ) return bool(rows) @@ -570,7 +551,7 @@ class LibraryTableStore: ( processing.id, processing.document_id, int(processing.time * 1000), processing.flow, - processing.user, processing.collection, + processing.workspace, processing.collection, processing.tags ), ) @@ -580,7 +561,7 @@ class LibraryTableStore: logger.debug("Add complete") - async def remove_processing(self, user, processing_id): + async def remove_processing(self, workspace, processing_id): logger.info(f"Removing processing {processing_id}") @@ -588,7 +569,7 @@ class LibraryTableStore: await async_execute( self.cassandra, self.delete_processing_stmt, - (user, processing_id), + (workspace, processing_id), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -596,7 +577,7 @@ class LibraryTableStore: logger.debug("Delete complete") - async def list_processing(self, user): + async def list_processing(self, workspace): logger.debug("List processing objects") @@ -604,7 +585,7 @@ class LibraryTableStore: rows = await async_execute( self.cassandra, self.list_processing_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) @@ -616,7 +597,7 @@ class LibraryTableStore: document_id = row[1], time = int(time.mktime(row[2].timetuple())), flow = row[3], - user = user, + workspace = workspace, collection = row[4], tags = row[5] if row[5] else [], ) @@ -632,7 +613,7 @@ class LibraryTableStore: async def create_upload_session( self, upload_id, - user, + workspace, document_id, document_metadata, s3_upload_id, @@ -652,7 +633,7 @@ class LibraryTableStore: self.cassandra, self.insert_upload_session_stmt, ( - upload_id, user, document_id, document_metadata, + upload_id, workspace, document_id, document_metadata, s3_upload_id, object_id, total_size, chunk_size, total_chunks, {}, now, now ), @@ -681,7 +662,7 @@ class LibraryTableStore: for row in rows: session = { "upload_id": row[0], - "user": row[1], + "workspace": row[1], "document_id": row[2], "document_metadata": row[3], "s3_upload_id": row[4], @@ -738,16 +719,16 @@ class LibraryTableStore: logger.debug("Upload session deleted") - async def list_upload_sessions(self, user): - """List all upload sessions for a user.""" + async def list_upload_sessions(self, workspace): + """List all upload sessions for a workspace.""" - logger.debug(f"List upload sessions for {user}") + logger.debug(f"List upload sessions for {workspace}") try: rows = await async_execute( self.cassandra, self.list_upload_sessions_stmt, - (user,), + (workspace,), ) except Exception: logger.error("Exception occurred", exc_info=True) diff --git a/trustgraph-flow/trustgraph/tool_service/joke/service.py b/trustgraph-flow/trustgraph/tool_service/joke/service.py index d9b7cde0..171156d8 100644 --- a/trustgraph-flow/trustgraph/tool_service/joke/service.py +++ b/trustgraph-flow/trustgraph/tool_service/joke/service.py @@ -2,7 +2,6 @@ Joke Tool Service - An example dynamic tool service. This service demonstrates the tool service integration by: -- Using the 'user' field to personalize responses - Using config params (style) to customize joke style - Using arguments (topic) to generate topic-specific jokes @@ -143,17 +142,16 @@ class Processor(DynamicToolService): super(Processor, self).__init__(**params) logger.info("Joke service initialized") - async def invoke(self, user, config, arguments): + async def invoke(self, config, arguments): """ Generate a joke based on the topic and style. Args: - user: The user requesting the joke config: Config values including 'style' (pun, dad-joke, one-liner) arguments: Arguments including 'topic' (programming, animals, food) Returns: - A personalized joke string + A joke string """ # Get style from config (default: random) style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"])) @@ -183,10 +181,9 @@ class Processor(DynamicToolService): # Pick a random joke joke = random.choice(jokes) - # Personalize the response - response = f"Hey {user}! Here's a {style} for you:\n\n{joke}" + response = f"Here's a {style} for you:\n\n{joke}" - logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}") + logger.debug(f"Generated joke: style={style}, topic={topic}") return response diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index eadd841b..7378db64 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -49,26 +49,26 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8 logging.info("Shutdown complete") -async def get_socket_manager(ctx, user): +async def get_socket_manager(ctx): lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url gateway_token = lifespan_context.gateway_token - if user in sockets: + if "default" in sockets: logging.info("Return existing socket manager") - return sockets[user] + return sockets["default"] logging.info(f"Opening socket to {websocket_url}...") # Create manager with empty pending requests manager = WebSocketManager(websocket_url, token=gateway_token) - + # Start reader task with the proper manager await manager.start() - - sockets[user] = manager + + sockets["default"] = manager logging.info("Return new socket manager") return manager @@ -372,7 +372,6 @@ class McpServer: async def graph_rag( self, question: str, - user: str | None = None, collection: str | None = None, entity_limit: int | None = None, triple_limit: int | None = None, @@ -391,7 +390,6 @@ class McpServer: Args: question: The question or query to answer using the knowledge graph. The system will find relevant entities and relationships to inform the response. - user: User identifier for access control and personalization (default: "trustgraph"). collection: Knowledge collection to query (default: "default"). Different collections may contain domain-specific knowledge. entity_limit: Maximum number of entities to retrieve during graph traversal. @@ -414,7 +412,6 @@ class McpServer: - Perform research queries across connected information """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if flow_id is None: flow_id = "default" @@ -423,7 +420,7 @@ class McpServer: logging.info("GraphRAG request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -437,7 +434,6 @@ class McpServer: "query": question } - if user: request_data["user"] = user if collection: request_data["collection"] = collection if entity_limit: request_data["entity_limit"] = entity_limit if triple_limit: request_data["triple_limit"] = triple_limit @@ -466,7 +462,6 @@ class McpServer: async def agent( self, question: str, - user: str | None = None, collection: str | None = None, flow_id: str | None = None, ctx: Context = None, @@ -481,7 +476,6 @@ class McpServer: Args: question: The question or task for the agent to solve. Can be complex queries requiring multiple steps, analysis, or tool usage. - user: User identifier for personalization and access control (default: "trustgraph"). collection: Knowledge collection the agent can access (default: "default"). Determines what information and tools are available. flow_id: Agent workflow to use (default: "default"). Different flows @@ -501,7 +495,6 @@ class McpServer: through log messages, so you can follow its reasoning steps. """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if flow_id is None: flow_id = "default" @@ -510,7 +503,7 @@ class McpServer: logging.info("Agent request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -524,7 +517,6 @@ class McpServer: "question": question } - if user: request_data["user"] = user if collection: request_data["collection"] = collection gen = manager.request("agent", request_data, flow_id) @@ -1143,23 +1135,18 @@ class McpServer: async def get_knowledge_cores( self, - user: str | None = None, ctx: Context = None, ) -> KnowledgeCoresResponse: """ - List all available knowledge graph cores for a user. - + List all available knowledge graph cores in the current workspace. + Knowledge cores are packaged collections of structured knowledge that can be loaded into the system for querying and reasoning. They contain entities, relationships, and facts organized as knowledge graphs. - - Args: - user: User identifier to list cores for (default: "trustgraph"). - Different users may have access to different knowledge cores. - + Returns: KnowledgeCoresResponse containing a list of available knowledge core IDs. - + Use this for: - Discovering available knowledge collections - Understanding what knowledge domains are accessible @@ -1167,14 +1154,12 @@ class McpServer: - Managing knowledge resources """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get knowledge cores request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1185,7 +1170,6 @@ class McpServer: request_data = { "operation": "list-kg-cores", - "user": user } gen = manager.request("knowledge", request_data, None) @@ -1199,40 +1183,35 @@ class McpServer: async def delete_kg_core( self, core_id: str, - user: str | None = None, ctx: Context = None, ) -> DeleteKgCoreResponse: """ Permanently delete a knowledge graph core. - + This operation removes a knowledge core from storage. Use with caution as this action cannot be undone. - + Args: core_id: Unique identifier of the knowledge core to delete. - user: User identifier (default: "trustgraph"). Only cores owned - by this user can be deleted. - + Returns: DeleteKgCoreResponse confirming the deletion. - + Use this for: - Cleaning up obsolete knowledge cores - Removing test or experimental data - Managing storage space - Maintaining organized knowledge collections - + Warning: This permanently deletes the knowledge core and all its data. """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Delete KG core request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1244,7 +1223,6 @@ class McpServer: request_data = { "operation": "delete-kg-core", "id": core_id, - "user": user } gen = manager.request("knowledge", request_data, None) @@ -1258,27 +1236,25 @@ class McpServer: self, core_id: str, flow: str, - user: str | None = None, collection: str | None = None, ctx: Context = None, ) -> LoadKgCoreResponse: """ Load a knowledge graph core into the active system for querying. - + This operation makes a knowledge core available for GraphRAG queries, triple searches, and other knowledge-based operations. - + Args: core_id: Unique identifier of the knowledge core to load. flow: Processing flow to use for loading the core. Different flows may apply different processing, indexing, or optimization steps. - user: User identifier (default: "trustgraph"). collection: Target collection name (default: "default"). The loaded knowledge will be available under this collection name. - + Returns: LoadKgCoreResponse confirming the core has been loaded. - + Use this for: - Making knowledge cores available for queries - Switching between different knowledge domains @@ -1286,7 +1262,6 @@ class McpServer: - Preparing knowledge for GraphRAG operations """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if ctx is None: @@ -1294,7 +1269,7 @@ class McpServer: logging.info("Load KG core request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1307,7 +1282,6 @@ class McpServer: "operation": "load-kg-core", "id": core_id, "flow": flow, - "user": user, "collection": collection } @@ -1321,42 +1295,38 @@ class McpServer: async def get_kg_core( self, core_id: str, - user: str | None = None, ctx: Context = None, ) -> GetKgCoreResponse: """ Download and retrieve the complete content of a knowledge graph core. - + This tool streams the entire content of a knowledge core, returning all entities, relationships, and metadata. Due to potentially large data sizes, the content is streamed in chunks. - + Args: core_id: Unique identifier of the knowledge core to retrieve. - user: User identifier (default: "trustgraph"). - + Returns: GetKgCoreResponse containing all chunks of the knowledge core data. Each chunk contains part of the knowledge graph structure. - + Use this for: - Examining knowledge core content and structure - Debugging knowledge graph data - Exporting knowledge for backup or analysis - Understanding the scope and quality of knowledge - + Note: Large knowledge cores may take significant time to download. Progress updates are provided through log messages during streaming. """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get KG core request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1368,7 +1338,6 @@ class McpServer: request_data = { "operation": "get-kg-core", "id": core_id, - "user": user } # Collect all streaming responses @@ -1713,27 +1682,22 @@ class McpServer: async def get_documents( self, - user: str | None = None, ctx: Context = None, ) -> DocumentsResponse: """ List all documents stored in the TrustGraph document library. - + This tool returns metadata for all documents that have been uploaded to the system, including their processing status and properties. - - Args: - user: User identifier to list documents for (default: "trustgraph"). - Only documents owned by this user will be returned. - + Returns: DocumentsResponse containing metadata for each document including: - Document ID and title - - Upload timestamp and user + - Upload timestamp - MIME type and size information - Tags and custom metadata - Processing status - + Use this for: - Browsing available documents - Managing document collections @@ -1741,14 +1705,12 @@ class McpServer: - Auditing document storage """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get documents request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1759,7 +1721,6 @@ class McpServer: request_data = { "operation": "list-documents", - "user": user } gen = manager.request("librarian", request_data, None) @@ -1772,26 +1733,21 @@ class McpServer: async def get_processing( self, - user: str | None = None, ctx: Context = None, ) -> ProcessingResponse: """ List all documents currently in the processing queue. - + This tool shows documents that are being processed or waiting to be processed, along with their processing status and configuration. - - Args: - user: User identifier (default: "trustgraph"). Only processing - jobs for this user will be returned. - + Returns: ProcessingResponse containing processing metadata including: - Processing job ID and document ID - Processing flow and status - - Target collection and user + - Target collection - Timestamp and progress information - + Use this for: - Monitoring document processing progress - Debugging processing issues @@ -1799,14 +1755,12 @@ class McpServer: - Understanding system workload """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Get processing request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1817,7 +1771,6 @@ class McpServer: request_data = { "operation": "list-processing", - "user": user } gen = manager.request("librarian", request_data, None) @@ -1837,16 +1790,15 @@ class McpServer: title: str = "", comments: str = "", tags: List[str] | None = None, - user: str | None = None, ctx: Context = None, ) -> LoadDocumentResponse: """ Upload a document to the TrustGraph document library. - + This tool stores documents with rich metadata for later processing, search, and knowledge extraction. Documents can be text files, PDFs, or other supported formats. - + Args: document: The document content as a string. For binary files, this should be base64-encoded content. @@ -1856,11 +1808,10 @@ class McpServer: title: Human-readable title for the document. comments: Optional description or notes about the document. tags: List of tags for categorizing and finding the document. - user: User identifier (default: "trustgraph"). - + Returns: LoadDocumentResponse confirming the document has been stored. - + Use this for: - Adding new documents to the knowledge base - Storing reference materials and data sources @@ -1868,7 +1819,6 @@ class McpServer: - Importing external content for analysis """ - if user is None: user = "trustgraph" if tags is None: tags = [] if ctx is None: @@ -1876,7 +1826,7 @@ class McpServer: logging.info("Load document request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1897,7 +1847,6 @@ class McpServer: "title": title, "comments": comments, "metadata": metadata, - "user": user, "tags": tags }, "content": document @@ -1913,40 +1862,35 @@ class McpServer: async def remove_document( self, document_id: str, - user: str | None = None, ctx: Context = None, ) -> RemoveDocumentResponse: """ Permanently remove a document from the library. - + This operation deletes a document and all its associated metadata. Use with caution as this action cannot be undone. - + Args: document_id: Unique identifier of the document to remove. - user: User identifier (default: "trustgraph"). Only documents - owned by this user can be removed. - + Returns: RemoveDocumentResponse confirming the document has been deleted. - + Use this for: - Cleaning up obsolete or incorrect documents - Managing storage space - Removing sensitive or inappropriate content - Maintaining organized document collections - + Warning: This permanently deletes the document and all its metadata. """ - if user is None: user = "trustgraph" - if ctx is None: raise RuntimeError("No context provided") logging.info("Remove document request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -1958,7 +1902,6 @@ class McpServer: request_data = { "operation": "remove-document", "document-id": document_id, - "user": user } gen = manager.request("librarian", request_data, None) @@ -1973,42 +1916,39 @@ class McpServer: processing_id: str, document_id: str, flow: str, - user: str | None = None, collection: str | None = None, tags: List[str] | None = None, ctx: Context = None, ) -> AddProcessingResponse: """ Queue a document for processing through a specific workflow. - + This tool adds a document to the processing queue where it will be processed by the specified flow to extract knowledge, create embeddings, or perform other analysis operations. - + Args: processing_id: Unique identifier for this processing job. document_id: ID of the document to process (must exist in library). flow: Processing flow to use. Different flows perform different types of analysis (e.g., knowledge extraction, summarization). - user: User identifier (default: "trustgraph"). collection: Target collection for processed knowledge (default: "default"). Results will be stored under this collection name. tags: Optional tags for categorizing this processing job. - + Returns: AddProcessingResponse confirming the document has been queued. - + Use this for: - Processing uploaded documents into knowledge - Extracting entities and relationships from text - Creating searchable embeddings - Converting documents into structured knowledge - + Note: Processing may take time depending on document size and flow complexity. Use get_processing to monitor progress. """ - if user is None: user = "trustgraph" if collection is None: collection = "default" if tags is None: tags = [] @@ -2017,7 +1957,7 @@ class McpServer: logging.info("Add processing request made via websocket") - manager = await get_socket_manager(ctx, user) + manager = await get_socket_manager(ctx) await ctx.session.send_log_message( level="info", @@ -2036,7 +1976,6 @@ class McpServer: "document-id": document_id, "time": timestamp, "flow": flow, - "user": user, "collection": collection, "tags": tags } diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index 4844b104..9d955d17 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -91,7 +91,7 @@ class Processor(FlowProcessor): if v.document_id: doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": logger.error( @@ -106,7 +106,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str): content = content.encode('utf-8') @@ -141,7 +141,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, - user=v.metadata.user, + workspace=flow.workspace, content=page_content, document_type="page", title=f"Page {page_num}", @@ -163,7 +163,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -175,7 +174,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=pg_uri, root=v.metadata.root, - user=v.metadata.user, collection=v.metadata.collection, ), document_id=page_doc_id, diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py index 6b7d0246..b3723655 100644 --- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py +++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py @@ -275,7 +275,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=doc_id, parent_id=parent_doc_id, - user=metadata.user, + workspace=flow.workspace, content=page_content, document_type="page" if is_page else "section", title=label, @@ -303,7 +303,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=entity_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -314,7 +313,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=entity_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), document_id=doc_id, @@ -356,7 +354,7 @@ class Processor(FlowProcessor): await self.librarian.save_child_document( doc_id=img_uri, parent_id=parent_doc_id, - user=metadata.user, + workspace=flow.workspace, content=img_content, document_type="image", title=f"Image from page {page_number}" if page_number else "Image", @@ -379,7 +377,6 @@ class Processor(FlowProcessor): metadata=Metadata( id=img_uri, root=metadata.root, - user=metadata.user, collection=metadata.collection, ), triples=set_graph(prov_triples, GRAPH_SOURCE), @@ -404,13 +401,13 @@ class Processor(FlowProcessor): doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) mime_type = doc_meta.kind if doc_meta else None content = await self.librarian.fetch_document_content( document_id=v.document_id, - user=v.metadata.user, + workspace=flow.workspace, ) if isinstance(content, str):