diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 8248dfbf..88d2b79e 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=2.4.999 + run: make update-package-versions VERSION=2.5.999 - name: Setup environment run: python3 -m venv env diff --git a/.gitignore b/.gitignore index 32942156..366edb4a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-mcp/trustgraph/mcp_version.py trustgraph/trustgraph/trustgraph_version.py -vertexai/ \ No newline at end of file +vertexai/ +venv/ +.venv/ +.env diff --git a/containers/Containerfile.hf b/containers/Containerfile.hf index 93768b54..9cc6c2c8 100644 --- a/containers/Containerfile.hf +++ b/containers/Containerfile.hf @@ -23,7 +23,7 @@ RUN pip3 install --no-cache-dir \ langchain==1.2.16 langchain-core==1.3.2 langchain-huggingface==1.2.2 \ langchain-community==0.4.1 \ sentence-transformers==5.4.1 transformers==5.7.0 \ - huggingface-hub==1.13.0 \ + huggingface-hub==1.13.0 click \ pulsar-client==3.11.0 # Most commonly used embeddings model, just build it into the container diff --git a/dev-tools/library_client.py b/dev-tools/library_client.py index ae9d6857..30e0c344 100644 --- a/dev-tools/library_client.py +++ b/dev-tools/library_client.py @@ -25,7 +25,7 @@ BUCKET_URL = "https://storage.googleapis.com/trustgraph-library" INDEX_URL = f"{BUCKET_URL}/index.json" default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -113,7 +113,7 @@ def convert_metadata(metadata_json): return triples -def load_document(api, user, doc_entry): +def load_document(api, doc_entry): """Fetch metadata and content for a document, then load into TrustGraph.""" doc_id = doc_entry["id"] title = doc_entry["title"] @@ -133,7 +133,6 @@ def load_document(api, user, doc_entry): api.add_document( id=doc["id"], metadata=metadata, - user=user, kind=doc["kind"], title=doc["title"], comments=doc["comments"], @@ -144,12 +143,12 @@ def load_document(api, user, doc_entry): print(f" done.") -def load_documents(api, user, docs): +def load_documents(api, docs): """Load a list of documents.""" print(f"Loading {len(docs)} document(s)...\n") for doc in docs: try: - load_document(api, user, doc) + load_document(api, doc) except Exception as e: print(f" FAILED: {e}", file=sys.stderr) print() @@ -166,8 +165,8 @@ def main(): help=f"TrustGraph API URL (default: {default_url})", ) parser.add_argument( - "-U", "--user", default=default_user, - help=f"User ID (default: {default_user})", + "-w", "--workspace", default=default_workspace, + help=f"Workspace (default: {default_workspace})", ) parser.add_argument( "-t", "--token", default=default_token, @@ -212,22 +211,22 @@ def main(): return # Load commands need the API - api = Api(args.url, token=args.token).library() + api = Api(args.url, token=args.token, workspace=args.workspace).library() if args.command == "load-all": - load_documents(api, args.user, index) + load_documents(api, index) elif args.command == "load-doc": matches = [d for d in index if str(d.get("id")) == args.id] if not matches: print(f"No document with ID '{args.id}' found.", file=sys.stderr) sys.exit(1) - load_documents(api, args.user, matches) + load_documents(api, matches) elif args.command == "load-match": results = search_index(index, args.query) if results: - load_documents(api, args.user, results) + load_documents(api, results) else: print("No matches found.", file=sys.stderr) sys.exit(1) diff --git a/dev-tools/tests/relay/websocket_relay.py b/dev-tools/tests/relay/websocket_relay.py index d537f7da..6495ee49 100644 --- a/dev-tools/tests/relay/websocket_relay.py +++ b/dev-tools/tests/relay/websocket_relay.py @@ -3,208 +3,278 @@ WebSocket Relay Test Harness This script creates a relay server with two WebSocket endpoints: -- /in - for test clients to connect to -- /out - for reverse gateway to connect to +- /in - for test clients to connect to (speaks api-gateway protocol) +- /out - for reverse gateway to connect to (speaks rev-gateway protocol) -Messages are bidirectionally relayed between the two connections. +Clients on /in authenticate with a first-frame auth message: + {"type": "auth", "token": "..."} + +The relay stores the token and injects it into each subsequent message +before forwarding to /out. Responses from /out are forwarded back to +the originating /in connection unchanged. Usage: python websocket_relay.py [--port PORT] [--host HOST] """ import asyncio +import json import logging import argparse from aiohttp import web, WSMsgType -import weakref -from typing import Optional, Set +from typing import Dict, Optional -# Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("websocket_relay") + +class InConnection: + def __init__(self, ws, conn_id): + self.ws = ws + self.conn_id = conn_id + self.token: Optional[str] = None + self.authenticated = False + + class WebSocketRelay: - """WebSocket relay that forwards messages between 'in' and 'out' connections""" - + def __init__(self): - self.in_connections: Set = weakref.WeakSet() - self.out_connections: Set = weakref.WeakSet() - + self.in_connections: Dict[str, InConnection] = {} + self.out_connections: set = set() + self._conn_counter = 0 + + def _next_conn_id(self): + self._conn_counter += 1 + return f"conn-{self._conn_counter}" + async def handle_in_connection(self, request): - """Handle incoming connections on /in endpoint""" ws = web.WebSocketResponse() await ws.prepare(request) - - self.in_connections.add(ws) - logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + conn_id = self._next_conn_id() + conn = InConnection(ws, conn_id) + self.in_connections[conn_id] = conn + logger.info( + f"New 'in' connection {conn_id}. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + try: async for msg in ws: if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"IN → OUT: {data}") - await self._forward_to_out(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"IN → OUT: {len(data)} bytes (binary)") - await self._forward_to_out(data, binary=True) + await self._handle_in_message(conn, msg.data) elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") + logger.error( + f"WebSocket error on 'in' connection " + f"{conn_id}: {ws.exception()}" + ) break else: break except Exception as e: - logger.error(f"Error in 'in' connection handler: {e}") + logger.error( + f"Error in 'in' connection {conn_id}: {e}" + ) finally: - logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + del self.in_connections[conn_id] + logger.info( + f"'in' connection {conn_id} closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + return ws - - async def handle_out_connection(self, request): - """Handle outgoing connections on /out endpoint""" - ws = web.WebSocketResponse() - await ws.prepare(request) - - self.out_connections.add(ws) - logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + async def _handle_in_message(self, conn, data): try: - async for msg in ws: - if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"OUT → IN: {data}") - await self._forward_to_in(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"OUT → IN: {len(data)} bytes (binary)") - await self._forward_to_in(data, binary=True) - elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'out' connection: {ws.exception()}") - break - else: - break - except Exception as e: - logger.error(f"Error in 'out' connection handler: {e}") - finally: - logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - - return ws - - async def _forward_to_out(self, data, binary=False): - """Forward message from 'in' to all 'out' connections""" - if not self.out_connections: - logger.warning("No 'out' connections available to forward message") + message = json.loads(data) + except json.JSONDecodeError: + logger.warning( + f"{conn.conn_id}: received non-JSON message" + ) return - - closed_connections = [] + + if isinstance(message, dict) and message.get("type") == "auth": + conn.token = message.get("token", "") + conn.authenticated = True + logger.info(f"{conn.conn_id}: authenticated") + await conn.ws.send_json({ + "type": "auth-ok", + "workspace": "relayed", + }) + return + + if not conn.authenticated: + await conn.ws.send_json({ + "error": { + "message": "auth required", + "type": "auth-required", + }, + "complete": True, + }) + return + + message["token"] = conn.token + message["_relay_conn"] = conn.conn_id + + forwarded = json.dumps(message) + logger.info(f"IN {conn.conn_id} → OUT: {forwarded}") + await self._forward_to_out(forwarded) + + async def handle_out_connection(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.out_connections.add(ws) + logger.info( + f"New 'out' connection. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + await self._handle_out_message(msg.data) + elif msg.type == WSMsgType.ERROR: + logger.error( + f"WebSocket error on 'out' connection: " + f"{ws.exception()}" + ) + break + else: + break + except Exception as e: + logger.error(f"Error in 'out' connection: {e}") + finally: + self.out_connections.discard(ws) + logger.info( + f"'out' connection closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + return ws + + async def _handle_out_message(self, data): + try: + message = json.loads(data) + except json.JSONDecodeError: + logger.warning("OUT: received non-JSON message") + return + + conn_id = message.pop("_relay_conn", None) + + forwarded = json.dumps(message) + logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}") + + if conn_id and conn_id in self.in_connections: + conn = self.in_connections[conn_id] + try: + if not conn.ws.closed: + await conn.ws.send_str(forwarded) + except Exception as e: + logger.error( + f"Error forwarding to 'in' {conn_id}: {e}" + ) + else: + await self._broadcast_to_in(forwarded) + + async def _broadcast_to_in(self, data): + closed = [] + for conn_id, conn in list(self.in_connections.items()): + try: + if conn.ws.closed: + closed.append(conn_id) + continue + await conn.ws.send_str(data) + except Exception as e: + logger.error( + f"Error broadcasting to 'in' {conn_id}: {e}" + ) + closed.append(conn_id) + for conn_id in closed: + self.in_connections.pop(conn_id, None) + + async def _forward_to_out(self, data): + closed = [] for ws in list(self.out_connections): try: if ws.closed: - closed_connections.append(ws) + closed.append(ws) continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) + await ws.send_str(data) except Exception as e: - logger.error(f"Error forwarding to 'out' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.out_connections: - self.out_connections.discard(ws) - - async def _forward_to_in(self, data, binary=False): - """Forward message from 'out' to all 'in' connections""" - if not self.in_connections: - logger.warning("No 'in' connections available to forward message") - return - - closed_connections = [] - for ws in list(self.in_connections): - try: - if ws.closed: - closed_connections.append(ws) - continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) - except Exception as e: - logger.error(f"Error forwarding to 'in' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.in_connections: - self.in_connections.discard(ws) + logger.error(f"Error forwarding to 'out': {e}") + closed.append(ws) + for ws in closed: + self.out_connections.discard(ws) + async def create_app(relay): - """Create the web application with routes""" app = web.Application() - - # Add routes - app.router.add_get('/in', relay.handle_in_connection) + + app.router.add_get('/in/api/v1/socket', relay.handle_in_connection) app.router.add_get('/out', relay.handle_out_connection) - - # Add a simple status endpoint + async def status(request): - status_info = { + return web.json_response({ 'in_connections': len(relay.in_connections), 'out_connections': len(relay.out_connections), - 'status': 'running' - } - return web.json_response(status_info) - + 'status': 'running', + }) + app.router.add_get('/status', status) - app.router.add_get('/', status) # Root also shows status - + app.router.add_get('/', status) + return app + def main(): parser = argparse.ArgumentParser( description="WebSocket Relay Test Harness" ) parser.add_argument( - '--host', + '--host', default='localhost', - help='Host to bind to (default: localhost)' + help='Host to bind to (default: localhost)', ) parser.add_argument( - '--port', - type=int, + '--port', + type=int, default=8080, - help='Port to bind to (default: 8080)' + help='Port to bind to (default: 8080)', ) parser.add_argument( '--verbose', '-v', action='store_true', - help='Enable verbose logging' + help='Enable verbose logging', ) - + args = parser.parse_args() - + if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + relay = WebSocketRelay() - + print(f"Starting WebSocket Relay on {args.host}:{args.port}") - print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") + print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket") print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") print(f" Status: http://{args.host}:{args.port}/status") print() - print("Usage:") - print(f" Test client connects to: ws://{args.host}:{args.port}/in") - print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") - + print("Client protocol (same as api-gateway):") + print(' 1. Connect to /in/api/v1/socket') + print(' 2. Send: {"type": "auth", "token": "tg_..."}') + print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}') + print(' 4. Send requests as normal') + web.run_app(create_app(relay), host=args.host, port=args.port) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/tech-specs/no-auth-regime.md b/docs/tech-specs/no-auth-regime.md new file mode 100644 index 00000000..ae8b427f --- /dev/null +++ b/docs/tech-specs/no-auth-regime.md @@ -0,0 +1,186 @@ +--- +layout: default +title: "No-Auth IAM Regime" +parent: "Tech Specs" +--- + +# No-Auth IAM Regime + +## Overview + +A minimal IAM regime that permits all access unconditionally. +Implements the same Pulsar request/response protocol as `iam-svc` +(see [iam-contract.md](iam-contract.md)) so it is a drop-in +replacement: swap `iam-svc` for `no-auth-svc` in the deployment +and the gateway, bootstrapper, and all other components continue +to work without modification. + +Intended for development, testing, single-tenant self-hosted +deployments, and evaluation environments where authentication +overhead is unwanted. + +## Motivation + +The full IAM regime requires Cassandra tables, a bootstrap +sequence, API key management, and signing key rotation. For +many deployments this is unnecessary friction: + +- Local development and CI/CD pipelines. +- Single-user or small-team self-hosted instances. +- Evaluation and demo environments. +- Deployments behind an external authentication proxy + (e.g. OAuth2 reverse proxy, VPN-gated access). + +Today operators who want no auth must still deploy `iam-svc` and +complete the bootstrap ceremony. A purpose-built no-auth regime +eliminates that requirement entirely. + +## Design + +### Deployment + +Replace `iam-svc` with `no-auth-svc` in the processor group or +container configuration. No other services change. The no-auth +service listens on the standard IAM Pulsar topics: + +- Request: `request::iam` +- Response: `response::iam` + +### Dependencies + +None. No database, no config entries, no signing keys, no +bootstrap sequence. + +### Operation responses + +The service implements the IAM contract +([iam-contract.md](iam-contract.md)) with the following +behaviour for each operation: + +| Operation | Behaviour | +|---|---| +| `authenticate-anonymous` | Returns a default identity: `user_id="anonymous"`, `workspace="default"`, `roles=["admin"]`. This is the key operation that distinguishes no-auth from the full regime. | +| `resolve-api-key` | Accepts any token. Returns the same default identity as `authenticate-anonymous`. | +| `authorise` | Always allows. Returns `decision_allow=True`, `decision_ttl_seconds=3600`. | +| `authorise-many` | Always allows all checks. | +| `get-signing-key-public` | Returns an empty string. The gateway skips JWT validation when no key is available. | +| `bootstrap` | No-op. Returns empty admin user/key. | +| `bootstrap-status` | Returns `bootstrap_available=False`. | +| `whoami` | Returns a stub user record for the actor. | +| `login` | Returns empty JWT (not supported under no-auth). | +| `create-user`, `list-users`, `get-user`, `update-user`, `delete-user`, `disable-user`, `enable-user` | Return empty/stub responses. User management is meaningless without auth. | +| `create-workspace`, `list-workspaces`, `get-workspace`, `update-workspace`, `disable-workspace` | Return empty/stub responses. | +| `create-api-key`, `list-api-keys`, `revoke-api-key` | Return empty/stub responses. | +| `change-password`, `reset-password` | No-op. | +| `rotate-signing-key` | No-op. | +| Unknown operation | Returns an error response (same as `iam-svc`). | + +### Workspace resolution + +When `resolve-api-key` is called, the returned workspace +determines which workspace the request operates against. The +no-auth service defaults to `"default"`. + +A configurable `--default-workspace` flag allows operators to +change this without code changes. + +### Anonymous authentication + +A new `authenticate-anonymous` operation is added to the IAM +protocol. This is a small, backward-compatible addition to the +contract: + +**Gateway change** (`auth.py`): when `authenticate()` receives a +request with no `Authorization` header (or an empty bearer +token), instead of immediately returning 401, it sends an +`authenticate-anonymous` request to the IAM service. If the +regime returns a valid identity, the request proceeds. If the +regime returns an error, the gateway returns 401 as before. + +**`iam-svc` (full regime)**: returns `auth-failed` for +`authenticate-anonymous`. Behaviour is unchanged — unauthenticated +requests are rejected exactly as they are today. + +**`no-auth-svc`**: returns the default identity (`anonymous` / +`default` workspace). No token required. + +This keeps the policy decision ("is anonymous access allowed?") +in the IAM regime, not in the gateway. The gateway is a generic +enforcement point that asks and respects the answer. + +**Wire format**: uses the existing `IamRequest` / `IamResponse` +schema with `operation="authenticate-anonymous"`. No new fields +required — the response uses `resolved_user_id`, +`resolved_workspace`, and `resolved_roles`, same as +`resolve-api-key`. + +Requests that do carry a bearer token follow the existing +`resolve-api-key` / JWT paths unchanged. + +## Implementation + +### Service structure + +The service is a standard `AsyncProcessor` that consumes IAM +requests and produces IAM responses, identical in shape to the +existing `iam-svc` processor: + +``` +trustgraph-flow/ + trustgraph/ + iam/ + noauth/ + __init__.py + __main__.py + service.py # AsyncProcessor wiring + handler.py # Operation dispatch, always-allow logic +``` + +### Handler + +The handler is a single `handle(request) -> response` function +with a dispatch table. Each operation returns a pre-built +`IamResponse` with the appropriate fields set. No database +access, no crypto, no state. + +### Configuration + +| Flag | Default | Description | +|---|---|---| +| `--default-workspace` | `"default"` | Workspace returned by `resolve-api-key` | +| `--default-user-id` | `"anonymous"` | User ID returned by `resolve-api-key` | + +### Entry point + +``` +tg-no-auth-svc +``` + +Or via processor group: + +```yaml +- class: trustgraph.iam.noauth.Processor + params: + <<: *defaults + id: no-auth-svc +``` + +## Security considerations + +This regime provides **no security whatsoever**. Any caller with +network access to the API gateway has full admin access to all +workspaces. + +Operators must ensure that network-level controls (firewall, +VPN, private network) provide adequate protection when deploying +this regime. The regime is explicitly not suitable for multi- +tenant or internet-facing deployments. + +## Testing + +- Unit: verify each operation returns the expected stub response. +- Integration: deploy `no-auth-svc` in place of `iam-svc`, confirm + the gateway starts, accepts requests with a dummy bearer token, + and routes them to the default workspace. +- E2E: run the standard e2e test suite with `no-auth-svc` to + confirm no regressions. diff --git a/docs/tech-specs/ontorag.md b/docs/tech-specs/ontorag.md index 86a3cd19..460e72ba 100644 --- a/docs/tech-specs/ontorag.md +++ b/docs/tech-specs/ontorag.md @@ -278,7 +278,7 @@ The system uses **FAISS (Facebook AI Similarity Search)** with IndexFlatIP for e 3. **Similarity Search**: - For each text segment embedding, search the vector store - Retrieve top-k (e.g., 10) most similar ontology elements - - Apply similarity threshold (e.g., 0.7) to filter weak matches + - Apply similarity threshold (e.g., 0.3) to filter weak matches - Aggregate results across all segments, tracking match frequencies 4. **Dependency Resolution**: diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 514a5dbf..290e1348 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -63,26 +63,26 @@ class TestEndToEndConfigurationFlow: 'CASSANDRA_USERNAME': 'obj-user', 'CASSANDRA_PASSWORD': 'obj-pass' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) - + # Trigger Cassandra connection processor.connect_cassandra() - + # Verify auth provider was created with env vars mock_auth_provider.assert_called_once_with( username='obj-user', password='obj-pass' ) - + # Verify cluster was created with hosts from env and auth mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -188,37 +188,34 @@ class TestConfigurationPriorityEndToEnd: ) @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra_kg.Cluster') - async def test_no_config_defaults_end_to_end(self, mock_cluster): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_no_config_defaults_end_to_end(self, mock_kg_class): """Test that defaults are used when no configuration provided end-to-end.""" - mock_cluster_instance = MagicMock() - mock_session = MagicMock() - mock_cluster_instance.connect.return_value = mock_session - mock_cluster.return_value = mock_cluster_instance - + from unittest.mock import AsyncMock + + mock_tg_instance = MagicMock() + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_kg_class.return_value = mock_tg_instance + with patch.dict(os.environ, {}, clear=True): processor = TriplesQuery(taskgroup=MagicMock()) - + # Mock query to trigger TrustGraph creation mock_query = MagicMock() mock_query.collection = 'default_collection' mock_query.s = None mock_query.p = None mock_query.o = None + mock_query.g = None mock_query.limit = 100 - - # Mock the get_all method to return empty list - mock_tg_instance = MagicMock() - mock_tg_instance.get_all.return_value = [] - processor.tg = mock_tg_instance - + await processor.query_triples('default_user', mock_query) - + # Should use defaults - mock_cluster.assert_called_once() - call_args = mock_cluster.call_args - assert call_args.args[0] == ['cassandra'] # Default host - assert 'auth_provider' not in call_args.kwargs # No auth with default config + mock_kg_class.assert_called_once_with( + hosts=['cassandra'], + keyspace='default_user' + ) class TestNoBackwardCompatibilityEndToEnd: @@ -324,16 +321,16 @@ class TestMultipleHostsHandling: env_vars = { 'CASSANDRA_HOST': 'host1,host2,host3,host4,host5' } - + mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Verify all hosts were passed to Cluster mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -392,27 +389,27 @@ class TestAuthenticationFlow: 'CASSANDRA_USERNAME': 'auth-user', 'CASSANDRA_PASSWORD': 'auth-secret' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should be created mock_auth_provider.assert_called_once_with( username='auth-user', password='auth-secret' ) - + # Cluster should be created with auth provider call_args = mock_cluster.call_args assert 'auth_provider' in call_args.kwargs assert call_args.kwargs['auth_provider'] == mock_auth_instance - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster): @@ -421,21 +418,21 @@ class TestAuthenticationFlow: 'CASSANDRA_HOST': 'no-auth-host' # No username/password } - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should not be created mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster): @@ -446,15 +443,15 @@ class TestAuthenticationFlow: cassandra_username='partial-user' # No password ) - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + processor.connect_cassandra() - + # Auth provider should not be created (needs both username AND password) mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs \ No newline at end of file diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index 1358d420..d668600c 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -101,6 +101,8 @@ class TestRowsCassandraIntegration: processor.session = None # Bind actual methods from the new unified table implementation + import asyncio + processor._setup_lock = asyncio.Lock() processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) @@ -108,6 +110,7 @@ class TestRowsCassandraIntegration: processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_object = Processor.on_object.__get__(processor, Processor) processor.collection_exists = MagicMock(return_value=True) diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index 29b4464d..a455accd 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration: await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra - processor.connect_cassandra() + await processor.connect_cassandra() assert processor.session is not None # Create test keyspace and table @@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration: """Test inserting data and querying via GraphQL""" # Load schema and connect await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Setup test data keyspace = "test_user" @@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration: """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "test_user" collection = "filter_test" @@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration: """Test full message processing workflow""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create mock message request = RowsQueryRequest( @@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling multiple concurrent GraphQL queries""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create multiple query tasks queries = [ @@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling of large query result sets""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "large_test_user" collection = "large_collection" diff --git a/tests/pytest.ini b/tests/pytest.ini index 5dcc095c..a89759ab 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -16,4 +16,13 @@ markers = unit: marks tests as unit tests contract: marks tests as contract tests (service interface validation) vertexai: marks tests as vertex ai specific tests - asyncio: marks tests that use asyncio \ No newline at end of file + asyncio: marks tests that use asyncio +# This is helpful if you're bored with deprecationwarnings. I prefer to +# keep the warnings for now, it avoids masking problems. +# +# filterwarnings = +# ignore:Core Pydantic V1 functionality isn't compatible with Python 3.14.*:UserWarning +# ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning +# ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning +# ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning +# ignore:.*_UnionGenericAlias.*is deprecated and slated for removal in Python 3.17:DeprecationWarning diff --git a/tests/unit/test_api/test_library_api.py b/tests/unit/test_api/test_library_api.py new file mode 100644 index 00000000..086ecd63 --- /dev/null +++ b/tests/unit/test_api/test_library_api.py @@ -0,0 +1,296 @@ +""" +Tests for the Library API wrapper round-trip behavior. +Covers the get_documents → update_document path and edge cases +from issue #893. +""" + +import datetime +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.api.library import Library, to_value, from_value +from trustgraph.api.types import DocumentMetadata, Triple +from trustgraph.knowledge import Uri, Literal + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_library(response=None): + api = MagicMock() + api.workspace = "default" + api.request.return_value = response or {} + lib = Library(api) + return lib, api + + +def _wire_triple(s_iri, p_iri, o_val): + return { + "s": {"t": "i", "i": s_iri}, + "p": {"t": "i", "i": p_iri}, + "o": {"t": "l", "v": o_val}, + } + + +def _doc_wire(id="doc-1", time=1700000000, title="Test Doc", + kind="text/plain", comments="", tags=None, + metadata=None, parent_id="", document_type="source", + include_title=True): + doc = { + "id": id, + "time": time, + "kind": kind, + "comments": comments, + "metadata": metadata or [], + "tags": tags or [], + "parent-id": parent_id, + "document-type": document_type, + } + if include_title: + doc["title"] = title + return doc + + +# --------------------------------------------------------------------------- +# Bug 1: get_documents tolerates missing title +# --------------------------------------------------------------------------- + +class TestGetDocumentsMissingTitle: + + def test_missing_title_defaults_to_empty(self): + doc = _doc_wire(include_title=False) + lib, api = _make_library({"document-metadatas": [doc]}) + + result = lib.get_documents() + + assert len(result) == 1 + assert result[0].title == "" + + def test_present_title_preserved(self): + doc = _doc_wire(title="My Title") + lib, api = _make_library({"document-metadatas": [doc]}) + + result = lib.get_documents() + + assert result[0].title == "My Title" + + +# --------------------------------------------------------------------------- +# Bug 2: update_document handles Triple objects (attribute access) +# --------------------------------------------------------------------------- + +class TestUpdateDocumentTripleAccess: + + def test_triple_objects_serialized_correctly(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Test", + comments="", + metadata=[ + Triple( + s=Uri("http://example.org/entity/alice"), + p=Uri("http://example.org/rel/knows"), + o=Literal("Bob"), + ), + ], + tags=["test"], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + triples = call_args["document-metadata"]["metadata"] + + assert len(triples) == 1 + assert triples[0]["s"]["i"] == "http://example.org/entity/alice" + assert triples[0]["p"]["i"] == "http://example.org/rel/knows" + assert triples[0]["o"]["v"] == "Bob" + + def test_empty_metadata_list(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + assert call_args["document-metadata"]["metadata"] == [] + + +# --------------------------------------------------------------------------- +# Bug 3: update_document serializes datetime to int seconds +# --------------------------------------------------------------------------- + +class TestUpdateDocumentTimeSerialization: + + def test_datetime_serialized_to_int(self): + lib, api = _make_library({}) + + ts = 1700000000 + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(ts), + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + wire_time = call_args["document-metadata"]["time"] + + assert isinstance(wire_time, int) + assert wire_time == ts + + def test_int_time_passed_through(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=1700000000, + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + assert call_args["document-metadata"]["time"] == 1700000000 + + +# --------------------------------------------------------------------------- +# Bug 4: update_document handles empty server response +# --------------------------------------------------------------------------- + +class TestUpdateDocumentEmptyResponse: + + def test_empty_response_returns_input_metadata(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Updated Title", + comments="notes", + metadata=[], + tags=["a"], + ) + + result = lib.update_document(id="doc-1", metadata=metadata) + + assert result is metadata + + def test_full_response_parsed(self): + response_doc = _doc_wire( + id="doc-1", title="Server Title", tags=["b"], + ) + lib, api = _make_library({"document-metadata": response_doc}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Client Title", + comments="", + metadata=[], + tags=["a"], + ) + + result = lib.update_document(id="doc-1", metadata=metadata) + + assert result.title == "Server Title" + assert result.tags == ["b"] + + +# --------------------------------------------------------------------------- +# Bug 5: update_document sends both id and document-id +# --------------------------------------------------------------------------- + +class TestUpdateDocumentIdKeys: + + def test_both_id_keys_sent(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + doc_meta = call_args["document-metadata"] + + assert doc_meta["id"] == "doc-1" + assert doc_meta["document-id"] == "doc-1" + + +# --------------------------------------------------------------------------- +# Round-trip: get_documents → update_document +# --------------------------------------------------------------------------- + +class TestGetUpdateRoundTrip: + + def test_full_round_trip(self): + wire_doc = _doc_wire( + id="doc-42", + title="Original", + tags=["v1"], + metadata=[_wire_triple( + "http://example.org/e/1", + "http://example.org/r/type", + "report", + )], + ) + + lib, api = _make_library({"document-metadatas": [wire_doc]}) + + docs = lib.get_documents() + assert len(docs) == 1 + + doc = docs[0] + doc.title = "Updated" + doc.tags.append("v2") + + # Server returns empty on update + api.request.return_value = {} + result = lib.update_document(id=doc.id, metadata=doc) + + # Should not raise, should return the input metadata + assert result.title == "Updated" + assert "v2" in result.tags + + # Verify the wire format sent + call_args = api.request.call_args[0][1] + doc_meta = call_args["document-metadata"] + + assert doc_meta["id"] == "doc-42" + assert doc_meta["title"] == "Updated" + assert isinstance(doc_meta["time"], int) + assert len(doc_meta["metadata"]) == 1 + assert doc_meta["metadata"][0]["o"]["v"] == "report" diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 59c7f2b5..44d82182 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -272,23 +272,22 @@ class TestMetricsIntegration: class TestPollTimeout: @pytest.mark.asyncio - async def test_poll_timeout_is_100ms(self): - """Consumer receive timeout should be 100ms, not the original 2000ms. + async def test_poll_timeout_is_2000ms(self): + """Consumer receive timeout should be 2000ms. - A 2000ms poll timeout means every service adds up to 2s of idle - blocking between message bursts. With many sequential hops in a - query pipeline, this compounds into seconds of unnecessary latency. - 100ms keeps responsiveness high without significant CPU overhead. + receive() is a blocking call that returns immediately when a + message arrives — the timeout only governs how often the loop + checks the shutdown flag during idle periods. Lower values + (e.g. 100ms) generate excessive C++ client WARN logging with + no latency benefit. """ consumer = _make_consumer() - # Wire up a mock Pulsar consumer that records the receive kwargs mock_pulsar_consumer = MagicMock() received_kwargs = {} def capture_receive(**kwargs): received_kwargs.update(kwargs) - # Stop after one call consumer.running = False raise type('Timeout', (Exception,), {})("timeout") @@ -296,7 +295,7 @@ class TestPollTimeout: await consumer.consume_from_queue(mock_pulsar_consumer) - assert received_kwargs.get("timeout_millis") == 100 + assert received_kwargs.get("timeout_millis") == 2000 # --------------------------------------------------------------------------- diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py index 6a1ae8ab..a5374678 100644 --- a/tests/unit/test_concurrency/test_dispatcher_semaphore.py +++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py @@ -25,16 +25,17 @@ class TestSemaphoreEnforcement: max_concurrent = 0 processing_event = asyncio.Event() - async def slow_process(message): + async def slow_process(message, sender): nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) await asyncio.sleep(0.05) concurrent_count -= 1 - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = slow_process + sender = AsyncMock() + # Launch more tasks than max_workers messages = [ {"id": f"msg-{i}", "service": "test", "request": {}} @@ -42,7 +43,7 @@ class TestSemaphoreEnforcement: ] tasks = [ - asyncio.create_task(dispatcher.handle_message(m)) + asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages ] @@ -66,17 +67,17 @@ class TestSemaphoreEnforcement: original_process = dispatcher._process_message - async def tracking_process(message): + async def tracking_process(message, sender): nonlocal task_was_tracked # During processing, our task should be in active_tasks if len(dispatcher.active_tasks) > 0: task_was_tracked = True - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = tracking_process await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) assert task_was_tracked @@ -88,7 +89,7 @@ class TestSemaphoreEnforcement: """Semaphore should be released even if processing raises.""" dispatcher = MessageDispatcher(max_workers=2) - async def failing_process(message): + async def failing_process(message, sender): raise RuntimeError("process failed") dispatcher._process_message = failing_process @@ -96,7 +97,8 @@ class TestSemaphoreEnforcement: # Should not deadlock — semaphore must be released on error with pytest.raises(RuntimeError): await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) # Semaphore should be back at max @@ -109,17 +111,18 @@ class TestSemaphoreEnforcement: order = [] - async def ordered_process(message): + async def ordered_process(message, sender): msg_id = message["id"] order.append(f"start-{msg_id}") await asyncio.sleep(0.02) order.append(f"end-{msg_id}") - return {"id": msg_id, "response": {"ok": True}} dispatcher._process_message = ordered_process + sender = AsyncMock() + messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] - tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] + tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages] await asyncio.gather(*tasks) # With semaphore=1, each message should complete before next starts diff --git a/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py b/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py new file mode 100644 index 00000000..195e8adf --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py @@ -0,0 +1,389 @@ +""" +Tests for TripleConverter domain/range enforcement and +OntologySelector bypass for small ontologies. + +Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation). +""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from trustgraph.extract.kg.ontology.triple_converter import TripleConverter +from trustgraph.extract.kg.ontology.ontology_selector import ( + OntologySelector, + OntologySubset, +) +from trustgraph.extract.kg.ontology.ontology_loader import ( + Ontology, + OntologyClass, + OntologyProperty, +) +from trustgraph.extract.kg.ontology.simplified_parser import ( + Relationship, + Attribute, +) +from trustgraph.extract.kg.ontology.text_processor import TextSegment + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def ontology_subset(): + """Ontology subset with classes, hierarchy, and constrained properties.""" + return OntologySubset( + ontology_id="test", + classes={ + "Person": { + "uri": "http://example.org/Person", + "type": "owl:Class", + "labels": [{"value": "Person"}], + "subclass_of": None, + }, + "Employee": { + "uri": "http://example.org/Employee", + "type": "owl:Class", + "labels": [{"value": "Employee"}], + "subclass_of": "Person", + }, + "Manager": { + "uri": "http://example.org/Manager", + "type": "owl:Class", + "labels": [{"value": "Manager"}], + "subclass_of": "Employee", + }, + "Company": { + "uri": "http://example.org/Company", + "type": "owl:Class", + "labels": [{"value": "Company"}], + "subclass_of": None, + }, + "Product": { + "uri": "http://example.org/Product", + "type": "owl:Class", + "labels": [{"value": "Product"}], + "subclass_of": None, + }, + }, + object_properties={ + "worksFor": { + "uri": "http://example.org/worksFor", + "type": "owl:ObjectProperty", + "labels": [{"value": "works for"}], + "domain": "Person", + "range": "Company", + }, + "manages": { + "uri": "http://example.org/manages", + "type": "owl:ObjectProperty", + "labels": [{"value": "manages"}], + "domain": "Manager", + "range": "Employee", + }, + "relatedTo": { + "uri": "http://example.org/relatedTo", + "type": "owl:ObjectProperty", + "labels": [{"value": "related to"}], + "domain": None, + "range": None, + }, + }, + datatype_properties={ + "employeeId": { + "uri": "http://example.org/employeeId", + "type": "owl:DatatypeProperty", + "labels": [{"value": "employee ID"}], + "domain": "Employee", + }, + "description": { + "uri": "http://example.org/description", + "type": "owl:DatatypeProperty", + "labels": [{"value": "description"}], + "domain": None, + }, + }, + metadata={"name": "Test Ontology"}, + ) + + +@pytest.fixture +def converter(ontology_subset): + return TripleConverter(ontology_subset=ontology_subset, ontology_id="test") + + +# --------------------------------------------------------------------------- +# Domain/range enforcement — relationships +# --------------------------------------------------------------------------- + +class TestRelationshipDomainRange: + + def test_valid_domain_and_range(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + triple = converter.convert_relationship(rel) + assert triple is not None + + def test_domain_violation_rejected(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is None + + def test_range_violation_rejected(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="worksFor", + object="Widget", object_type="Product", + ) + assert converter.convert_relationship(rel) is None + + def test_both_domain_and_range_violated(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="worksFor", + object="Gadget", object_type="Product", + ) + assert converter.convert_relationship(rel) is None + + +# --------------------------------------------------------------------------- +# Subclass acceptance +# --------------------------------------------------------------------------- + +class TestSubclassAcceptance: + + def test_direct_subclass_matches_domain(self, converter): + """Employee is subclass of Person; worksFor domain is Person.""" + rel = Relationship( + subject="Bob", subject_type="Employee", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_transitive_subclass_matches_domain(self, converter): + """Manager → Employee → Person; worksFor domain is Person.""" + rel = Relationship( + subject="Carol", subject_type="Manager", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_subclass_matches_range(self, converter): + """manages range is Employee; Manager is subclass of Employee.""" + rel = Relationship( + subject="Carol", subject_type="Manager", + relation="manages", + object="Dave", object_type="Manager", + ) + assert converter.convert_relationship(rel) is not None + + def test_superclass_does_not_match_subclass_constraint(self, converter): + """manages domain is Manager; Person is NOT a subclass of Manager.""" + rel = Relationship( + subject="Alice", subject_type="Person", + relation="manages", + object="Bob", object_type="Employee", + ) + assert converter.convert_relationship(rel) is None + + +# --------------------------------------------------------------------------- +# Polymorphic properties (no domain/range) +# --------------------------------------------------------------------------- + +class TestPolymorphicProperties: + + def test_no_domain_no_range_allows_anything(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="relatedTo", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_polymorphic_with_unrelated_types(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="relatedTo", + object="Bob", object_type="Employee", + ) + assert converter.convert_relationship(rel) is not None + + +# --------------------------------------------------------------------------- +# Datatype property domain enforcement +# --------------------------------------------------------------------------- + +class TestAttributeDomainValidation: + + def test_valid_domain(self, converter): + attr = Attribute( + entity="Bob", entity_type="Employee", + attribute="employeeId", value="E-1234", + ) + assert converter.convert_attribute(attr) is not None + + def test_subclass_matches_domain(self, converter): + """Manager is subclass of Employee; employeeId domain is Employee.""" + attr = Attribute( + entity="Carol", entity_type="Manager", + attribute="employeeId", value="M-5678", + ) + assert converter.convert_attribute(attr) is not None + + def test_domain_violation_rejected(self, converter): + attr = Attribute( + entity="Acme Corp", entity_type="Company", + attribute="employeeId", value="E-0000", + ) + assert converter.convert_attribute(attr) is None + + def test_no_domain_allows_anything(self, converter): + attr = Attribute( + entity="Widget", entity_type="Product", + attribute="description", value="A useful widget", + ) + assert converter.convert_attribute(attr) is not None + + +# --------------------------------------------------------------------------- +# OntologySelector bypass for small ontologies (#908) +# --------------------------------------------------------------------------- + +def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0): + classes = { + f"C{i}": OntologyClass(uri=f"http://example.org/C{i}") + for i in range(n_classes) + } + obj_props = { + f"op{i}": OntologyProperty( + uri=f"http://example.org/op{i}", type="owl:ObjectProperty" + ) + for i in range(n_obj_props) + } + dt_props = { + f"dp{i}": OntologyProperty( + uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty" + ) + for i in range(n_dt_props) + } + return Ontology( + id="tiny", + metadata={"name": "Tiny"}, + classes=classes, + object_properties=obj_props, + datatype_properties=dt_props, + ) + + +def _make_loader(ontology): + loader = Mock() + loader.get_ontology.return_value = ontology + loader.get_all_ontologies.return_value = {"tiny": ontology} + return loader + + +class TestBypassSelectorBelow: + + async def test_bypass_returns_full_ontology(self): + """With 3 elements and bypass_selector_below=5, selector is bypassed.""" + ont = _make_ontology(2, 1, 0) + loader = _make_loader(ont) + embedder = Mock() + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + assert len(subsets) == 1 + assert subsets[0].ontology_id == "tiny" + assert len(subsets[0].classes) == 2 + assert len(subsets[0].object_properties) == 1 + assert subsets[0].relevance_score == 1.0 + # Embedder should never be called + embedder.embed_text.assert_not_called() + + async def test_no_bypass_when_above_threshold(self): + """With 10 elements and bypass_selector_below=5, selector runs normally.""" + ont = _make_ontology(6, 3, 1) + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2]) + vector_store = Mock() + vector_store.size.return_value = 10 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # Vector store was consulted (selector ran normally) + vector_store.size.assert_called_once() + + async def test_bypass_at_exact_threshold_not_triggered(self): + """With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=).""" + ont = _make_ontology(3, 1, 1) # total = 5 + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2]) + vector_store = Mock() + vector_store.size.return_value = 5 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # Should NOT bypass — 5 is not < 5 + vector_store.size.assert_called_once() + + async def test_bypass_zero_disables(self): + """bypass_selector_below=0 means bypass never triggers.""" + ont = _make_ontology(0, 0, 0) # empty ontology + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1]) + vector_store = Mock() + vector_store.size.return_value = 0 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=0, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # 0 is not < 0, so bypass doesn't trigger + vector_store.size.assert_called_once() diff --git a/tests/unit/test_gateway/test_auth.py b/tests/unit/test_gateway/test_auth.py index 26e93fd9..8ffcafa1 100644 --- a/tests/unit/test_gateway/test_auth.py +++ b/tests/unit/test_gateway/test_auth.py @@ -165,22 +165,37 @@ class TestIamAuthDispatch: by shape of the bearer.""" @pytest.mark.asyncio - async def test_no_authorization_header_raises_401(self): + async def test_no_authorization_header_tries_anonymous(self): auth = IamAuth(backend=Mock()) - with pytest.raises(web.HTTPUnauthorized): - await auth.authenticate(make_request(None)) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) @pytest.mark.asyncio - async def test_non_bearer_header_raises_401(self): + async def test_non_bearer_header_tries_anonymous(self): auth = IamAuth(backend=Mock()) - with pytest.raises(web.HTTPUnauthorized): - await auth.authenticate(make_request("Basic whatever")) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Basic whatever")) @pytest.mark.asyncio - async def test_empty_bearer_raises_401(self): + async def test_empty_bearer_tries_anonymous(self): auth = IamAuth(backend=Mock()) - with pytest.raises(web.HTTPUnauthorized): - await auth.authenticate(make_request("Bearer ")) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer ")) @pytest.mark.asyncio async def test_unknown_format_raises_401(self): @@ -445,3 +460,121 @@ class TestAuthorise: # Different resource → different cache key → two IAM calls. assert calls["n"] == 2 + + +# -- Anonymous authentication boundary ------------------------------------ + + +class TestAnonymousAuthBoundary: + """The gateway must only attempt anonymous auth when no credential + is presented. A malformed token must NOT fall through to the + anonymous path — that would let an attacker bypass a broken token + by simply sending garbage.""" + + @pytest.mark.asyncio + async def test_no_header_attempts_anonymous(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authenticate_anonymous=AsyncMock( + return_value=("anon", "default", ["reader"]), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = await auth.authenticate(make_request(None)) + assert ident.handle == "anon" + assert ident.source == "anonymous" + + @pytest.mark.asyncio + async def test_empty_bearer_attempts_anonymous(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authenticate_anonymous=AsyncMock( + return_value=("anon", "default", ["reader"]), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = await auth.authenticate(make_request("Bearer ")) + assert ident.handle == "anon" + assert ident.source == "anonymous" + + @pytest.mark.asyncio + async def test_malformed_token_does_not_fall_through_to_anonymous(self): + auth = IamAuth(backend=Mock()) + called = {"anonymous": False} + + original = auth._authenticate_anonymous + + async def spy_anonymous(): + called["anonymous"] = True + return await original() + + auth._authenticate_anonymous = spy_anonymous + + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer garbage")) + assert not called["anonymous"] + + @pytest.mark.asyncio + async def test_bad_api_key_does_not_fall_through_to_anonymous(self): + auth = IamAuth(backend=Mock()) + called = {"anonymous": False} + + async def spy_anonymous(): + called["anonymous"] = True + + auth._authenticate_anonymous = spy_anonymous + + async def fake_with_client(op): + raise RuntimeError("auth-failed: unknown key") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer tg_bad")) + assert not called["anonymous"] + + @pytest.mark.asyncio + async def test_bad_jwt_does_not_fall_through_to_anonymous(self): + auth = IamAuth(backend=Mock()) + auth._signing_public_pem = "not-a-real-pem" + called = {"anonymous": False} + + async def spy_anonymous(): + called["anonymous"] = True + + auth._authenticate_anonymous = spy_anonymous + + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer a.b.c")) + assert not called["anonymous"] + + @pytest.mark.asyncio + async def test_anonymous_rejected_by_iam_raises_401(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) + + @pytest.mark.asyncio + async def test_anonymous_with_empty_user_id_raises_401(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authenticate_anonymous=AsyncMock( + return_value=("", "default", []), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) diff --git a/tests/unit/test_iam/__init__.py b/tests/unit/test_iam/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_iam/test_iam_rejects_anonymous.py b/tests/unit/test_iam/test_iam_rejects_anonymous.py new file mode 100644 index 00000000..492b570a --- /dev/null +++ b/tests/unit/test_iam/test_iam_rejects_anonymous.py @@ -0,0 +1,44 @@ +""" +Contract test: the full iam-svc MUST reject authenticate-anonymous. + +This is a safety pin — if someone accidentally adds anonymous access +to the production IAM handler, this test catches it. +""" + +import asyncio +from unittest.mock import Mock, AsyncMock + +import pytest + +from trustgraph.iam.service.iam import IamService + + +def _make_request(**kwargs): + req = Mock() + for k, v in kwargs.items(): + setattr(req, k, v) + return req + + +class TestIamRejectsAnonymous: + + @pytest.fixture + def handler(self): + svc = object.__new__(IamService) + svc.table_store = Mock(spec=[]) + svc.bootstrap_mode = "token" + svc.bootstrap_token = "tok" + svc._on_workspace_created = None + svc._on_workspace_deleted = None + svc._signing_key = None + svc._signing_key_lock = asyncio.Lock() + return svc + + @pytest.mark.asyncio + async def test_authenticate_anonymous_returns_auth_failed(self, handler): + resp = await handler.handle( + _make_request(operation="authenticate-anonymous") + ) + assert resp.error is not None + assert resp.error.type == "auth-failed" + assert "anonymous" in resp.error.message.lower() diff --git a/tests/unit/test_iam/test_noauth_handler.py b/tests/unit/test_iam/test_noauth_handler.py new file mode 100644 index 00000000..38461b62 --- /dev/null +++ b/tests/unit/test_iam/test_noauth_handler.py @@ -0,0 +1,138 @@ +""" +Tests for the no-auth IAM handler. + +Verifies that NoAuthHandler returns the expected permissive responses +and that the always-allow authorise path returns the correct shape. +""" + +import json +from unittest.mock import Mock + +import pytest + +from trustgraph.iam.noauth.handler import NoAuthHandler + + +def _make_request(**kwargs): + req = Mock() + for k, v in kwargs.items(): + setattr(req, k, v) + return req + + +class TestAuthenticateAnonymous: + + @pytest.mark.asyncio + async def test_returns_default_identity(self): + h = NoAuthHandler( + default_user_id="anon", default_workspace="ws", + ) + resp = await h.handle( + _make_request(operation="authenticate-anonymous") + ) + assert resp.error is None + assert resp.resolved_user_id == "anon" + assert resp.resolved_workspace == "ws" + assert "admin" in list(resp.resolved_roles) + + @pytest.mark.asyncio + async def test_custom_defaults_propagate(self): + h = NoAuthHandler( + default_user_id="dev-user", default_workspace="dev-ws", + ) + resp = await h.handle( + _make_request(operation="authenticate-anonymous") + ) + assert resp.resolved_user_id == "dev-user" + assert resp.resolved_workspace == "dev-ws" + + +class TestResolveApiKey: + + @pytest.mark.asyncio + async def test_any_key_resolves_to_default_identity(self): + h = NoAuthHandler() + resp = await h.handle( + _make_request(operation="resolve-api-key", api_key="tg_bogus") + ) + assert resp.error is None + assert resp.resolved_user_id == "anonymous" + assert resp.resolved_workspace == "default" + + +class TestAuthorise: + + @pytest.mark.asyncio + async def test_always_allows(self): + h = NoAuthHandler() + resp = await h.handle( + _make_request( + operation="authorise", + user_id="anyone", + capability="anything", + resource_json="{}", + parameters_json="{}", + ) + ) + assert resp.error is None + assert resp.decision_allow is True + assert resp.decision_ttl_seconds > 0 + + @pytest.mark.asyncio + async def test_authorise_many_returns_matching_count(self): + h = NoAuthHandler() + checks = [ + {"capability": "a", "resource": {}, "parameters": {}}, + {"capability": "b", "resource": {}, "parameters": {}}, + {"capability": "c", "resource": {}, "parameters": {}}, + ] + resp = await h.handle( + _make_request( + operation="authorise-many", + user_id="u", + authorise_checks=json.dumps(checks), + ) + ) + assert resp.error is None + decisions = json.loads(resp.decisions_json) + assert len(decisions) == 3 + assert all(d["allow"] is True for d in decisions) + + +class TestCreateWorkspaceCallback: + + @pytest.mark.asyncio + async def test_create_workspace_calls_callback(self): + called_with = [] + + async def on_created(ws_id): + called_with.append(ws_id) + + h = NoAuthHandler(on_workspace_created=on_created) + req = _make_request(operation="create-workspace") + req.workspace_record = Mock() + req.workspace_record.id = "test-ws" + resp = await h.handle(req) + assert resp.error is None + assert called_with == ["test-ws"] + + @pytest.mark.asyncio + async def test_create_workspace_without_callback_still_succeeds(self): + h = NoAuthHandler() + req = _make_request(operation="create-workspace") + req.workspace_record = Mock() + req.workspace_record.id = "test-ws" + resp = await h.handle(req) + assert resp.error is None + + +class TestUnknownOperation: + + @pytest.mark.asyncio + async def test_unknown_op_returns_error(self): + h = NoAuthHandler() + resp = await h.handle( + _make_request(operation="not-a-real-op") + ) + assert resp.error is not None + assert resp.error.type == "invalid-argument" diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index 3e73ab9c..22b735ac 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -7,7 +7,7 @@ including template rendering, term merging, JSON validation, and error handling. import pytest import json -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock from trustgraph.template.prompt_manager import PromptManager, PromptConfiguration, Prompt @@ -344,6 +344,42 @@ class TestPromptManager: assert pm.terms == {} # Default empty terms assert len(pm.prompts) == 0 + def test_load_config_does_not_swallow_keyboard_interrupt(self, monkeypatch): + """KeyboardInterrupt should propagate out of config parsing.""" + pm = PromptManager() + + def interrupt(_value): + raise KeyboardInterrupt + + monkeypatch.setattr("trustgraph.template.prompt_manager.json.loads", interrupt) + + with pytest.raises(KeyboardInterrupt): + pm.load_config({"system": json.dumps("Test")}) + + @pytest.mark.asyncio + async def test_json_parse_does_not_swallow_system_exit(self): + """SystemExit should propagate out of JSON response parsing.""" + pm = PromptManager() + config = { + "system": json.dumps("Test"), + "template-index": json.dumps(["json_response"]), + "template.json_response": json.dumps({ + "prompt": "Generate JSON", + "response-type": "json" + }) + } + pm.load_config(config) + + def exit_parse(_text): + raise SystemExit(2) + + pm.parse_json = exit_parse + mock_llm = AsyncMock() + mock_llm.return_value = "{}" + + with pytest.raises(SystemExit): + await pm.invoke("json_response", {}, mock_llm) + @pytest.mark.unit class TestPromptManagerJsonl: @@ -585,4 +621,4 @@ not json at all assert len(result) == 2 assert result[0] == {"any": "structure"} - assert result[1] == {"completely": "different"} \ No newline at end of file + assert result[1] == {"completely": "different"} diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py index 0b6709fb..1fea0ee6 100644 --- a/tests/unit/test_python_api_client.py +++ b/tests/unit/test_python_api_client.py @@ -8,6 +8,7 @@ import pytest from unittest.mock import Mock, patch, MagicMock, call import json +from trustgraph.api.socket_client import SocketClient from trustgraph.api import ( Api, Triple, @@ -222,6 +223,82 @@ class TestSocketClient: for method in expected_methods: assert hasattr(flow_instance, method), f"Missing method: {method}" + def test_socket_client_close_does_not_swallow_base_exceptions(self): + """Test close cleanup does not suppress process-level interrupts.""" + + class InterruptingLoop: + def is_closed(self): + return False + + def run_until_complete(self, awaitable): + if hasattr(awaitable, "close"): + awaitable.close() + raise SystemExit("stop") + + socket = SocketClient(url="http://test/", timeout=60, token=None) + socket._loop = InterruptingLoop() + + with pytest.raises(SystemExit): + socket.close() + + @pytest.mark.parametrize( + ("generator_method", "async_method"), + [ + ("_streaming_generator", "_send_request_async_streaming"), + ("_streaming_generator_raw", "_send_request_async_streaming_raw"), + ], + ) + def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions( + self, generator_method, async_method + ): + """Test streaming cleanup does not suppress process-level interrupts.""" + + class FakeAsyncGenerator: + def __anext__(self): + return "next" + + def aclose(self): + return "close" + + class InterruptingLoop: + def run_until_complete(self, awaitable): + if awaitable == "next": + raise StopAsyncIteration + if awaitable == "close": + raise SystemExit("stop") + raise AssertionError(f"unexpected awaitable: {awaitable!r}") + + socket = SocketClient(url="http://test/", timeout=60, token=None) + setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator()) + generator = getattr(socket, generator_method)( + "agent", "default", {}, InterruptingLoop() + ) + + with pytest.raises(SystemExit): + next(generator) + + @pytest.mark.asyncio + async def test_socket_client_reader_does_not_swallow_base_exceptions(self): + """Test reader error fanout does not suppress process-level interrupts.""" + + class FailingSocket: + def __aiter__(self): + return self + + async def __anext__(self): + raise ValueError("reader failed") + + class InterruptingQueue: + async def put(self, message): + raise SystemExit("stop") + + socket = SocketClient(url="http://test/", timeout=60, token=None) + socket._socket = FailingSocket() + socket._pending = {"req-1": InterruptingQueue()} + + with pytest.raises(SystemExit): + await socket._reader() + class TestBulkClient: """Test bulk operations client""" diff --git a/tests/unit/test_query/test_ontology_monitoring.py b/tests/unit/test_query/test_ontology_monitoring.py new file mode 100644 index 00000000..4b1b4253 --- /dev/null +++ b/tests/unit/test_query/test_ontology_monitoring.py @@ -0,0 +1,56 @@ +""" +Tests for ontology monitoring metrics. +""" + +from trustgraph.query.ontology.monitoring import ( + PerformanceMonitor, + _extract_metric_label, +) + + +def test_extract_metric_label_reads_unquoted_label_value(): + metric_name = "cache_requests_total{cache_type=entity,component=ontology}" + + assert _extract_metric_label(metric_name, "cache_type") == "entity" + + +def test_extract_metric_label_reads_quoted_label_value(): + metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}' + + assert _extract_metric_label(metric_name, "cache_type") == "entity" + + +def test_extract_metric_label_returns_none_when_label_missing(): + metric_name = "cache_requests_total{component=ontology}" + + assert _extract_metric_label(metric_name, "cache_type") is None + + +def test_performance_report_ignores_counters_without_cache_type_label(): + monitor = PerformanceMonitor({"enabled": False}) + monitor.metrics_collector.increment( + "cache_requests_total", + labels={"component": "ontology"}, + ) + monitor.metrics_collector.increment( + "cache_type=not_a_label", + labels={"component": "ontology"}, + ) + monitor.metrics_collector.increment( + "cache_requests_total", + labels={"cache_type": "entity"}, + ) + monitor.metrics_collector.increment( + "cache_hits_total", + labels={"cache_type": "entity"}, + ) + + report = monitor.get_performance_report() + + assert report["cache_performance"] == { + "entity": { + "hit_rate": 1.0, + "total_requests": 1.0, + "total_hits": 1.0, + } + } diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index bb6bbe84..b61500a4 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configuration""" + import asyncio processor = MagicMock() processor.schemas = {} processor.schema_builders = {} processor.graphql_schemas = {} processor.config_key = "schema" processor.query_cassandra = MagicMock() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -335,7 +338,7 @@ class TestUnifiedTableQueries: """Test query execution with matching index""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) @@ -396,7 +399,7 @@ class TestUnifiedTableQueries: """Test query execution without matching index (scan mode)""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py new file mode 100644 index 00000000..d2a49e99 --- /dev/null +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -0,0 +1,580 @@ +""" +Tests for the SPARQL algebra evaluator. + +Verifies that evaluate() and _query_pattern() call TriplesClient.query() +with the correct arguments, and in particular that workspace is never +passed — workspace isolation is handled by pub/sub topic routing. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call + +from rdflib.term import Variable, URIRef, Literal +from rdflib.plugins.sparql.parserutils import CompValue + +from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.query.sparql.algebra import ( + evaluate, materialise, _query_pattern, _eval_bgp, +) + + +# --- Helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + + +def lit(v): + return Term(type=LITERAL, value=v) + + +def make_tc(query_return=None, query_side_effect=None): + """Create a mock TriplesClient with both query() and query_gen() support.""" + tc = AsyncMock() + + if query_side_effect is not None: + tc.query.side_effect = query_side_effect + + async def gen_side_effect(**kwargs): + results = await query_side_effect(**kwargs) + for r in results: + yield r + + tc.query_gen = gen_side_effect + else: + items = query_return or [] + tc.query.return_value = items + + async def gen(**kwargs): + for item in items: + yield item + + tc.query_gen = gen + + return tc + + +def make_triple(s, p, o): + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def make_bgp(*patterns): + """Build a CompValue BGP node from (s, p, o) tuples of rdflib terms.""" + node = CompValue("BGP") + node.triples = list(patterns) + return node + + +def make_project(inner, variables): + node = CompValue("Project") + node.p = inner + node.PV = [Variable(v) for v in variables] + return node + + +def make_select(inner): + node = CompValue("SelectQuery") + node.p = inner + return node + + +def make_join(left, right): + node = CompValue("Join") + node.p1 = left + node.p2 = right + return node + + +def make_union(left, right): + node = CompValue("Union") + node.p1 = left + node.p2 = right + return node + + +def make_slice(inner, start, length): + node = CompValue("Slice") + node.p = inner + node.start = start + node.length = length + return node + + +def make_distinct(inner): + node = CompValue("Distinct") + node.p = inner + return node + + +def make_filter(inner, expr): + node = CompValue("Filter") + node.p = inner + node.expr = expr + return node + + +def make_minus(left, right): + node = CompValue("Minus") + node.p1 = left + node.p2 = right + return node + + +class TestQueryPattern: + """Tests for _query_pattern — the leaf that calls TriplesClient.""" + + @pytest.mark.asyncio + async def test_passes_correct_args(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern( + tc, + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + collection="my-collection", + limit=100, + ) + + tc.query.assert_called_once_with( + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + limit=100, + collection="my-collection", + ) + + @pytest.mark.asyncio + async def test_workspace_not_passed(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern(tc, None, None, None, "default", 10) + + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + + @pytest.mark.asyncio + async def test_returns_query_results(self): + tc = AsyncMock() + triple = make_triple(iri("http://a"), iri("http://b"), lit("c")) + tc.query.return_value = [triple] + + results = await _query_pattern(tc, None, None, None, "default", 10) + + assert len(results) == 1 + assert results[0].s.iri == "http://a" + + +class TestEvalBgp: + """Tests for BGP evaluation — triple pattern queries.""" + + @pytest.mark.asyncio + async def test_single_pattern_all_variables(self): + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc = make_tc(query_return=[triple]) + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await materialise(bgp, tc, collection="default", limit=100) + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://s" + assert solutions[0]["p"].iri == "http://p" + assert solutions[0]["o"].value == "o" + + @pytest.mark.asyncio + async def test_single_pattern_bound_subject(self): + tc = make_tc(query_return=[ + make_triple(iri("http://s"), iri("http://p"), lit("val")), + ]) + + bgp = make_bgp( + (URIRef("http://s"), Variable("p"), Variable("o")), + ) + + solutions = await materialise(bgp, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_empty_bgp_returns_empty_solution(self): + tc = make_tc() + + bgp = make_bgp() + + solutions = await materialise(bgp, tc, collection="default") + + assert solutions == [{}] + + @pytest.mark.asyncio + async def test_no_results_returns_empty(self): + tc = make_tc(query_return=[]) + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await materialise(bgp, tc, collection="default") + + assert solutions == [] + + +class TestEvaluate: + """Tests for the top-level evaluate() dispatcher.""" + + @pytest.mark.asyncio + async def test_select_query_node(self): + tc = make_tc(query_return=[ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ]) + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + select = make_select(make_project(bgp, ["s", "p"])) + + solutions = await materialise(select, tc, collection="default") + + assert len(solutions) == 1 + assert "s" in solutions[0] + assert "p" in solutions[0] + assert "o" not in solutions[0] + + @pytest.mark.asyncio + async def test_workspace_never_in_query_calls(self): + """Verify that no matter the algebra structure, workspace is never + passed to TriplesClient.query().""" + tc = make_tc(query_return=[ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ]) + + bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) + tree = make_select(make_project( + make_union(bgp1, bgp2), ["s", "p", "o"] + )) + + await materialise(tree, tc, collection="test-coll") + + @pytest.mark.asyncio + async def test_join(self): + call_count = 0 + + async def mock_query(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [make_triple(iri("http://a"), iri("http://p"), lit("v"))] + else: + return [make_triple(iri("http://a"), iri("http://q"), lit("w"))] + + tc = make_tc(query_side_effect=mock_query) + + bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) + bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) + tree = make_join(bgp1, bgp2) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://a" + + @pytest.mark.asyncio + async def test_slice(self): + triples = [ + make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) + for i in range(5) + ] + tc = make_tc(query_return=triples) + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_slice(bgp, start=1, length=2) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 2 + + @pytest.mark.asyncio + async def test_distinct(self): + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc = make_tc(query_return=[triple, triple]) + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_distinct(bgp) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_minus_removes_matching(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + hates = iri("http://example.com/hates") + charlie = iri("http://example.com/charlie") + + left_triple = make_triple(alice, knows, bob) + right_triple2 = make_triple(alice, hates, charlie) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple] + elif pred and pred.iri == "http://example.com/hates": + return [right_triple2] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) + ) + + tree = make_select( + make_project( + make_minus(left_bgp, right_bgp), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 0 + + @pytest.mark.asyncio + async def test_minus_no_shared_vars_preserves_all(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + + left_triple = make_triple(alice, iri("http://example.com/p"), bob) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/p": + return [left_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/p"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("x"), URIRef("http://example.com/q"), Variable("y")) + ) + + tree = make_select( + make_project( + make_minus(left_bgp, right_bgp), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_filter_exists_keeps_matching(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + charlie = iri("http://example.com/charlie") + + left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob) + left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) + exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple1, left_triple2] + elif pred and pred.iri == "http://example.com/likes": + return [exists_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) + + exists_expr = CompValue("Builtin_EXISTS") + exists_expr.graph = exists_bgp + + tree = make_select( + make_project( + make_filter(left_bgp, exists_expr), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + result_objects = [s["o"].iri for s in solutions] + assert "http://example.com/bob" in result_objects + assert "http://example.com/charlie" not in result_objects + + @pytest.mark.asyncio + async def test_filter_not_exists_removes_matching(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + charlie = iri("http://example.com/charlie") + + left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob) + left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) + exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple1, left_triple2] + elif pred and pred.iri == "http://example.com/likes": + return [exists_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) + + not_exists_expr = CompValue("Builtin_NOTEXISTS") + not_exists_expr.graph = exists_bgp + + tree = make_select( + make_project( + make_filter(left_bgp, not_exists_expr), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + result_objects = [s["o"].iri for s in solutions] + assert "http://example.com/charlie" in result_objects + assert "http://example.com/bob" not in result_objects + + @pytest.mark.asyncio + async def test_join_values_uses_bind_join(self): + """When VALUES is joined with a BGP, the bind join should pass + the VALUES bindings into the BGP evaluation so the triple store + query is selective (not a wildcard).""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + + queries_issued = [] + + async def mock_query(**kwargs): + queries_issued.append(kwargs) + s, p = kwargs.get("s"), kwargs.get("p") + if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows": + return [make_triple(alice, knows, bob)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + # VALUES ?s { } + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [[URIRef("http://example.com/alice")]] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://example.com/alice" + assert solutions[0]["o"].iri == "http://example.com/bob" + + # The key assertion: the BGP query should have received + # s=alice (bound from VALUES), NOT s=None (wildcard) + assert len(queries_issued) == 1 + assert queries_issued[0]["s"] is not None + assert queries_issued[0]["s"].iri == "http://example.com/alice" + + @pytest.mark.asyncio + async def test_join_values_multiple_bindings(self): + """Bind join with multiple VALUES bindings.""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + charlie = iri("http://example.com/charlie") + + async def mock_query(**kwargs): + s = kwargs.get("s") + if s and s.iri == "http://example.com/alice": + return [make_triple(alice, knows, bob)] + elif s and s.iri == "http://example.com/bob": + return [make_triple(bob, knows, charlie)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [ + [URIRef("http://example.com/alice")], + [URIRef("http://example.com/bob")], + ] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 2 + subjects = {s["s"].iri for s in solutions} + assert subjects == { + "http://example.com/alice", + "http://example.com/bob", + } + + @pytest.mark.asyncio + async def test_unsupported_node_returns_empty_solution(self): + tc = make_tc() + + node = CompValue("SomethingUnknown") + + solutions = await materialise(node, tc, collection="default") + + assert solutions == [{}] + + @pytest.mark.asyncio + async def test_non_compvalue_returns_empty_solution(self): + tc = make_tc() + + solutions = await materialise("not a node", tc, collection="default") + + assert solutions == [{}] diff --git a/tests/unit/test_query/test_sparql_expressions.py b/tests/unit/test_query/test_sparql_expressions.py index 63e9188f..87c862e8 100644 --- a/tests/unit/test_query/test_sparql_expressions.py +++ b/tests/unit/test_query/test_sparql_expressions.py @@ -300,6 +300,438 @@ class TestBuiltinFunctions: flags=None) assert evaluate_expression(expr, {"x": lit("hello")}) is False + def test_substr_three_args(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(1), + length=Literal(4)) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "2024" + + def test_substr_two_args(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(6), + length=None) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "03-15" + + def test_substr_middle(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(6), + length=Literal(2)) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "03" + + def test_substr_null_start(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Variable("missing"), + length=None) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result is None + + def test_year(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 2024 + + def test_month(self): + from rdflib.term import Variable + expr = self._make_builtin("MONTH", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 3 + + def test_day(self): + from rdflib.term import Variable + expr = self._make_builtin("DAY", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 15 + + def test_hours(self): + from rdflib.term import Variable + expr = self._make_builtin("HOURS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 10 + + def test_minutes(self): + from rdflib.term import Variable + expr = self._make_builtin("MINUTES", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 30 + + def test_seconds(self): + from rdflib.term import Variable + expr = self._make_builtin("SECONDS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 45 + + def test_year_from_datetime(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 2024 + + def test_hours_from_date_returns_zero(self): + from rdflib.term import Variable + expr = self._make_builtin("HOURS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 0 + + def test_year_invalid_date(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("not-a-date")} + ) + assert result is None + + def test_floor(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.7")}) == 3 + + def test_floor_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3 + + def test_floor_none(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_ceil(self): + from rdflib.term import Variable + expr = self._make_builtin("CEIL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.2")}) == 4 + + def test_ceil_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("CEIL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2 + + def test_abs_positive(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("42")}) == 42 + + def test_abs_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-42")}) == 42 + + def test_abs_none(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_replace_simple(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal(" BC"), + replacement=Literal(""), + flags=None) + result = evaluate_expression(expr, {"x": lit("500 BC")}) + assert result.type == LITERAL + assert result.value == "500" + + def test_replace_regex(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal("[0-9]+"), + replacement=Literal("X"), + flags=None) + result = evaluate_expression(expr, {"x": lit("abc123def456")}) + assert result.value == "abcXdefX" + + def test_replace_case_insensitive(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal("hello"), + replacement=Literal("world"), + flags=Literal("i")) + result = evaluate_expression(expr, {"x": lit("HELLO there")}) + assert result.value == "world there" + + def test_round_up(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.7")}) == 4 + + def test_round_down(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.2")}) == 3 + + def test_round_none(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_strbefore(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRBEFORE", + arg1=Variable("x"), arg2=Literal("-")) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.value == "2024" + + def test_strbefore_not_found(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRBEFORE", + arg1=Variable("x"), arg2=Literal("/")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_strafter(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRAFTER", + arg1=Variable("x"), arg2=Literal("-")) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.value == "03-15" + + def test_strafter_not_found(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRAFTER", + arg1=Variable("x"), arg2=Literal("/")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_encode_for_uri(self): + from rdflib.term import Variable + expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello world")}) + assert result.value == "hello%20world" + + def test_encode_for_uri_special_chars(self): + from rdflib.term import Variable + expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")}) + assert result.value == "a%2Fb%3Fc%3Dd%26e" + + def test_langmatches_basic(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("en"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_subtag(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("en-US"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_wildcard(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("fr"), arg2=Literal("*")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_wildcard_empty(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal(""), arg2=Literal("*")) + assert evaluate_expression(expr, {}) is False + + def test_langmatches_no_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("fr"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is False + + def test_iri_constructor(self): + from rdflib.term import Variable + expr = self._make_builtin("IRI", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("http://example.com/test")} + ) + assert result.type == IRI + assert result.iri == "http://example.com/test" + + def test_uri_constructor(self): + from rdflib.term import Variable + expr = self._make_builtin("URI", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("http://example.com/test")} + ) + assert result.type == IRI + assert result.iri == "http://example.com/test" + + def test_bnode_no_arg(self): + expr = self._make_builtin("BNODE") + result = evaluate_expression(expr, {}) + assert result.type == BLANK + assert len(result.id) > 0 + + def test_bnode_with_label(self): + from rdflib import Literal + expr = self._make_builtin("BNODE", arg=Literal("mynode")) + result = evaluate_expression(expr, {}) + assert result.type == BLANK + assert result.id == "mynode" + + def test_now(self): + import re as re_mod + expr = self._make_builtin("NOW") + result = evaluate_expression(expr, {}) + assert result.type == LITERAL + assert result.datatype == XSD + "dateTime" + assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value) + + def test_tz_with_utc(self): + from rdflib.term import Variable + expr = self._make_builtin("TZ", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45+0000", + datatype=XSD + "dateTime")} + ) + assert result.type == LITERAL + assert result.value == "+00:00" + + def test_tz_no_timezone(self): + from rdflib.term import Variable + expr = self._make_builtin("TZ", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", + datatype=XSD + "dateTime")} + ) + assert result.value == "" + + def test_rand(self): + expr = self._make_builtin("RAND") + result = evaluate_expression(expr, {}) + assert isinstance(result, float) + assert 0.0 <= result < 1.0 + + def test_uuid(self): + import re as re_mod + expr = self._make_builtin("UUID") + result = evaluate_expression(expr, {}) + assert result.type == IRI + assert result.iri.startswith("urn:uuid:") + uuid_part = result.iri[len("urn:uuid:"):] + assert re_mod.match( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + uuid_part + ) + + def test_struuid(self): + import re as re_mod + expr = self._make_builtin("STRUUID") + result = evaluate_expression(expr, {}) + assert result.type == LITERAL + assert re_mod.match( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + result.value + ) + + def test_md5(self): + from rdflib.term import Variable + expr = self._make_builtin("MD5", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == "5d41402abc4b2a76b9719d911017c592" + + def test_sha1(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA1", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d" + + def test_sha256(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA256", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == ( + "2cf24dba5fb0a30e26e83b2ac5b9e29e" + "1b161e5c1fa7425e73043362938b9824" + ) + + def test_sha512(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA512", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert len(result.value) == 128 + + def test_exists_with_callback(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("EXISTS", graph=graph) + cb = lambda g, s: True + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is True + + def test_exists_callback_false(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("EXISTS", graph=graph) + cb = lambda g, s: False + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is False + + def test_notexists_with_callback(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("NOTEXISTS", graph=graph) + cb = lambda g, s: True + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is False + + def test_notexists_callback_false(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("NOTEXISTS", graph=graph) + cb = lambda g, s: False + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is True + class TestEffectiveBoolean: diff --git a/tests/unit/test_query/test_sparql_solutions.py b/tests/unit/test_query/test_sparql_solutions.py index 5805ca84..7588a95b 100644 --- a/tests/unit/test_query/test_sparql_solutions.py +++ b/tests/unit/test_query/test_sparql_solutions.py @@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations. import pytest from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.solutions import ( - hash_join, left_join, union, project, distinct, + hash_join, left_join, minus, union, project, distinct, order_by, slice_solutions, _terms_equal, _compatible, ) @@ -311,6 +311,30 @@ class TestOrderBy: result = order_by(solutions, []) assert len(result) == 1 + def test_order_by_numeric_literals(self): + solutions = [ + {"year": lit("1950")}, + {"year": lit("700")}, + {"year": lit("2000")}, + {"year": lit("450")}, + {"year": lit("1200")}, + ] + key_fns = [(lambda sol: sol.get("year"), True)] + result = order_by(solutions, key_fns) + values = [s["year"].value for s in result] + assert values == ["450", "700", "1200", "1950", "2000"] + + def test_order_by_numeric_descending(self): + solutions = [ + {"year": lit("1950")}, + {"year": lit("700")}, + {"year": lit("2000")}, + ] + key_fns = [(lambda sol: sol.get("year"), False)] + result = order_by(solutions, key_fns) + values = [s["year"].value for s in result] + assert values == ["2000", "1950", "700"] + class TestSlice: @@ -343,3 +367,37 @@ class TestSlice: solutions = [{"s": alice}, {"s": bob}] result = slice_solutions(solutions) assert len(result) == 2 + + +class TestMinus: + + def test_removes_compatible(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_empty_right_preserves_all(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + result = minus(left, []) + assert len(result) == 2 + + def test_no_shared_variables_preserves_all(self, alice, bob): + left = [{"s": alice}] + right = [{"t": bob}] + result = minus(left, right) + assert len(result) == 1 + + def test_all_removed(self, alice): + left = [{"s": alice}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 0 + + def test_partial_shared_variables(self, alice, bob): + left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 09681214..980fa904 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -2,8 +2,10 @@ Tests for Cassandra triples query service """ +import asyncio + import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, AsyncMock from trustgraph.query.triples.cassandra.service import Processor, create_term from trustgraph.schema import Term, IRI, LITERAL @@ -18,7 +20,7 @@ class TestCassandraQueryProcessor: return Processor( taskgroup=MagicMock(), id='test-cassandra-query', - graph_host='localhost' + cassandra_host='localhost' ) def test_create_term_with_http_uri(self, processor): @@ -85,7 +87,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -110,8 +112,8 @@ class TestCassandraQueryProcessor: keyspace='test_user' ) - # Verify get_spo was called with correct parameters - mock_tg_instance.get_spo.assert_called_once_with( + # Verify async_get_spo was called with correct parameters + mock_tg_instance.async_get_spo.assert_called_once_with( 'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100 ) @@ -130,23 +132,25 @@ class TestCassandraQueryProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, cassandra_host='cassandra.example.com', cassandra_username='queryuser', cassandra_password='querypass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'queryuser' assert processor.cassandra_password == 'querypass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -164,7 +168,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_sp.return_value = [mock_result] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -178,7 +182,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) + mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'test_predicate' @@ -200,7 +204,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_s.return_value = [mock_result] + mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -214,7 +218,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) + mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'result_predicate' @@ -236,7 +240,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_p.return_value = [mock_result] + mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -250,7 +254,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) + mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'test_predicate' @@ -272,7 +276,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_o.return_value = [mock_result] + mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -286,7 +290,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) + mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'result_predicate' @@ -305,11 +309,11 @@ class TestCassandraQueryProcessor: mock_result.s = 'all_subject' mock_result.p = 'all_predicate' mock_result.o = 'all_object' - mock_result.g = '' + mock_result.d = '' mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_all.return_value = [mock_result] + mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -323,7 +327,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) + mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 assert result[0].s.iri == 'all_subject' assert result[0].p.iri == 'all_predicate' @@ -410,7 +414,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -451,7 +455,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -489,8 +493,8 @@ class TestCassandraQueryProcessor: mock_result.lang = None mock_result.p = 'p' mock_result.o = 'o' - mock_tg_instance1.get_s.return_value = [mock_result] - mock_tg_instance2.get_s.return_value = [mock_result] + mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result]) + mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -504,7 +508,6 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user1', query1) - assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( @@ -516,10 +519,11 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user2', query2) - assert processor.table == 'user2' - # Verify TrustGraph was created twice + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -529,7 +533,7 @@ class TestCassandraQueryProcessor: mock_tg_instance = MagicMock() mock_kg_class.return_value = mock_tg_instance - mock_tg_instance.get_spo.side_effect = Exception("Query failed") + mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed")) processor = Processor(taskgroup=MagicMock()) @@ -566,7 +570,7 @@ class TestCassandraQueryProcessor: mock_result2.otype = None mock_result2.dtype = None mock_result2.lang = None - mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2]) processor = Processor(taskgroup=MagicMock()) @@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_po.return_value = [mock_result] + mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_po was called (should use optimized po_table) - mock_tg_instance.get_po.assert_called_once_with( + # Verify async_get_po was called (should use optimized po_table) + mock_tg_instance.async_get_po.assert_called_once_with( 'test_collection', 'test_predicate', 'test_object', g=None, limit=50 ) @@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_os.return_value = [mock_result] + mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_os was called (should use optimized subject_table with clustering) - mock_tg_instance.get_os.assert_called_once_with( + # Verify async_get_os was called (should use optimized subject_table with clustering) + mock_tg_instance.async_get_os.assert_called_once_with( 'test_collection', 'test_object', 'test_subject', g=None, limit=25 ) @@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations: mock_kg_class.return_value = mock_tg_instance # Mock empty results for all queries - mock_tg_instance.get_all.return_value = [] - mock_tg_instance.get_s.return_value = [] - mock_tg_instance.get_p.return_value = [] - mock_tg_instance.get_o.return_value = [] - mock_tg_instance.get_sp.return_value = [] - mock_tg_instance.get_po.return_value = [] - mock_tg_instance.get_os.return_value = [] - mock_tg_instance.get_spo.return_value = [] + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_tg_instance.async_get_s = AsyncMock(return_value=[]) + mock_tg_instance.async_get_p = AsyncMock(return_value=[]) + mock_tg_instance.async_get_o = AsyncMock(return_value=[]) + mock_tg_instance.async_get_sp = AsyncMock(return_value=[]) + mock_tg_instance.async_get_po = AsyncMock(return_value=[]) + mock_tg_instance.async_get_os = AsyncMock(return_value=[]) + mock_tg_instance.async_get_spo = AsyncMock(return_value=[]) processor = Processor(taskgroup=MagicMock()) # Test each query pattern test_patterns = [ # (s, p, o, expected_method) - (None, None, None, 'get_all'), # All triples - ('s1', None, None, 'get_s'), # Subject only - (None, 'p1', None, 'get_p'), # Predicate only - (None, None, 'o1', 'get_o'), # Object only - ('s1', 'p1', None, 'get_sp'), # Subject + Predicate - (None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) - ('s1', None, 'o1', 'get_os'), # Object + Subject - ('s1', 'p1', 'o1', 'get_spo'), # All three + (None, None, None, 'async_get_all'), # All triples + ('s1', None, None, 'async_get_s'), # Subject only + (None, 'p1', None, 'async_get_p'), # Predicate only + (None, None, 'o1', 'async_get_o'), # Object only + ('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate + (None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) + ('s1', None, 'o1', 'async_get_os'), # Object + Subject + ('s1', 'p1', 'o1', 'async_get_spo'), # All three ] for s, p, o, expected_method in test_patterns: @@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.lang = None mock_results.append(mock_result) - mock_tg_instance.get_po.return_value = mock_results + mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results) processor = Processor(taskgroup=MagicMock()) @@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('large_dataset_user', query) - # Verify optimized get_po was used (no ALLOW FILTERING needed!) - mock_tg_instance.get_po.assert_called_once_with( + # Verify optimized async_get_po was used (no ALLOW FILTERING needed!) + mock_tg_instance.async_get_po.assert_called_once_with( 'massive_collection', 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 'http://example.com/Person', diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 2296e961..dbe06b40 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_embedding_upserted(self): + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_dimension_in_collection_name(self): """Collection name should include vector dimension.""" + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "docs" @@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_entity_and_vector_upserted(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_lazy_collection_creation_on_new_dimension(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = False proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "graphs" diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index e0f41357..d1979211 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -337,6 +337,57 @@ class TestQuery: cache_key = "test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") + @pytest.mark.asyncio + async def test_triples_query_never_passes_workspace(self): + """Workspace isolation is handled by pub/sub topic routing, not + by passing workspace to TriplesClient.query(). Verify that + GraphRAG never passes workspace as a keyword argument.""" + mock_rag = MagicMock() + mock_cache = MagicMock() + mock_cache.get.return_value = None + mock_rag.label_cache = mock_cache + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.o = "Label" + mock_triples_client.query.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False + ) + + await query.maybe_label("http://example.com/entity") + + for c in mock_triples_client.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_follow_edges_never_passes_workspace(self): + """Verify follow_edges never passes workspace to query_stream.""" + mock_rag = MagicMock() + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triples_client.query_stream.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + triple_limit=10 + ) + + subgraph = set() + await query.follow_edges("e1", subgraph, path_length=1) + + for c in mock_triples_client.query_stream.call_args_list: + assert "workspace" not in c.kwargs + @pytest.mark.asyncio async def test_follow_edges_basic_functionality(self): """Test Query.follow_edges method basic triple discovery""" diff --git a/tests/unit/test_rev_gateway/__init__.py b/tests/unit/test_rev_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_rev_gateway/test_dispatcher.py b/tests/unit/test_rev_gateway/test_dispatcher.py index 2a9c8df0..6df786cf 100644 --- a/tests/unit/test_rev_gateway/test_dispatcher.py +++ b/tests/unit/test_rev_gateway/test_dispatcher.py @@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch, ANY -from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher - - -class TestWebSocketResponder: - """Test cases for WebSocketResponder class""" - - def test_websocket_responder_initialization(self): - """Test WebSocketResponder initialization""" - responder = WebSocketResponder() - - assert responder.response is None - assert responder.completed is False - - @pytest.mark.asyncio - async def test_websocket_responder_send_method(self): - """Test WebSocketResponder send method""" - responder = WebSocketResponder() - - test_response = {"data": "test response"} - - # Call send method - await responder.send(test_response) - - # Verify response was stored - assert responder.response == test_response - - @pytest.mark.asyncio - async def test_websocket_responder_call_method(self): - """Test WebSocketResponder __call__ method""" - responder = WebSocketResponder() - - test_response = {"result": "success"} - test_completed = True - - # Call the responder - await responder(test_response, test_completed) - - # Verify response and completed status were set - assert responder.response == test_response - assert responder.completed == test_completed - - @pytest.mark.asyncio - async def test_websocket_responder_call_method_with_false_completion(self): - """Test WebSocketResponder __call__ method with incomplete response""" - responder = WebSocketResponder() - - test_response = {"partial": "data"} - test_completed = False - - # Call the responder - await responder(test_response, test_completed) - - # Verify response was set and completed is True (since send() always sets completed=True) - assert responder.response == test_response - assert responder.completed is True +from trustgraph.rev_gateway.dispatcher import MessageDispatcher class TestMessageDispatcher: """Test cases for MessageDispatcher class""" def test_message_dispatcher_initialization_with_defaults(self): - """Test MessageDispatcher initialization with default parameters""" dispatcher = MessageDispatcher() - + assert dispatcher.max_workers == 10 assert dispatcher.semaphore._value == 10 assert dispatcher.active_tasks == set() assert dispatcher.backend is None + assert dispatcher.auth is None assert dispatcher.dispatcher_manager is None assert len(dispatcher.service_mapping) > 0 def test_message_dispatcher_initialization_with_custom_workers(self): - """Test MessageDispatcher initialization with custom max_workers""" dispatcher = MessageDispatcher(max_workers=5) - + assert dispatcher.max_workers == 5 assert dispatcher.semaphore._value == 5 @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') - def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): - """Test MessageDispatcher initialization with pulsar_client and config_receiver""" + def test_message_dispatcher_initialization_with_backend( + self, mock_dispatcher_manager, + ): mock_backend = MagicMock() mock_config_receiver = MagicMock() + mock_auth = MagicMock() mock_dispatcher_instance = MagicMock() mock_dispatcher_manager.return_value = mock_dispatcher_instance - + dispatcher = MessageDispatcher( max_workers=8, config_receiver=mock_config_receiver, - backend=mock_backend + backend=mock_backend, + auth=mock_auth, + timeout=300, ) - + assert dispatcher.max_workers == 8 assert dispatcher.backend == mock_backend + assert dispatcher.auth == mock_auth assert dispatcher.dispatcher_manager == mock_dispatcher_instance mock_dispatcher_manager.assert_called_once_with( - mock_backend, mock_config_receiver, prefix="rev-gateway" + mock_backend, mock_config_receiver, + auth=mock_auth, prefix="rev-gateway", timeout=300, ) def test_message_dispatcher_service_mapping(self): - """Test MessageDispatcher service mapping contains expected services""" dispatcher = MessageDispatcher() - + expected_services = [ "text-completion", "graph-rag", "agent", "embeddings", "graph-embeddings", "triples", "document-load", "text-load", - "flow", "knowledge", "config", "librarian", "document-rag" + "flow", "knowledge", "config", "librarian", "document-rag", ] - + for service in expected_services: assert service in dispatcher.service_mapping - - # Test specific mappings - assert dispatcher.service_mapping["text-completion"] == "text-completion" + assert dispatcher.service_mapping["document-load"] == "document" assert dispatcher.service_mapping["text-load"] == "text-document" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_without_dispatcher_manager(self): - """Test MessageDispatcher handle_message without dispatcher manager""" + async def test_handle_message_without_dispatcher_manager(self): dispatcher = MessageDispatcher() - - test_message = { - "id": "test-123", - "service": "test-service", - "request": {"data": "test"} - } - - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-123" - assert "error" in result["response"] - assert "DispatcherManager not available" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="default") + ) + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-1", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-1" + assert sent["error"]["message"] == "DispatcherManager not available" + assert sent["error"]["type"] == "error" + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_with_exception(self): - """Test MessageDispatcher handle_message with exception during processing""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error")) - + async def test_handle_message_auth_failure(self): dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-456", - "service": "text-completion", - "request": {"prompt": "test"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-456" - assert "error" in result["response"] - assert "Test error" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + side_effect=Exception("auth failure") + ) + dispatcher.dispatcher_manager = MagicMock() + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-2", "token": "bad", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-2" + assert "auth failure" in sent["error"]["message"] + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_global_service(self): - """Test MessageDispatcher handle_message with global service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"result": "success"} - + async def test_handle_message_global_service(self): + mock_dm = MagicMock() + mock_dm.invoke_global_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-789", - "service": "text-completion", - "request": {"prompt": "hello"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-789" - assert result["response"] == {"result": "success"} - mock_dispatcher_manager.invoke_global_service.assert_called_once() + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-3", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hello"}, + }, + sender, + ) + + mock_dm.invoke_global_service.assert_called_once() + args, kwargs = mock_dm.invoke_global_service.call_args + assert args[0] == {"prompt": "hello"} + assert args[2] == "text-completion" + assert kwargs["workspace"] == "ws1" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_flow_service(self): - """Test MessageDispatcher handle_message with flow service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"data": "flow_result"} - + async def test_handle_message_flow_service(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-flow-123", - "service": "document-rag", - "request": {"query": "test"}, - "flow": "custom-flow" - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-flow-123" - assert result["response"] == {"data": "flow_result"} - mock_dispatcher_manager.invoke_flow_service.assert_called_once_with( - {"query": "test"}, mock_responder, "custom-flow", "document-rag" + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws2") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-4", + "token": "tg_key", + "service": "document-rag", + "request": {"query": "test"}, + "flow": "my-flow", + }, + sender, + ) + + mock_dm.invoke_flow_service.assert_called_once_with( + {"query": "test"}, ANY, "ws2", "my-flow", "document-rag", ) @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_incomplete_response(self): - """Test MessageDispatcher handle_message with incomplete response""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = False - mock_responder.response = None - + async def test_handle_message_responder_sends_frames(self): + mock_dm = MagicMock() + + async def fake_invoke(data, responder, svc, workspace=None): + await responder({"partial": 1}, False) + await responder({"partial": 2}, True) + + mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke) + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-incomplete", - "service": "agent", - "request": {"input": "test"} + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-5", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hi"}, + }, + sender, + ) + + assert sender.call_count == 2 + first = sender.call_args_list[0][0][0] + second = sender.call_args_list[1][0][0] + + assert first == { + "id": "test-5", "response": {"partial": 1}, "complete": False, + } + assert second == { + "id": "test-5", "response": {"partial": 2}, "complete": True, } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-incomplete" - assert result["response"] == {"error": "No response received"} @pytest.mark.asyncio - async def test_message_dispatcher_shutdown(self): - """Test MessageDispatcher shutdown method""" - import asyncio - + async def test_handle_message_workspace_from_identity(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - - # Create actual async tasks + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="derived-ws") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-6", + "token": "tg_key", + "service": "agent", + "request": {"question": "test"}, + "flow": "default", + }, + sender, + ) + + args = mock_dm.invoke_flow_service.call_args[0] + assert args[2] == "derived-ws" + + @pytest.mark.asyncio + async def test_shutdown(self): + dispatcher = MessageDispatcher() + async def dummy_task(): await asyncio.sleep(0.01) - return "done" - + task1 = asyncio.create_task(dummy_task()) task2 = asyncio.create_task(dummy_task()) dispatcher.active_tasks = {task1, task2} - - # Call shutdown + await dispatcher.shutdown() - - # Verify tasks were completed + assert task1.done() assert task2.done() - assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed @pytest.mark.asyncio - async def test_message_dispatcher_shutdown_with_no_tasks(self): - """Test MessageDispatcher shutdown with no active tasks""" + async def test_shutdown_with_no_tasks(self): dispatcher = MessageDispatcher() - - # Call shutdown with no active tasks + await dispatcher.shutdown() - - # Should complete without error - assert dispatcher.active_tasks == set() \ No newline at end of file + + assert dispatcher.active_tasks == set() diff --git a/tests/unit/test_rev_gateway/test_rev_gateway_service.py b/tests/unit/test_rev_gateway/test_rev_gateway_service.py index 23aff18e..e2d1045d 100644 --- a/tests/unit/test_rev_gateway/test_rev_gateway_service.py +++ b/tests/unit/test_rev_gateway/test_rev_gateway_service.py @@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock from aiohttp import WSMsgType, ClientWebSocketResponse import json -from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run +from trustgraph.rev_gateway.service import ReverseGateway, run + + +MOCK_PATCHES = [ + 'trustgraph.rev_gateway.service.IamAuth', + 'trustgraph.rev_gateway.service.ConfigReceiver', + 'trustgraph.rev_gateway.service.MessageDispatcher', + 'trustgraph.rev_gateway.service.get_pubsub', +] + + +def make_gateway(**overrides): + config = {"websocket_uri": "ws://localhost:7650/out"} + config.update(overrides) + return ReverseGateway(**config) class TestReverseGateway: """Test cases for ReverseGateway class""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with default parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_defaults( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + assert gateway.websocket_uri == "ws://localhost:7650/out" assert gateway.host == "localhost" assert gateway.port == 7650 @@ -33,25 +49,22 @@ class TestReverseGateway: assert gateway.max_workers == 10 assert gateway.running is False assert gateway.reconnect_delay == 3.0 - assert gateway.pulsar_host == "pulsar://pulsar:6650" - assert gateway.pulsar_api_key is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with custom parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway( + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_custom_params( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway( websocket_uri="wss://example.com:8080/websocket", max_workers=20, - pulsar_host="pulsar://custom:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" ) - + assert gateway.websocket_uri == "wss://example.com:8080/websocket" assert gateway.host == "example.com" assert gateway.port == 8080 @@ -59,340 +72,360 @@ class TestReverseGateway: assert gateway.path == "/websocket" assert gateway.url == "wss://example.com:8080/websocket" assert gateway.max_workers == 20 - assert gateway.pulsar_host == "pulsar://custom:6650" - assert gateway.pulsar_api_key == "test-key" - assert gateway.pulsar_listener == "test-listener" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with WebSocket URI missing path""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway(websocket_uri="ws://example.com") - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_with_missing_path( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway(websocket_uri="ws://example.com") + assert gateway.path == "/ws" assert gateway.url == "ws://example.com/ws" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with invalid WebSocket scheme""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_invalid_scheme( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): - ReverseGateway(websocket_uri="http://example.com") + make_gateway(websocket_uri="http://example.com") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with missing hostname""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_missing_hostname( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must include hostname"): - ReverseGateway(websocket_uri="ws://") + make_gateway(websocket_uri="ws://") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway creates backend with authentication""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_iam_auth_created( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend - gateway = ReverseGateway( - pulsar_api_key="test-key", - pulsar_listener="test-listener" + gateway = make_gateway(id="test-rev-gw") + + mock_iam_auth.assert_called_once_with( + backend=mock_backend, + id="test-rev-gw", ) - # Verify get_pubsub was called with the correct parameters - mock_get_pubsub.assert_called_once_with( - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_config_receiver_gets_auth( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend + mock_auth_instance = MagicMock() + mock_iam_auth.return_value = mock_auth_instance + + gateway = make_gateway() + + mock_config_receiver.assert_called_once_with( + mock_backend, auth=mock_auth_instance, ) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway successful connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_success( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_ws = AsyncMock() mock_session.ws_connect.return_value = mock_ws mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is True assert gateway.session == mock_session assert gateway.ws == mock_ws mock_session.ws_connect.assert_called_once_with(gateway.url) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway connection failure""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_failure( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is False - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway disconnect""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket and session + async def test_reverse_gateway_disconnect( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False mock_session = AsyncMock() mock_session.closed = False - + gateway.ws = mock_ws gateway.session = mock_session - + await gateway.disconnect() - + mock_ws.close.assert_called_once() mock_session.close.assert_called_once() assert gateway.ws is None assert gateway.session is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket + async def test_reverse_gateway_send_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - + mock_ws.send_str.assert_called_once_with(json.dumps(test_message)) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message with closed connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock closed websocket + async def test_reverse_gateway_send_message_closed_connection( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = True gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - - # Should not call send_str on closed connection + mock_ws.send_str.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - mock_dispatcher_instance = AsyncMock() - mock_dispatcher_instance.handle_message.return_value = {"response": "success"} - mock_dispatcher.return_value = mock_dispatcher_instance - - gateway = ReverseGateway() - - # Mock send_message - gateway.send_message = AsyncMock() - - test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' - - await gateway.handle_message(test_message) - - mock_dispatcher_instance.handle_message.assert_called_once_with({ - "id": "test", - "service": "test-service", - "request": {"data": "test"} - }) - gateway.send_message.assert_called_once_with({"response": "success"}) + async def test_reverse_gateway_handle_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_dispatcher_instance = AsyncMock() + mock_dispatcher.return_value = mock_dispatcher_instance + + gateway = make_gateway() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - @pytest.mark.asyncio - async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message with invalid JSON""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock send_message gateway.send_message = AsyncMock() - - test_message = 'invalid json' - - # Should not raise exception + + test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' + await gateway.handle_message(test_message) - - # Should not call send_message due to error + + mock_dispatcher_instance.handle_message.assert_called_once_with( + { + "id": "test", + "service": "test-service", + "request": {"data": "test"}, + }, + gateway.send_message, + ) + + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + @pytest.mark.asyncio + async def test_reverse_gateway_handle_message_invalid_json( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + + gateway.send_message = AsyncMock() + + await gateway.handle_message('invalid json') + gateway.send_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with text message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_text_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.TEXT mock_msg.data = '{"test": "message"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "message"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with binary message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_binary_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.BINARY mock_msg.data = b'{"test": "binary"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "binary"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with close message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_close_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.CLOSE - - # Mock receive to return close message + mock_ws.receive.return_value = mock_msg - + await gateway.listen() - - # Should not call handle_message for close message + gateway.handle_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway shutdown""" + async def test_reverse_gateway_shutdown( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend mock_dispatcher_instance = AsyncMock() mock_dispatcher.return_value = mock_dispatcher_instance - gateway = ReverseGateway() + gateway = make_gateway() gateway.running = True - # Mock disconnect gateway.disconnect = AsyncMock() await gateway.shutdown() @@ -402,46 +435,50 @@ class TestReverseGateway: gateway.disconnect.assert_called_once() mock_backend.close.assert_called_once() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway stop""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_stop( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - + gateway.stop() - + assert gateway.running is False class TestReverseGatewayRun: """Test cases for ReverseGateway run method""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway run method with successful connect/listen cycle""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_run_successful_cycle( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_auth_instance = AsyncMock() + mock_iam_auth.return_value = mock_auth_instance + mock_config_receiver_instance = AsyncMock() mock_config_receiver.return_value = mock_config_receiver_instance - - gateway = ReverseGateway() - - # Mock methods - gateway.connect = AsyncMock(return_value=True) + + gateway = make_gateway() + gateway.listen = AsyncMock() gateway.disconnect = AsyncMock() gateway.shutdown = AsyncMock() - - # Stop after one iteration + call_count = 0 async def mock_connect(): nonlocal call_count @@ -451,91 +488,13 @@ class TestReverseGatewayRun: else: gateway.running = False return False - + gateway.connect = mock_connect - + await gateway.run() - + + mock_auth_instance.start.assert_called_once() mock_config_receiver_instance.start.assert_called_once() gateway.listen.assert_called_once() - # disconnect is called twice: once in the main loop, once in shutdown assert gateway.disconnect.call_count == 2 gateway.shutdown.assert_called_once() - - -class TestReverseGatewayArgs: - """Test cases for argument parsing and run function""" - - def test_parse_args_defaults(self): - """Test parse_args with default values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway'] - - try: - args = parse_args() - - assert args.websocket_uri is None - assert args.max_workers == 10 - assert args.pulsar_host is None - assert args.pulsar_api_key is None - assert args.pulsar_listener is None - finally: - sys.argv = original_argv - - def test_parse_args_custom_values(self): - """Test parse_args with custom values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = [ - 'reverse-gateway', - '--websocket-uri', 'ws://custom:8080/ws', - '--max-workers', '20', - '--pulsar-host', 'pulsar://custom:6650', - '--pulsar-api-key', 'test-key', - '--pulsar-listener', 'test-listener' - ] - - try: - args = parse_args() - - assert args.websocket_uri == 'ws://custom:8080/ws' - assert args.max_workers == 20 - assert args.pulsar_host == 'pulsar://custom:6650' - assert args.pulsar_api_key == 'test-key' - assert args.pulsar_listener == 'test-listener' - finally: - sys.argv = original_argv - - @patch('trustgraph.rev_gateway.service.ReverseGateway') - @patch('asyncio.run') - def test_run_function(self, mock_asyncio_run, mock_gateway_class): - """Test run function""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway', '--max-workers', '15'] - - try: - mock_gateway_instance = MagicMock() - mock_gateway_instance.url = "ws://localhost:7650/out" - mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650" - mock_gateway_class.return_value = mock_gateway_instance - - run() - - mock_gateway_class.assert_called_once_with( - websocket_uri=None, - max_workers=15, - pulsar_host=None, - pulsar_api_key=None, - pulsar_listener=None - ) - mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run()) - finally: - sys.argv = original_argv \ No newline at end of file diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index ce6e6b3d..360ac3dc 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions - # Verify collection existence is checked on each write - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + # Second write uses cached collection state — no collection_exists check + mock_qdrant_instance.collection_exists.assert_not_called() # But upsert should still be called mock_qdrant_instance.upsert.assert_called_once() diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index 8754f47c..44fdf516 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("test_collection", 384) + await processor.ensure_collection("test_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") mock_qdrant_instance.create_collection.assert_called_once() # Verify the collection is cached - assert "test_collection" in processor.created_collections + assert "test_collection" in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_ensure_collection_skips_existing(self, mock_qdrant_client): @@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("existing_collection", 384) + await processor.ensure_collection("existing_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once() mock_qdrant_instance.create_collection.assert_not_called() @@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add("cached_collection") + processor._known_collections.add("cached_collection") - processor.ensure_collection("cached_collection", 384) + await processor.ensure_collection("cached_collection", 384) # Should not check or create - just return mock_qdrant_instance.collection_exists.assert_not_called() @@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') + processor._known_collections.add('rows_test_workspace_test_collection_schema1_384') await processor.delete_collection('test_workspace', 'test_collection') @@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): assert mock_qdrant_instance.delete_collection.call_count == 2 # Verify the cached collection was removed - assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections + assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_delete_collection_schema(self, mock_qdrant_client): diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index 852f01a1..3e5664ea 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configurations""" + import asyncio processor = MagicMock() processor.schemas = {} processor.config_key = "schema" processor.registered_partitions = set() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test configuration diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 04acbb16..394f0e54 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -2,6 +2,8 @@ Tests for Cassandra triples storage service """ +import asyncio + import pytest from unittest.mock import MagicMock, patch, AsyncMock @@ -24,12 +26,13 @@ class TestCassandraStorageProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters (new cassandra_* names)""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, id='custom-storage', @@ -37,11 +40,12 @@ class TestCassandraStorageProcessor: cassandra_username='testuser', cassandra_password='testpass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'testuser' assert processor.cassandra_password == 'testpass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_partial_auth(self): """Test processor initialization with only username (no password)""" @@ -92,6 +96,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor( @@ -114,7 +119,6 @@ class TestCassandraStorageProcessor: username='testuser', password='testpass' ) - assert processor.table == 'user1' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -122,6 +126,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when no authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -138,7 +143,6 @@ class TestCassandraStorageProcessor: hosts=['cassandra'], # Updated default keyspace='user2' ) - assert processor.table == 'user2' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -146,6 +150,7 @@ class TestCassandraStorageProcessor: """Test that TrustGraph is not recreated when table hasn't changed""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -169,6 +174,7 @@ class TestCassandraStorageProcessor: """Test that triples are properly inserted into Cassandra""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -208,12 +214,12 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) - assert mock_tg_instance.insert.call_count == 2 - mock_tg_instance.insert.assert_any_call( + assert mock_tg_instance.async_insert.call_count == 2 + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject1', 'predicate1', 'object1', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) - mock_tg_instance.insert.assert_any_call( + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject2', 'predicate2', 'object2', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) @@ -224,6 +230,7 @@ class TestCassandraStorageProcessor: """Test behavior when message has no triples""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -236,19 +243,17 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify no triples were inserted - mock_tg_instance.insert.assert_not_called() + mock_tg_instance.async_insert.assert_not_called() @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - @patch('trustgraph.storage.triples.cassandra.write.time.sleep') - async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class): + async def test_exception_handling_on_connection_failure(self, mock_kg_class): """Test exception handling during TrustGraph creation""" taskgroup_mock = MagicMock() mock_kg_class.side_effect = Exception("Connection failed") processor = Processor(taskgroup=taskgroup_mock) - # Create mock message mock_message = MagicMock() mock_message.metadata.collection = 'collection1' mock_message.triples = [] @@ -256,9 +261,6 @@ class TestCassandraStorageProcessor: with pytest.raises(Exception, match="Connection failed"): await processor.store_triples('user1', mock_message) - # Verify sleep was called before re-raising - mock_sleep.assert_called_once_with(1) - def test_add_args_method(self): """Test that add_args properly configures argument parser""" from argparse import ArgumentParser @@ -359,8 +361,6 @@ class TestCassandraStorageProcessor: mock_message1.triples = [] await processor.store_triples('user1', mock_message1) - assert processor.table == 'user1' - assert processor.tg == mock_tg_instance1 # Second message with different table mock_message2 = MagicMock() @@ -368,11 +368,11 @@ class TestCassandraStorageProcessor: mock_message2.triples = [] await processor.store_triples('user2', mock_message2) - assert processor.table == 'user2' - assert processor.tg == mock_tg_instance2 - # Verify TrustGraph was created twice for different tables + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -380,6 +380,7 @@ class TestCassandraStorageProcessor: """Test storing triples with special characters and unicode""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -405,7 +406,7 @@ class TestCassandraStorageProcessor: await processor.store_triples('test_workspace', mock_message) # Verify the triple was inserted with special characters preserved - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'test_collection', 'subject with spaces & symbols', 'predicate:with/colons', @@ -418,29 +419,29 @@ class TestCassandraStorageProcessor: @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class): - """Test that table remains unchanged when TrustGraph creation fails""" + async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class): + """Test that a failed connection doesn't leave stale cached state""" taskgroup_mock = MagicMock() + mock_good_instance = MagicMock() processor = Processor(taskgroup=taskgroup_mock) - # Set an initial table - processor.table = ('old_user', 'old_collection') - - # Mock TrustGraph to raise exception - mock_kg_class.side_effect = Exception("Connection failed") - mock_message = MagicMock() - mock_message.metadata.collection = 'new_collection' + mock_message.metadata.collection = 'collection1' mock_message.triples = [] + # First call fails + mock_kg_class.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples('new_user', mock_message) + await processor.store_triples('user1', mock_message) - # Table should remain unchanged since self.table = table happens after try/except - assert processor.table == ('old_user', 'old_collection') - # TrustGraph should be set to None though - assert processor.tg is None + # Second call succeeds — should retry connection, not use stale state + mock_kg_class.side_effect = None + mock_kg_class.return_value = mock_good_instance + await processor.store_triples('user1', mock_message) + + # Connection was attempted twice (failed + succeeded) + assert mock_kg_class.call_count == 2 class TestCassandraPerformanceOptimizations: @@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations: """Test that legacy mode still works with single table""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): @@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations: """Test that optimized mode uses multi-table schema""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): @@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations: """Test that all tables stay consistent during batch writes""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations: await processor.store_triples('user1', mock_message) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'collection1', 'test_subject', 'test_predicate', 'test_object', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py index 51cf834f..f1297e1c 100644 --- a/tests/unit/test_structured_data/test_row_embeddings_query.py +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -89,7 +89,8 @@ class TestSanitizeName: class TestFindCollection: - def test_finds_matching_collection(self): + @pytest.mark.asyncio + async def test_finds_matching_collection(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_test_workspace_test_col_customers_384" @@ -98,11 +99,12 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result == "rows_test_workspace_test_col_customers_384" - def test_returns_none_when_no_match(self): + @pytest.mark.asyncio + async def test_returns_none_when_no_match(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_other_workspace_other_col_schema_768" @@ -111,14 +113,15 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result is None - def test_returns_none_on_error(self): + @pytest.mark.asyncio + async def test_returns_none_on_error(self): proc = _make_processor() proc.qdrant.get_collections.side_effect = Exception("connection error") - result = proc.find_collection("workspace", "col", "schema") + result = await proc.find_collection("workspace", "col", "schema") assert result is None @@ -139,7 +142,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_no_collection_returns_empty(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value=None) + proc.find_collection = AsyncMock(return_value=None) request = _make_request() result = await proc.query_row_embeddings("test-workspace", request) @@ -148,7 +151,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_successful_query_returns_matches(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") points = [ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), @@ -172,7 +175,7 @@ class TestQueryRowEmbeddings: async def test_index_name_filter_applied(self): """When index_name is specified, a Qdrant filter should be used.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -188,7 +191,7 @@ class TestQueryRowEmbeddings: async def test_no_index_name_no_filter(self): """When index_name is empty, no filter should be applied.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -204,7 +207,7 @@ class TestQueryRowEmbeddings: async def test_missing_payload_fields_default(self): """Points with missing payload fields should use defaults.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") point = MagicMock() point.payload = {} # Empty payload @@ -225,7 +228,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_qdrant_error_propagates(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") proc.qdrant.query_points.side_effect = Exception("qdrant down") request = _make_request() diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index ca9146b9..d18bee34 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -62,12 +62,6 @@ class AsyncSocketClient: if self._connected: return - if not self.token: - raise ProtocolException( - "AsyncSocketClient requires a token for first-frame " - "auth against /api/v1/socket" - ) - ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout @@ -79,7 +73,7 @@ class AsyncSocketClient: # reader task so the response isn't consumed by the reader's # id-based routing. await self._socket.send(json.dumps({ - "type": "auth", "token": self.token, + "type": "auth", "token": self.token or "", })) try: raw = await asyncio.wait_for( diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index c3ec2308..06357d70 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -132,3 +132,34 @@ class Knowledge: self.request(request = input) + def list_de_cores(self): + + input = { + "operation": "list-de-cores", + "workspace": self.api.workspace, + } + + return self.request(request = input)["ids"] + + def delete_de_core(self, id): + + input = { + "operation": "delete-de-core", + "workspace": self.api.workspace, + "id": id, + } + + self.request(request = input) + + def load_de_core(self, id, flow="default", collection="default"): + + input = { + "operation": "load-de-core", + "workspace": self.api.workspace, + "id": id, + "flow": flow, + "collection": collection, + } + + self.request(request = input) + diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index 024e933d..b3506bb7 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -365,7 +365,7 @@ class Library: id = v["id"], time = datetime.datetime.fromtimestamp(v["time"]), kind = v["kind"], - title = v["title"], + title = v.get("title", ""), comments = v.get("comments", ""), metadata = [ Triple( @@ -482,14 +482,15 @@ class Library: "workspace": self.api.workspace, "document-metadata": { "document-id": id, - "time": metadata.time, + "id": id, + "time": int(metadata.time.timestamp()) if hasattr(metadata.time, 'timestamp') else metadata.time, "title": metadata.title, "comments": metadata.comments, "metadata": [ { - "s": from_value(t["s"]), - "p": from_value(t["p"]), - "o": from_value(t["o"]), + "s": from_value(t.s), + "p": from_value(t.p), + "o": from_value(t.o), } for t in metadata.metadata ], @@ -498,14 +499,17 @@ class Library: } object = self.request(input) - doc = object["document-metadata"] + doc = object.get("document-metadata") if isinstance(object, dict) else None + + if not doc: + return metadata try: - DocumentMetadata( + return DocumentMetadata( id = doc["id"], time = datetime.datetime.fromtimestamp(doc["time"]), kind = doc["kind"], - title = doc["title"], + title = doc.get("title", ""), comments = doc.get("comments", ""), metadata = [ Triple( @@ -513,10 +517,11 @@ class Library: p = to_value(w["p"]), o = to_value(w["o"]) ) - for w in doc["metadata"] + for w in doc.get("metadata", []) ], - workspace = doc.get("workspace", ""), - tags = doc["tags"] + tags = doc.get("tags", []), + parent_id = doc.get("parent-id", ""), + document_type = doc.get("document-type", "source"), ) except Exception as e: logger.error("Failed to parse document update response", exc_info=True) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index aeb15f85..6eeb95ff 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -11,6 +11,7 @@ multiplexes requests by ID. import json import asyncio import websockets +from websockets.exceptions import ConnectionClosed from typing import Optional, Dict, Any, Iterator, Union, List from threading import Lock @@ -137,12 +138,6 @@ class SocketClient: if self._connected: return - if not self.token: - raise ProtocolException( - "SocketClient requires a token for first-frame auth " - "against /api/v1/socket" - ) - ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout @@ -153,7 +148,7 @@ class SocketClient: # auth-ok / auth-failed response isn't consumed by the reader # loop's id-based routing. await self._socket.send(json.dumps({ - "type": "auth", "token": self.token, + "type": "auth", "token": self.token or "", })) try: raw = await asyncio.wait_for( @@ -197,13 +192,13 @@ class SocketClient: if request_id and request_id in self._pending: await self._pending[request_id].put(response) - except websockets.exceptions.ConnectionClosed: + except ConnectionClosed: pass except Exception as e: for queue in self._pending.values(): try: await queue.put({"error": str(e)}) - except: + except Exception: pass finally: self._connected = False @@ -256,7 +251,7 @@ class SocketClient: finally: try: loop.run_until_complete(async_gen.aclose()) - except: + except Exception: pass def _streaming_generator_raw( @@ -279,7 +274,7 @@ class SocketClient: finally: try: loop.run_until_complete(async_gen.aclose()) - except: + except Exception: pass async def _send_request_async_streaming_raw( @@ -491,12 +486,64 @@ class SocketClient: triples=raw_triples, ) + def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-kg-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_kg_core( + self, id: str, triples=None, graph_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-kg-core", + "workspace": self.workspace, + "id": id, + } + if triples is not None: + request["triples"] = triples + if graph_embeddings is not None: + request["graph-embeddings"] = graph_embeddings + return self._send_request_sync("knowledge", None, request) + + def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-de-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_de_core( + self, id: str, document_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-de-core", + "workspace": self.workspace, + "id": id, + } + if document_embeddings is not None: + request["document-embeddings"] = document_embeddings + return self._send_request_sync("knowledge", None, request) + def close(self) -> None: """Close the persistent WebSocket connection.""" if self._loop and not self._loop.is_closed(): try: self._loop.run_until_complete(self._close_async()) - except: + except Exception: pass async def _close_async(self): diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 5c59c515..b9f2ee0b 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -76,8 +76,10 @@ class Consumer: if hasattr(self, "consumer"): if self.consumer: - self.consumer.unsubscribe() - self.consumer.close() + try: + self.consumer.close() + except Exception: + pass self.consumer = None async def stop(self): @@ -157,12 +159,14 @@ class Consumer: except Exception as e: logger.error(f"Consumer loop exception: {e}", exc_info=True) - for c in consumers: + for i, c in enumerate(consumers): try: - c.unsubscribe() c.close() - except Exception: - pass + except Exception as ce: + logger.warning( + f"Consumer {i} close failed (error path): " + f"{type(ce).__name__}: {ce}" + ) for ex in executors: ex.shutdown(wait=False) consumers = [] @@ -171,12 +175,14 @@ class Consumer: continue finally: - for c in consumers: + for i, c in enumerate(consumers): try: - c.unsubscribe() c.close() - except Exception: - pass + except Exception as ce: + logger.warning( + f"Consumer {i} close failed: " + f"{type(ce).__name__}: {ce}" + ) for ex in executors: ex.shutdown(wait=False) @@ -188,7 +194,7 @@ class Consumer: try: msg = await loop.run_in_executor( executor, - lambda: consumer.receive(timeout_millis=100), + lambda: consumer.receive(timeout_millis=2000), ) except Exception as e: # Handle timeout from any backend diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py index 0f42bbe2..3b928d3e 100644 --- a/trustgraph-base/trustgraph/base/flow.py +++ b/trustgraph-base/trustgraph/base/flow.py @@ -34,6 +34,8 @@ class Flow: async def stop(self): for c in self.consumer.values(): await c.stop() + for p in self.producer.values(): + await p.stop() if self.librarian: await self.librarian.stop() diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py index 4be59de1..e0457d19 100644 --- a/trustgraph-base/trustgraph/base/iam_client.py +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -62,6 +62,22 @@ class IamClient(RequestResponse): ) return resp.user + async def authenticate_anonymous(self, timeout=IAM_TIMEOUT): + """Request anonymous access from the IAM regime. + + Returns ``(user_id, workspace, roles)`` if the regime permits + anonymous access, or raises ``RuntimeError`` with error type + ``auth-failed`` if it does not.""" + resp = await self._request( + operation="authenticate-anonymous", + timeout=timeout, + ) + return ( + resp.resolved_user_id, + resp.resolved_workspace, + list(resp.resolved_roles), + ) + async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT): """Resolve a plaintext API key to its identity triple. diff --git a/trustgraph-base/trustgraph/base/producer.py b/trustgraph-base/trustgraph/base/producer.py index 20b4b0d6..9af9d22e 100644 --- a/trustgraph-base/trustgraph/base/producer.py +++ b/trustgraph-base/trustgraph/base/producer.py @@ -34,6 +34,9 @@ class Producer: async def stop(self): self.running = False + if self.producer: + self.producer.close() + self.producer = None async def send(self, msg, properties={}): diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py index fb4765c1..4ae8d2d0 100644 --- a/trustgraph-base/trustgraph/base/pubsub.py +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) # Default connection settings from environment DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None) +DEFAULT_PULSAR_ADMIN_URL = os.getenv("PULSAR_ADMIN_URL", 'http://pulsar:8080') DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq') DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672')) @@ -43,6 +44,7 @@ def get_pubsub(**config: Any) -> Any: host=config.get('pulsar_host', DEFAULT_PULSAR_HOST), api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY), listener=config.get('pulsar_listener'), + admin_url=config.get('pulsar_admin_url', DEFAULT_PULSAR_ADMIN_URL), ) elif backend_type == 'rabbitmq': from .rabbitmq_backend import RabbitMQBackend @@ -77,6 +79,7 @@ def get_pubsub(**config: Any) -> Any: STANDALONE_PULSAR_HOST = 'pulsar://localhost:6650' +STANDALONE_PULSAR_ADMIN_URL = 'http://localhost:8080' def add_pubsub_args(parser: ArgumentParser, standalone: bool = False) -> None: @@ -88,6 +91,7 @@ def add_pubsub_args(parser: ArgumentParser, standalone: bool = False) -> None: that run outside containers) """ pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST + pulsar_admin_url = STANDALONE_PULSAR_ADMIN_URL if standalone else DEFAULT_PULSAR_ADMIN_URL pulsar_listener = 'localhost' if standalone else None rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST kafka_bootstrap = 'localhost:9092' if standalone else DEFAULT_KAFKA_BOOTSTRAP @@ -105,6 +109,12 @@ def add_pubsub_args(parser: ArgumentParser, standalone: bool = False) -> None: help=f'Pulsar host (default: {pulsar_host})', ) + parser.add_argument( + '--pulsar-admin-url', + default=pulsar_admin_url, + help=f'Pulsar admin REST API URL (default: {pulsar_admin_url})', + ) + parser.add_argument( '--pulsar-api-key', default=DEFAULT_PULSAR_API_KEY, diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index e27d16af..e85dfbef 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -7,8 +7,12 @@ handling topic mapping, serialization, and Pulsar client management. import pulsar import _pulsar +import asyncio import json import logging +import urllib.request +import urllib.error +import urllib.parse from typing import Any from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message @@ -117,7 +121,10 @@ class PulsarBackend: producers and consumers. """ - def __init__(self, host: str, api_key: str = None, listener: str = None): + def __init__( + self, host: str, api_key: str = None, listener: str = None, + admin_url: str = None, + ): """ Initialize Pulsar backend. @@ -125,10 +132,12 @@ class PulsarBackend: host: Pulsar broker URL (e.g., pulsar://localhost:6650) api_key: Optional API key for authentication listener: Optional listener name for multi-homed setups + admin_url: Pulsar admin REST API URL (e.g., http://pulsar:8080) """ self.host = host self.api_key = api_key self.listener = listener + self.admin_url = admin_url # Create Pulsar client client_args = {'service_url': host} @@ -139,6 +148,10 @@ class PulsarBackend: if api_key: client_args['authentication'] = pulsar.AuthenticationToken(api_key) + client_args['logger'] = pulsar.ConsoleLogger( + _pulsar.LoggerLevel.Error + ) + self.client = pulsar.Client(**client_args) logger.info(f"Pulsar client connected to {host}") @@ -266,24 +279,129 @@ class PulsarBackend: return PulsarBackendConsumer(pulsar_consumer, schema) + def _admin_api_path(self, pulsar_uri: str) -> str: + """ + Convert a Pulsar topic URI to an admin REST API path. + + persistent://tg/flow/triples-store:default:explain-flow + -> /admin/v2/persistent/tg/flow/triples-store%3Adefault%3Aexplain-flow + """ + scheme, rest = pulsar_uri.split('://', 1) + tenant, namespace, topic = rest.split('/', 2) + encoded_topic = urllib.parse.quote(topic, safe='') + return f"/admin/v2/{scheme}/{tenant}/{namespace}/{encoded_topic}" + + def _admin_request(self, method, path): + """ + Make a synchronous admin REST API request. + + Returns parsed JSON for GET, None for DELETE/PUT. + Raises urllib.error.HTTPError for non-404 errors. + 404 is treated as success (idempotent deletion). + """ + url = f"{self.admin_url}{path}" + req = urllib.request.Request(url, method=method) + + try: + with urllib.request.urlopen(req) as resp: + if method == 'GET': + return json.loads(resp.read().decode('utf-8')) + return None + except urllib.error.HTTPError as e: + if e.code == 404: + return None + raise + + def _delete_topic_sync(self, topic: str): + """ + Delete a persistent topic and all its subscriptions. + + Subscriptions must be removed first — Pulsar rejects topic + deletion while subscriptions exist. Force-deletes each + subscription to disconnect any lingering consumers. + """ + pulsar_uri = self.map_topic(topic) + + if pulsar_uri.startswith('non-persistent://'): + return + + api_path = self._admin_api_path(pulsar_uri) + + try: + subs = self._admin_request('GET', f"{api_path}/subscriptions") + except Exception as e: + logger.warning(f"Failed to list subscriptions for {topic}: {e}") + return + + if subs: + for sub in subs: + encoded_sub = urllib.parse.quote(sub, safe='') + try: + self._admin_request( + 'DELETE', + f"{api_path}/subscription/{encoded_sub}" + f"?force=true" + ) + logger.info( + f"Deleted subscription {sub} from {topic}" + ) + except Exception as e: + logger.warning( + f"Failed to delete subscription {sub} " + f"from {topic}: {e}" + ) + + try: + self._admin_request('DELETE', api_path) + logger.info(f"Deleted topic: {topic}") + except Exception as e: + logger.warning(f"Failed to delete topic {topic}: {e}") + + def _topic_exists_sync(self, topic: str) -> bool: + """Check topic existence via admin API.""" + pulsar_uri = self.map_topic(topic) + + if pulsar_uri.startswith('non-persistent://'): + return False + + api_path = self._admin_api_path(pulsar_uri) + + try: + result = self._admin_request('GET', f"{api_path}/stats") + return result is not None + except Exception: + return False + async def create_topic(self, topic: str) -> None: - """No-op — Pulsar auto-creates topics on first use. - TODO: Use admin REST API for explicit persistent topic creation.""" + """No-op — Pulsar auto-creates topics on first use.""" pass async def delete_topic(self, topic: str) -> None: - """No-op — to be replaced with admin REST API calls. - TODO: Delete persistent topic via admin API.""" - pass + """ + Delete a persistent topic and all its subscriptions via + the admin REST API. + + Called by the flow controller during deliberate flow deletion. + Non-persistent topics are skipped. Idempotent. + """ + if not self.admin_url: + logger.warning( + f"Cannot delete topic {topic}: " + f"no admin URL configured" + ) + return + + await asyncio.to_thread(self._delete_topic_sync, topic) async def topic_exists(self, topic: str) -> bool: - """Returns True — Pulsar auto-creates on subscribe. - TODO: Use admin REST API for actual existence check.""" - return True + """Check whether a persistent topic exists via the admin API.""" + if not self.admin_url: + return True + + return await asyncio.to_thread(self._topic_exists_sync, topic) async def ensure_topic(self, topic: str) -> None: - """No-op — Pulsar auto-creates topics on first use. - TODO: Use admin REST API for explicit creation.""" + """No-op — Pulsar auto-creates topics on first use.""" pass def close(self) -> None: diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 2601a1e1..0506cb9f 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any from . request_response_spec import RequestResponse, RequestResponseSpec @@ -44,6 +45,60 @@ def from_value(x: Any) -> Any: return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): + + async def query_gen(self, s=None, p=None, o=None, limit=20, + collection="default", + batch_size=20, timeout=30, g=None): + """Async generator yielding Triple objects as batches arrive.""" + queue = asyncio.Queue() + done = False + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + await queue.put(batch) + + if resp.is_final: + await queue.put(None) + + return resp.is_final + + # Launch the streaming request as a background task + task = asyncio.ensure_future(self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + collection=collection, + streaming=True, + batch_size=batch_size, + g=g, + ), + timeout=timeout, + recipient=recipient, + )) + + try: + while True: + batch = await queue.get() + if batch is None: + break + for triple in batch: + yield triple + finally: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + async def query(self, s=None, p=None, o=None, limit=20, collection="default", timeout=30, g=None): diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f2cc8e46..3830bf59 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple, Optional from ...schema import ( KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, + DocumentEmbeddings, ChunkEmbeddings, Metadata, EntityEmbeddings ) from .base import MessageTranslator @@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator): ] ) + document_embeddings = None + if "document-embeddings" in data: + document_embeddings = DocumentEmbeddings( + metadata=Metadata( + id=data["document-embeddings"]["metadata"]["id"], + root=data["document-embeddings"]["metadata"].get("root", ""), + collection=data["document-embeddings"]["metadata"]["collection"] + ), + chunks=[ + ChunkEmbeddings( + chunk_id=ch["chunk_id"], + vector=ch["vector"], + ) + for ch in data["document-embeddings"]["chunks"] + ] + ) + return KnowledgeRequest( operation=data.get("operation"), id=data.get("id"), @@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator): collection=data.get("collection"), triples=triples, graph_embeddings=graph_embeddings, + document_embeddings=document_embeddings, ) def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: @@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator): ], } + if obj.document_embeddings: + result["document-embeddings"] = { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + return result @@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator): } } + # Streaming document embeddings response + if obj.document_embeddings: + return { + "document-embeddings": { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + } + # End of stream marker if obj.eos is True: return {"eos": True} @@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator): is_final = ( obj.ids is not None or # List response obj.eos is True or # End of stream - (not obj.triples and not obj.graph_embeddings) # Empty response + (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response ) return response, is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 64cb7082..a3879103 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -4,7 +4,7 @@ from ..core.topic import queue from ..core.metadata import Metadata from .document import Document, TextDocument from .graph import Triples -from .embeddings import GraphEmbeddings +from .embeddings import GraphEmbeddings, DocumentEmbeddings # get-kg-core # -> (???) @@ -41,6 +41,9 @@ class KnowledgeRequest: triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + # put-de-core + document_embeddings: DocumentEmbeddings | None = None + @dataclass class KnowledgeResponse: error: Error | None = None @@ -48,6 +51,7 @@ class KnowledgeResponse: eos: bool = False # Indicates end of knowledge core stream triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + document_embeddings: DocumentEmbeddings | None = None knowledge_request_queue = queue('knowledge', cls='request') knowledge_response_queue = queue('knowledge', cls='response') diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index f0c8d571..7aa2f96a 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index e8062fba..fafcc9bc 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", "requests", "pulsar-client", "aiohttp", @@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main" tg-dump-queues = "trustgraph.cli.dump_queues:main" tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" +tg-get-de-core = "trustgraph.cli.get_de_core:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" @@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main" tg-load-knowledge = "trustgraph.cli.load_knowledge:main" tg-load-structured-data = "trustgraph.cli.load_structured_data:main" tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main" +tg-put-de-core = "trustgraph.cli.put_de_core:main" tg-put-kg-core = "trustgraph.cli.put_kg_core:main" tg-remove-library-document = "trustgraph.cli.remove_library_document:main" tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main" @@ -119,6 +121,9 @@ tg-show-extraction-provenance = "trustgraph.cli.show_extraction_provenance:main" tg-list-explain-traces = "trustgraph.cli.list_explain_traces:main" tg-show-explain-trace = "trustgraph.cli.show_explain_trace:main" +[tool.setuptools.package-data] +"trustgraph.cli.sample_documents" = ["*.md", "*.pdf", "*.json"] + [tool.setuptools.packages.find] include = ["trustgraph*"] diff --git a/trustgraph-cli/trustgraph/cli/get_de_core.py b/trustgraph-cli/trustgraph/cli/get_de_core.py new file mode 100644 index 00000000..caf74ba9 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/get_de_core.py @@ -0,0 +1,111 @@ +""" +Uses the knowledge service to fetch a document embeddings core which is +saved to a local file in msgpack format. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def write_de(f, data): + msg = ( + "de", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["root"], + "c": data["metadata"]["collection"], + }, + "c": [ + { + "i": ch["chunk_id"], + "v": ch["vector"], + } + for ch in data["chunks"] + ] + } + ) + f.write(msgpack.packb(msg, use_bin_type=True)) + +def fetch(url, workspace, id, output, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(output, "wb") as f: + + for response in socket.get_de_core(id): + + if "document-embeddings" in response: + de += 1 + write_de(f, response["document-embeddings"]) + + print(f"Got: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-get-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help=f'Output file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index 8bee4115..b4f37b81 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,13 +5,11 @@ to a local file in msgpack format. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,7 +19,7 @@ def write_triple(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -35,13 +33,13 @@ def write_ge(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "e": [ { "e": ent["entity"], - "v": ent["vectors"], + "v": ent["vector"], } for ent in data["entities"] ] @@ -49,54 +47,18 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, workspace, id, output, token=None): +def fetch(url, workspace, id, output, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - mid = str(uuid.uuid4()) - - async with connect(url) as ws: - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "get-kg-core", - "workspace": workspace, - "id": id, - } - }) - - await ws.send(req) + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 with open(output, "wb") as f: - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "response" not in obj: - raise RuntimeError("No response?") - - response = obj["response"] - - if "error" in response: - raise RuntimeError(obj["error"]) - - if "eos" in response: - if response["eos"]: break + for response in socket.get_kg_core(id): if "triples" in response: t += 1 @@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None): print(f"Got: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -151,14 +114,12 @@ def main(): try: - asyncio.run( - fetch( - url=args.url, - workspace=args.workspace, - id=args.id, - output=args.output, - token=args.token, - ) + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 23d6bcac..f39cdab0 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question """ import argparse -import json import os import sys -import websockets -import asyncio from trustgraph.api import ( Api, ExplainabilityClient, @@ -31,607 +28,6 @@ default_max_path_length = 2 default_edge_score_limit = 30 default_edge_limit = 25 -# Provenance predicates -TG = "https://trustgraph.ai/ns/" -TG_QUERY = TG + "query" -TG_CONCEPT = TG + "concept" -TG_ENTITY = TG + "entity" -TG_EDGE_COUNT = TG + "edgeCount" -TG_SELECTED_EDGE = TG + "selectedEdge" -TG_EDGE = TG + "edge" -TG_REASONING = TG + "reasoning" -TG_DOCUMENT = TG + "document" -TG_CONTAINS = TG + "contains" -PROV = "http://www.w3.org/ns/prov#" -PROV_STARTED_AT_TIME = PROV + "startedAtTime" -PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" - - -def _get_event_type(prov_id): - """Extract event type from provenance_id""" - if "question" in prov_id: - return "question" - elif "grounding" in prov_id: - return "grounding" - elif "exploration" in prov_id: - return "exploration" - elif "focus" in prov_id: - return "focus" - elif "synthesis" in prov_id: - return "synthesis" - return "provenance" - - -def _format_provenance_details(event_type, triples): - """Format provenance details based on event type and triples""" - lines = [] - - if event_type == "question": - # Show query and timestamp - for s, p, o in triples: - if p == TG_QUERY: - lines.append(f" Query: {o}") - elif p == PROV_STARTED_AT_TIME: - lines.append(f" Time: {o}") - - elif event_type == "grounding": - # Show extracted concepts - concepts = [o for s, p, o in triples if p == TG_CONCEPT] - if concepts: - lines.append(f" Concepts: {len(concepts)}") - for concept in concepts: - lines.append(f" - {concept}") - - elif event_type == "exploration": - # Show edge count (seed entities resolved separately with labels) - for s, p, o in triples: - if p == TG_EDGE_COUNT: - lines.append(f" Edges explored: {o}") - - elif event_type == "focus": - # For focus, just count edge selection URIs - # The actual edge details are fetched separately via edge_selections parameter - edge_sel_uris = [] - for s, p, o in triples: - if p == TG_SELECTED_EDGE: - edge_sel_uris.append(o) - if edge_sel_uris: - lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") - - elif event_type == "synthesis": - # Show document reference (content already streamed) - for s, p, o in triples: - if p == TG_DOCUMENT: - lines.append(f" Document: {o}") - - return lines - - -async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False): - """Query triples for a provenance node (single attempt)""" - request = { - "id": "triples-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": prov_id}, - "collection": collection, - "limit": 100 - } - } - # Add graph filter if specified (for named graph queries) - if graph is not None: - request["request"]["g"] = graph - - if debug: - print(f" [debug] querying triples for s={prov_id}", file=sys.stderr) - - triples = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if debug: - print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr) - - if response.get("id") != "triples-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - # Handle triples response - # Response format: {"response": [triples...]} - # Each triple uses compact keys: "i" for iri, "v" for value, "t" for type - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", t.get("s", {}).get("v", "")) - p = t.get("p", {}).get("i", t.get("p", {}).get("v", "")) - # Handle quoted triples (type "t") and regular values - o_term = t.get("o", {}) - if o_term.get("t") == "t": - # Quoted triple - extract s, p, o from nested structure - tr = o_term.get("tr", {}) - o = { - "s": tr.get("s", {}).get("i", ""), - "p": tr.get("p", {}).get("i", ""), - "o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")), - } - else: - o = o_term.get("i", o_term.get("v", "")) - triples.append((s, p, o)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception: {e}", file=sys.stderr) - - if debug: - print(f" [debug] got {len(triples)} triples", file=sys.stderr) - - return triples - - -async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): - """Query triples for a provenance node with retries for race condition""" - for attempt in range(max_retries): - triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug) - if triples: - return triples - # Wait before retry if empty (triples may not be stored yet) - if attempt < max_retries - 1: - if debug: - print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr) - await asyncio.sleep(retry_delay) - return [] - - -async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False): - """ - Query for provenance of an edge (s, p, o) in the knowledge graph. - - Finds subgraphs that contain the edge via tg:contains, then follows - prov:wasDerivedFrom to find source documents. - - Returns list of source URIs (chunks, pages, documents). - """ - # Query for subgraphs that contain this edge: ?subgraph tg:contains <> - request = { - "id": "edge-prov-request", - "service": "triples", - "flow": flow_id, - "request": { - "p": {"t": "i", "i": TG_CONTAINS}, - "o": { - "t": "t", # Quoted triple type - "tr": { - "s": {"t": "i", "i": edge_s}, - "p": {"t": "i", "i": edge_p}, - "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, - } - }, - "collection": collection, - "limit": 10 - } - } - - if debug: - print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr) - - stmt_uris = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "edge-prov-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", "") - if s: - stmt_uris.append(s) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr) - - # For each statement, query wasDerivedFrom to find sources - sources = [] - for stmt_uri in stmt_uris: - # Query: stmt_uri prov:wasDerivedFrom ?source - request = { - "id": "derived-from-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": stmt_uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 10 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "derived-from-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - o = t.get("o", {}).get("i", "") - if o: - sources.append(o) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr) - - return sources - - -async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False): - """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" - request = { - "id": "parent-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 1 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "parent-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - return triple_list[0].get("o", {}).get("i", None) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying parent: {e}", file=sys.stderr) - - return None - - -async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False): - """ - Trace the full provenance chain from a source URI up to the root document. - Returns a list of (uri, label) tuples from leaf to root. - """ - chain = [] - current = source_uri - max_depth = 10 # Prevent infinite loops - - for _ in range(max_depth): - if not current: - break - - # Get label for current entity - label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug) - chain.append((current, label)) - - # Get parent - parent = await _query_derived_from(ws_url, flow_id, current, collection, debug) - if not parent or parent == current: - break - current = parent - - return chain - - -def _format_provenance_chain(chain): - """ - Format a provenance chain as a human-readable string. - Chain is [(uri, label), ...] from leaf to root. - """ - if not chain: - return "" - - # Show labels, from leaf to root - labels = [label for uri, label in chain] - return " → ".join(labels) - - -def _is_iri(value): - """Check if a value looks like an IRI.""" - if not isinstance(value, str): - return False - return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") - - -async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False): - """ - Query for the rdfs:label of an IRI. - Uses label_cache to avoid repeated queries. - Returns the label if found, otherwise returns the IRI. - """ - if not _is_iri(iri): - return iri - - # Check cache first - if iri in label_cache: - return label_cache[iri] - - request = { - "id": "label-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": iri}, - "p": {"t": "i", "i": RDFS_LABEL}, - "collection": collection, - "limit": 1 - } - } - - label = iri # Default to IRI if no label found - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "label-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - # Get the label value - o = triple_list[0].get("o", {}) - label = o.get("v", o.get("i", iri)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr) - - # Cache the result - label_cache[iri] = label - return label - - -async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False): - """ - Resolve labels for all IRI components of an edge triple. - Returns (s_label, p_label, o_label). - """ - s = edge_triple.get("s", "?") - p = edge_triple.get("p", "?") - o = edge_triple.get("o", "?") - - s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug) - p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug) - o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug) - - return s_label, p_label, o_label - - -async def _question_explainable( - url, flow_id, question, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, token=None, debug=False -): - """Execute graph RAG with explainability - shows provenance events with details""" - # Convert HTTP URL to WebSocket URL - if url.startswith("http://"): - ws_url = url.replace("http://", "ws://", 1) - elif url.startswith("https://"): - ws_url = url.replace("https://", "wss://", 1) - else: - ws_url = f"ws://{url}" - - ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" - if token: - ws_url = f"{ws_url}?token={token}" - - # Cache for label lookups to avoid repeated queries - label_cache = {} - - request = { - "id": "cli-request", - "service": "graph-rag", - "flow": flow_id, - "request": { - "query": question, - "collection": collection, - "entity-limit": entity_limit, - "triple-limit": triple_limit, - "max-subgraph-size": max_subgraph_size, - "max-path-length": max_path_length, - "streaming": True - } - } - - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "cli-request": - continue - - if "error" in response: - print(f"\nError: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - - # Check for errors in response - if "error" in resp and resp["error"]: - err = resp["error"] - print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr) - break - - message_type = resp.get("message_type", "") - - if debug: - print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr) - - if message_type == "explain": - # Display explain event with details - explain_id = resp.get("explain_id", "") - explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval) - if explain_id: - event_type = _get_event_type(explain_id) - print(f"\n [{event_type}] {explain_id}", file=sys.stderr) - - # Query triples for this explain node (using named graph filter) - triples = await _query_triples( - ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug - ) - - # Format and display details - details = _format_provenance_details(event_type, triples) - for line in details: - print(line, file=sys.stderr) - - # For exploration events, resolve entity labels - if event_type == "exploration": - entity_iris = [o for s, p, o in triples if p == TG_ENTITY] - if entity_iris: - print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) - for iri in entity_iris: - label = await _query_label( - ws_url, flow_id, iri, collection, - label_cache, debug=debug - ) - print(f" - {label}", file=sys.stderr) - - # For focus events, query each edge selection for details - if event_type == "focus": - for s, p, o in triples: - if debug: - print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr) - if p == TG_SELECTED_EDGE and isinstance(o, str): - if debug: - print(f" [debug] querying edge selection: {o}", file=sys.stderr) - # Query the edge selection entity (using named graph filter) - edge_triples = await _query_triples( - ws_url, flow_id, o, collection, graph=explain_graph, debug=debug - ) - if debug: - print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) - # Extract edge and reasoning - edge_triple = None # Store the actual triple for provenance lookup - reasoning = None - for es, ep, eo in edge_triples: - if debug: - print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr) - if ep == TG_EDGE and isinstance(eo, dict): - # eo is a quoted triple dict - edge_triple = eo - elif ep == TG_REASONING: - reasoning = eo - if edge_triple: - # Resolve labels for edge components - s_label, p_label, o_label = await _resolve_edge_labels( - ws_url, flow_id, edge_triple, collection, - label_cache, debug=debug - ) - print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if reasoning: - r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning - print(f" Reason: {r_short}", file=sys.stderr) - - # Trace edge provenance in the workspace collection (not explainability) - if edge_triple: - sources = await _query_edge_provenance( - ws_url, flow_id, - edge_triple.get("s", ""), - edge_triple.get("p", ""), - edge_triple.get("o", ""), - collection, # Use the query collection, not explainability - debug=debug - ) - if sources: - for src in sources: - # Trace full chain from source to root document - chain = await _trace_provenance_chain( - ws_url, flow_id, src, collection, - label_cache, debug=debug - ) - chain_str = _format_provenance_chain(chain) - print(f" Source: {chain_str}", file=sys.stderr) - - elif message_type == "chunk" or not message_type: - # Display response chunk - chunk = resp.get("response", "") - if chunk: - print(chunk, end="", flush=True) - - # Check if session is complete - if resp.get("end_of_session"): - break - - print() # Final newline - - def _question_explainable_api( url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, diff --git a/trustgraph-cli/trustgraph/cli/load_sample_documents.py b/trustgraph-cli/trustgraph/cli/load_sample_documents.py index 0398864c..e35381aa 100644 --- a/trustgraph-cli/trustgraph/cli/load_sample_documents.py +++ b/trustgraph-cli/trustgraph/cli/load_sample_documents.py @@ -1,709 +1,82 @@ """ -Loads a PDF document into the library +Loads sample documents into the TrustGraph library from bundled package data. """ import argparse +import json import os -import uuid -import datetime -import requests +from importlib import resources from trustgraph.api import Api -from trustgraph.api.types import hash, Uri, Literal, Triple +from trustgraph.api.types import Uri, Literal, Triple default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") - -from requests.adapters import HTTPAdapter -from urllib3.response import HTTPResponse - -class FileAdapter(HTTPAdapter): - def send(self, request, *args, **kwargs): - resp = HTTPResponse(body=open(request.url[7:], 'rb'), status=200, preload_content=False) - return self.build_response(request, resp) - -session = requests.session() - -session.mount('file://', FileAdapter()) - -try: - os.mkdir("doc-cache") -except: - pass - -documents = [ - - { - "id": "https://trustgraph.ai/doc/challenger-report-vol-1", - "title": "Report of the Presidential Commission on the Space Shuttle Challenger Accident, Volume 1", - "comments": "The findings of the Commission regarding the circumstances surrounding the Challenger accident are reported and recommendations for corrective action are outlined", - "url": "https://ntrs.nasa.gov/api/citations/19860015255/downloads/19860015255.pdf", - "kind": "application/pdf", - "date": datetime.datetime.now().date(), - "tags": ["nasa", "safety-engineering", "space-shuttle"], - "metadata": [ - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/DigitalDocument") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("Report of the Presidential Commission on the Space Shuttle Challenger Accident, Volume 1") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/name"), - o = Literal("Report of the Presidential Commission on the Space Shuttle Challenger Accident, Volume 1") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/description"), - o = Literal("The findings of the Commission regarding the circumstances surrounding the Challenger accident are reported and recommendations for corrective action are outlined") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/copyrightNotice"), - o = Literal("Work of the US Gov. Public Use Permitted") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/copyrightHolder"), - o = Literal("US Gov.") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/copyrightYear"), - o = Literal("1986") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/keywords"), - o = Literal("nasa") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/keywords"), - o = Literal("space-shuttle") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/keywords"), - o = Literal("safety-engineering") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/keywords"), - o = Literal("challenger") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/keywords"), - o = Literal("space-transportation") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/publication"), - o = Uri("https://trustgraph.ai/pubev/d946c320-0432-48c8-a015-26b0af3cedae") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/d946c320-0432-48c8-a015-26b0af3cedae"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/PublicationEvent") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/d946c320-0432-48c8-a015-26b0af3cedae"), - p = Uri("https://schema.org/description"), - o = Literal("The findings of the Commission regarding the circumstances surrounding the Challenger accident are reported and recommendations for corrective action are outlined") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/d946c320-0432-48c8-a015-26b0af3cedae"), - p = Uri("https://schema.org/publishedBy"), - o = Uri("https://trustgraph.ai/org/nasa") - ), - Triple( - s = Uri("https://trustgraph.ai/org/nasa"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Organization") - ), - Triple( - s = Uri("https://trustgraph.ai/org/nasa"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("NASA") - ), - Triple( - s = Uri("https://trustgraph.ai/org/nasa"), - p = Uri("https://schema.org/name"), - o = Literal("NASA") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/d946c320-0432-48c8-a015-26b0af3cedae"), - p = Uri("https://schema.org/startDate"), - o = Literal("1986-06-06") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/d946c320-0432-48c8-a015-26b0af3cedae"), - p = Uri("https://schema.org/endDate"), - o = Literal("1986-06-06") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/challenger-report-vol-1"), - p = Uri("https://schema.org/url"), - o = Uri("https://ntrs.nasa.gov/api/citations/19860015255/downloads/19860015255.pdf") - ) - ] - }, - - { - "id": "https://trustgraph.ai/doc/icelandic-dictionary", - "title": "A Concise Dictionary of Old Icelandic", - "comments": "A Concise Dictionary of Old Icelandic, published in 1910, is a 551-page dictionary that offers a comprehensive overview of the Old Norse language, particularly Old Icelandic.", - "url": "https://css4.pub/2015/icelandic/dictionary.pdf", - "kind": "application/pdf", - "date": datetime.datetime.now().date(), - "tags": ["old-icelandic", "dictionary", "language", "grammar", "old-norse", "icelandic"], - "metadata": [ - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/DigitalDocument") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("A Concise Dictionary of Old Icelandic"), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/name"), - o = Literal("A Concise Dictionary of Old Icelandic"), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/description"), - o = Literal("A Concise Dictionary of Old Icelandic, published in 1910, is a 551-page dictionary that offers a comprehensive overview of the Old Norse language, particularly Old Icelandic."), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/copyrightNotice"), - o = Literal("Copyright expired, public domain") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/copyrightHolder"), - o = Literal("Geir Zoëga, Clarendon Press") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/copyrightYear"), - o = Literal("1910") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/keywords"), - o = Literal("icelandic") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/keywords"), - o = Literal("old-norse") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/keywords"), - o = Literal("dictionary") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/keywords"), - o = Literal("grammar") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/keywords"), - o = Literal("old-icelandic") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/publication"), - o = Uri("https://trustgraph.ai/pubev/11a78156-3aea-4263-9f1b-0c63cbde69d7") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/11a78156-3aea-4263-9f1b-0c63cbde69d7"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/PublicationEvent") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/11a78156-3aea-4263-9f1b-0c63cbde69d7"), - p = Uri("https://schema.org/description"), - o = Literal("Published by Clarendon Press in 1910"), - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/11a78156-3aea-4263-9f1b-0c63cbde69d7"), - p = Uri("https://schema.org/publishedBy"), - o = Uri("https://trustgraph.ai/org/clarendon-press") - ), - Triple( - s = Uri("https://trustgraph.ai/org/clarendon-press"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Organization") - ), - Triple( - s = Uri("https://trustgraph.ai/org/clarendon-press"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("NASA") - ), - Triple( - s = Uri("https://trustgraph.ai/org/clarendon-press"), - p = Uri("https://schema.org/name"), - o = Literal("Clarendon Press") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/11a78156-3aea-4263-9f1b-0c63cbde69d7"), - p = Uri("https://schema.org/startDate"), - o = Literal("1910-01-01") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/11a78156-3aea-4263-9f1b-0c63cbde69d7"), - p = Uri("https://schema.org/endDate"), - o = Literal("1910-01-01") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/icelandic-dictionary"), - p = Uri("https://schema.org/url"), - o = Uri("https://digital-research-books-beta.nypl.org/edition/10476341") - ) - ] - }, +SAMPLE_DOCS_PACKAGE = "trustgraph.cli.sample_documents" - { - "id": "https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025", - "title": "Annual threat assessment of the U.S. intelligence community - March 2025", - "comments": "The report reflects the collective insights of the Intelligence Community (IC), which is committed to providing the nuanced, independent, and unvarnished intelligence that policymakers, warfighters, and domestic law enforcement personnel need to protect American lives and America’s interests anywhere in the world.", - "url": "https://www.intelligence.senate.gov/sites/default/files/2025%20Annual%20Threat%20Assessment%20of%20the%20U.S.%20Intelligence%20Community.pdf", - "kind": "application/pdf", - "date": datetime.datetime.now().date(), - "tags": ["adversary-cooperation", "cyberthreats", "supply-chain-vulnerabilities", "economic-competition", "national-security", "data-privacy"], - "metadata": [ - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/DigitalDocument") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("Annual threat assessment of the U.S. intelligence community - March 2025"), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/name"), - o = Literal("Annual threat assessment of the U.S. intelligence community - March 2025"), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/description"), - o = Literal("The report reflects the collective insights of the Intelligence Community (IC), which is committed to providing the nuanced, independent, and unvarnished intelligence that policymakers, warfighters, and domestic law enforcement personnel need to protect American lives and America’s interests anywhere in the world."), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/copyrightNotice"), - o = Literal("Not copyright") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/copyrightHolder"), - o = Literal("US Government") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/copyrightYear"), - o = Literal("2025") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/keywords"), - o = Literal("adversary-cooperation") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/keywords"), - o = Literal("cyberthreats") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/keywords"), - o = Literal("supply-chain-vulnerabilities") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/keywords"), - o = Literal("economic-competition") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/keywords"), - o = Literal("national-security") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/publication"), - o = Uri("https://trustgraph.ai/pubev/0f1cfbe2-ce64-403b-8327-799aa8ba3cec") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/0f1cfbe2-ce64-403b-8327-799aa8ba3cec"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/PublicationEvent") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/0f1cfbe2-ce64-403b-8327-799aa8ba3cec"), - p = Uri("https://schema.org/description"), - o = Literal("Published by the Director of National Intelligence (DNI)"), - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/0f1cfbe2-ce64-403b-8327-799aa8ba3cec"), - p = Uri("https://schema.org/publishedBy"), - o = Uri("https://trustgraph.ai/org/us-gov-dni") - ), - Triple( - s = Uri("https://trustgraph.ai/org/us-gov-dni"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Organization") - ), - Triple( - s = Uri("https://trustgraph.ai/org/us-gov-dni"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("The Director of National Intelligence") - ), - Triple( - s = Uri("https://trustgraph.ai/org/us-gov-dni"), - p = Uri("https://schema.org/name"), - o = Literal("The Director of National Intelligence") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/0f1cfbe2-ce64-403b-8327-799aa8ba3cec"), - p = Uri("https://schema.org/startDate"), - o = Literal("2025-03-18") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/0f1cfbe2-ce64-403b-8327-799aa8ba3cec"), - p = Uri("https://schema.org/endDate"), - o = Literal("2025-03-18") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/annual-threat-assessment-us-dni-march-2025"), - p = Uri("https://schema.org/url"), - o = Uri("https://www.dni.gov/index.php/newsroom/reports-publications/reports-publications-2025/4058-2025-annual-threat-assessment") - ) - ] - }, +def get_data_path(): + return resources.files(SAMPLE_DOCS_PACKAGE) - { - "id": "https://trustgraph.ai/doc/intelligence-and-state", - "title": "The Role of Intelligence and State Policies in International Security", - "comments": "A volume by Mehmet Emin Erendor, published by Cambridge Scholars Publishing (2021). It is well-known that the understanding of security has changed since the end of the Cold War. This, in turn, has impacted the characteristics of intelligence, as states have needed to improve their security policies with new intelligence tactics. This volume investigates this new state of play in the international arena.", - "url": "https://www.cambridgescholars.com/resources/pdfs/978-1-5275-7604-9-sample.pdf", - "kind": "application/pdf", - "date": "2025-05-06", - "tags": ["intelligence", "state-policy", "international-security", "national-security", "geopolitics", "foreign-policy", "security-studies", "military", "crime"], - "metadata": [ - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Book") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("The Role of Intelligence and State Policies in International Security") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/name"), - o = Literal("The Role of Intelligence and State Policies in International Security") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/description"), - o = Literal("A volume by Mehmet Emin Erendor. It is well-known that the understanding of security has changed since the end of the Cold War. This, in turn, has impacted the characteristics of intelligence, as states have needed to improve their security policies with new intelligence tactics. This volume investigates this new state of play in the international arena.") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/author"), - o = Literal("Mehmet Emin Erendor") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/isbn"), - o = Literal("9781527576049") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/numberOfPages"), - o = Literal("220") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("intelligence") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("state policy") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("international security") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("national security") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("geopolitics") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/publication"), - o = Uri("https://trustgraph.ai/pubev/b4352222-5da0-480d-a00f-f7342fe77862") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/b4352222-5da0-480d-a00f-f7342fe77862"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/PublicationEvent") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/b4352222-5da0-480d-a00f-f7342fe77862"), - p = Uri("https://schema.org/description"), - o = Literal("Published by Cambridge Scholars Publishing on October 28, 2021.") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/b4352222-5da0-480d-a00f-f7342fe77862"), - p = Uri("https://schema.org/publishedBy"), - o = Uri("https://trustgraph.ai/org/cambridge-scholars-publishing") - ), - Triple( - s = Uri("https://trustgraph.ai/org/cambridge-scholars-publishing"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Organization") - ), - Triple( - s = Uri("https://trustgraph.ai/org/cambridge-scholars-publishing"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("Cambridge Scholars Publishing") - ), - Triple( - s = Uri("https://trustgraph.ai/org/cambridge-scholars-publishing"), - p = Uri("https://schema.org/name"), - o = Literal("Cambridge Scholars Publishing") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/b4352222-5da0-480d-a00f-f7342fe77862"), - p = Uri("https://schema.org/startDate"), - o = Literal("2021-10-28") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/intelligence-and-state"), - p = Uri("https://schema.org/url"), - o = Uri("https://www.cambridgescholars.com/resources/pdfs/978-1-5275-7604-9-sample.pdf") - ) - ] - }, - { - "id": "https://trustgraph.ai/doc/beyond-vigilant-state", - "title": "Beyond the vigilant state: globalisation and intelligence", - "comments": "This academic paper by Richard J. Aldrich examines the relationship between globalization and intelligence agencies, discussing how intelligence services have adapted to global changes in the post-Cold War era.", - "url": "https://warwick.ac.uk/fac/soc/pais/people/aldrich/publications/beyond.pdf", - "kind": "application/pdf", - "date": datetime.datetime.now().date(), - "tags": ["intelligence", "globalization", "security-studies", "surveillance", "international-relations", "post-cold-war"], - "metadata": [ - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/ScholarlyArticle") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("Beyond the vigilant state: globalisation and intelligence"), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/name"), - o = Literal("Beyond the vigilant state: globalisation and intelligence"), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/description"), - o = Literal("This academic paper by Richard J. Aldrich examines the relationship between globalization and intelligence agencies, discussing how intelligence services have adapted to global changes in the post-Cold War era."), - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/copyrightNotice"), - o = Literal("(c) British International Studies Association") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/copyrightHolder"), - o = Literal("British International Studies Association") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/author"), - o = Uri("https://trustgraph.ai/person/3a45f8c9-b7d1-42e5-8631-d9f82c4a0e22") - ), - Triple( - s = Uri("https://trustgraph.ai/person/3a45f8c9-b7d1-42e5-8631-d9f82c4a0e22"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Person") - ), - Triple( - s = Uri("https://trustgraph.ai/person/3a45f8c9-b7d1-42e5-8631-d9f82c4a0e22"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("Richard J. Aldrich") - ), - Triple( - s = Uri("https://trustgraph.ai/person/3a45f8c9-b7d1-42e5-8631-d9f82c4a0e22"), - p = Uri("https://schema.org/name"), - o = Literal("Richard J. Aldrich") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("intelligence") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("globalisation") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("security-studies") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("surveillance") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("international-relations") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/keywords"), - o = Literal("post-cold-war") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/publication"), - o = Uri("https://trustgraph.ai/pubev/75c83dfa-6b2e-4d89-bda1-c8e92f0e3410") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/75c83dfa-6b2e-4d89-bda1-c8e92f0e3410"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/PublicationEvent") - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/75c83dfa-6b2e-4d89-bda1-c8e92f0e3410"), - p = Uri("https://schema.org/description"), - o = Literal("Published in Review of International Studies"), - ), - Triple( - s = Uri("https://trustgraph.ai/pubev/75c83dfa-6b2e-4d89-bda1-c8e92f0e3410"), - p = Uri("https://schema.org/publishedBy"), - o = Uri("https://trustgraph.ai/org/british-international-studies-association") - ), - Triple( - s = Uri("https://trustgraph.ai/org/british-international-studies-association"), - p = Uri("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), - o = Uri("https://schema.org/Organization") - ), - Triple( - s = Uri("https://trustgraph.ai/org/british-international-studies-association"), - p = Uri("http://www.w3.org/2000/01/rdf-schema#label"), - o = Literal("British International Studies Association") - ), - Triple( - s = Uri("https://trustgraph.ai/org/british-international-studies-association"), - p = Uri("https://schema.org/name"), - o = Literal("British International Studies Association") - ), - Triple( - s = Uri("https://trustgraph.ai/doc/beyond-vigilant-state"), - p = Uri("https://schema.org/url"), - o = Uri("https://warwick.ac.uk/fac/soc/pais/people/aldrich/publications/beyond.pdf") - ) - ] - } +def load_metadata(): + data_path = get_data_path() + metadata_file = data_path / "metadata.json" + return json.loads(metadata_file.read_text(encoding="utf-8")) -] -class Loader: +def convert_value(v): + if v["type"] == "uri": + return Uri(v["value"]) + else: + return Literal(v["value"]) - def __init__( - self, url, token=None, workspace="default", - ): - self.api = Api(url, token=token, workspace=workspace).library() +def convert_metadata(metadata_json): + triples = [] + for t in metadata_json: + triples.append(Triple( + s=convert_value(t["s"]), + p=convert_value(t["p"]), + o=convert_value(t["o"]), + )) + return triples - def load(self, documents): - for doc in documents: - self.load_doc(doc) +def load_document(api, doc_entry, data_path): - def load_doc(self, doc): + doc_id = doc_entry["id"] + title = doc_entry["title"] + filename = doc_entry["file"] - try: + print(f" [{filename}] {title}") - print(doc["title"], ":") + print(f" reading content...") + content_file = data_path / filename + content = content_file.read_bytes() - hid = hash(doc["url"]) - cache_file = f"doc-cache/{hid}" + print(f" loading into TrustGraph ({len(content) // 1024}KB)...") + metadata = convert_metadata(doc_entry["metadata"]) - if os.path.isfile(cache_file): - print(" (use cache file)") - content = open(cache_file, "rb").read() - else: - print(" downloading...") - resp = session.get(doc["url"]) - content = resp.content - open(cache_file, "wb").write(content) - print(" done.") + api.add_document( + id=doc_id, + metadata=metadata, + kind=doc_entry["kind"], + title=title, + comments=doc_entry["comments"], + tags=doc_entry["tags"], + document=content, + ) - print(" adding...") + print(f" done.") - self.api.add_document( - id=doc["id"], metadata=doc["metadata"], - kind=doc["kind"], title=doc["title"], - comments=doc["comments"], tags=doc["tags"], - document=content, - ) - - print(" successful.") - - except Exception as e: - print("Failed: {str(e)}", flush=True) - raise e def main(): parser = argparse.ArgumentParser( - prog='tg-add-library-document', + prog='tg-load-sample-documents', description=__doc__, ) @@ -729,18 +102,27 @@ def main(): try: - p = Loader( - url=args.url, - token=args.token, - workspace=args.workspace, - ) + api = Api(args.url, token=args.token, workspace=args.workspace) + library = api.library() - p.load(documents) + data_path = get_data_path() + documents = load_metadata() + + print(f"Loading {len(documents)} sample document(s)...\n") + + for doc in documents: + try: + load_document(library, doc, data_path) + except Exception as e: + print(f" FAILED: {e}") + print() + + print("Complete.") except Exception as e: - - print("Exception:", e, flush=True) + print(f"Exception: {e}") raise e + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/put_de_core.py b/trustgraph-cli/trustgraph/cli/put_de_core.py new file mode 100644 index 00000000..1d6589af --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/put_de_core.py @@ -0,0 +1,119 @@ +""" +Puts a document embeddings core into the knowledge manager via the API +socket. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): + + if unpacked[0] == "de": + msg = unpacked[1] + return { + "metadata": { + "id": id, + "root": msg["m"]["m"], + "collection": "default", + }, + "chunks": [ + { + "chunk_id": ch["i"], + "vector": ch["v"], + } + for ch in msg["c"] + ], + } + else: + raise RuntimeError("Unexpected message type", unpacked[0]) + +def put(url, workspace, id, input, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(input, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + while True: + + try: + unpacked = unpacker.unpack() + except msgpack.OutOfData: + break + + msg = read_message(unpacked, id) + de += 1 + socket.put_de_core(id, document_embeddings=msg) + + print(f"Put: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-put-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help=f'Input file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index bd3169c8..fe0981a5 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,13 +19,13 @@ def read_message(unpacked, id): return "ge", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used? + "root": msg["m"]["m"], + "collection": "default", }, "entities": [ { "entity": ent["e"], - "vectors": ent["v"], + "vector": ent["v"], } for ent in msg["e"] ], @@ -37,26 +35,20 @@ def read_message(unpacked, id): return "t", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used by receiver? + "root": msg["m"]["m"], + "collection": "default", }, "triples": msg["t"], } else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, workspace, id, input, token=None): +def put(url, workspace, id, input, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - async with connect(url) as ws: + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 @@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None): try: unpacked = unpacker.unpack() - except: + except msgpack.OutOfData: break kind, msg = read_message(unpacked, id) - mid = str(uuid.uuid4()) - if kind == "ge": - ge += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "graph-embeddings": msg - } - }) + socket.put_kg_core(id, graph_embeddings=msg) elif kind == "t": - t += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "triples": msg - } - }) + socket.put_kg_core(id, triples=msg) else: - raise RuntimeError("Unexpected message kind", kind) - await ws.send(req) - - # Retry loop, wait for right response to come back - while True: - - msg = await ws.recv() - msg = json.loads(msg) - - if msg["id"] != mid: - continue - - if "response" in msg: - if "error" in msg["response"]: - raise RuntimeError(msg["response"]["error"]) - - break - print(f"Put: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -173,14 +122,12 @@ def main(): try: - asyncio.run( - put( - url=args.url, - workspace=args.workspace, - id=args.id, - input=args.input, - token=args.token, - ) + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/__init__.py b/trustgraph-cli/trustgraph/cli/sample_documents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/belgian-beer.md b/trustgraph-cli/trustgraph/cli/sample_documents/belgian-beer.md new file mode 100644 index 00000000..fe704611 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/sample_documents/belgian-beer.md @@ -0,0 +1,53 @@ +# The Brewing Traditions of Belgium + +## Section 1: The Fermentation Foundations (Yeasts & Styles) + +### Component A: The Ardennes Isolation Strain + +- **Alternative Names:** Strain-71, The Trappist Yeast, Wallonian Culture. +- **Origin:** Historically isolated in the rugged forests of the Ardennes region, shared via mutual trade agreements between various monastic brewing sites. +- **Characteristics:** A high-attenuation, top-fermenting yeast culture (*Saccharomyces cerevisiae*) that thrives at elevated temperatures (20°C to 26°C). It produces heavy volatile esters reminiscent of clove, banana, and white pepper. +- **Downstream Dependencies:** Essential biological engine required to brew The Westvleteren Quad and The Chimay Grand Réserve. + +### Component B: Brettanomyces Bruxellensis + +- **Alternative Names:** Wild Yeast, The Brussels Funk, Pajottenland Air. +- **Origin:** Indigenous entirely to the Senne River Valley and the surrounding Pajottenland region southwest of Brussels. It cannot be cultured in a standard laboratory setting for primary fermentation; it must be caught ambiently from the atmosphere. +- **Characteristics:** A slow-acting, wild yeast strain that consumes complex sugars that standard yeasts cannot digest. It introduces lactic acid and earthy, "barnyard" aroma characteristics over a 1 to 3-year aging cycle. +- **Downstream Dependencies:** Biological prerequisite for Oude Geuze and The Flemish Red Ale. + +## Section 2: Monastic & Trappist Hierarchies (Appellation Controlled) + +### Beer 1: The Westvleteren Quad + +- **Alternative Names:** Westvleteren 12, The Yellow Cap. +- **Origin:** Brewed exclusively inside the walls of the Abbey of Saint-Sixtus in Westvleteren, Flanders. Holds the strict "Authentic Trappist Product" (ATP) legal designation. +- **Ingredients:** The Ardennes Isolation Strain, local soft water, pale malt, dark liquid candi sugar (sucrose solution), and Northern Brewer hops. +- **Process:** Primary fermentation utilizing the Ardennes strain for 7 days. Afterward, dark candi sugar is injected into the green beer to trigger a secondary fermentation stage. Crucially, the beer is bottled completely unfiltered with active yeast cells, requiring a mandatory 3-month cellar conditioning period to carbonate inside the bottle. + +### Beer 2: The Chimay Grand Réserve + +- **Alternative Names:** Chimay Blue, The Grande Réserve. +- **Origin:** Brewed inside the Scourmont Abbey in Hainaut, Wallonia. Also carries the ATP designation. +- **Ingredients:** The Ardennes Isolation Strain, estate-drawn well water, malted barley, Hallertau Mittelfrüh hops, and caramelized sugar. +- **Process:** Follows a parallel fermentation profile to the Westvleteren Quad, using the exact same ancestral yeast strain but utilizing a different mineral profile in the water, resulting in a drier, more dark-fruit-forward profile. + +## Section 3: Spontaneous & Sour Traditions (Wild Ecosystems) + +### Beer 3: Oude Geuze + +- **Alternative Names:** The Champagne of Belgium, Brussels Lambic. +- **Origin:** The Pajottenland region. It is legally protected; it cannot be called "Oude Geuze" unless it is spontaneously fermented by the regional air. +- **Ingredients:** Unmalted wheat (30%), Pale barley malt (70%), aged "suranné" hops (which lose their bitterness but retain preservative qualities), and ambient Brettanomyces Bruxellensis. +- **Process:** Boiling wort is pumped into an open-air shallow vessel called a "coolship" overnight to cool down, absorbing wild microbes from the Senne Valley breeze. + +**The Blending Protocol Dispute (Critical Logic Test):** + +- **The Traditionalist Assembly:** A true Oude Geuze is a blend of 1-year-old young lambic (which provides active sugars) and 3-year-old vintage lambic (which provides complex sourness). +- **The Industrial Controversy:** Some macro-breweries pasteurize the blend and inject artificial sweeteners (aspartame) to neutralize the sourness for commercial appeal. Traditionalists argue this strips the product of its geographic identity and violates the "Oude" (Old) designation. + +### Beer 4: The Flemish Red Ale + +- **Alternative Names:** Rodenbach style, West-Flemish Sour. +- **Origin:** Roeselare, West Flanders. +- **Ingredients:** Red-kilned malts, aged hops, standard top-fermenting yeast, and a secondary inoculation of Brettanomyces Bruxellensis. diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/bronze-age-collapse.pdf b/trustgraph-cli/trustgraph/cli/sample_documents/bronze-age-collapse.pdf new file mode 100644 index 00000000..e7e3dbe5 Binary files /dev/null and b/trustgraph-cli/trustgraph/cli/sample_documents/bronze-age-collapse.pdf differ diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/corporate-scandals.pdf b/trustgraph-cli/trustgraph/cli/sample_documents/corporate-scandals.pdf new file mode 100644 index 00000000..012d0ad4 Binary files /dev/null and b/trustgraph-cli/trustgraph/cli/sample_documents/corporate-scandals.pdf differ diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/history-of-pets.md b/trustgraph-cli/trustgraph/cli/sample_documents/history-of-pets.md new file mode 100644 index 00000000..932b17f7 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/sample_documents/history-of-pets.md @@ -0,0 +1,13 @@ +# The Domestic Canopy: A Unified Narrative of Companionship + +The story of the human-animal bond begins not with a conscious decision to breed a companion, but with an ancient, mutual opportunism in the frozen wastes of the late Pleistocene. Long before the advent of agriculture, the gray wolf (*Canis lupus*) began to separate from its wild packs, drawn to the peripheral campfires of Eurasian hunter-gatherers. These ancestral canids, which would morph over millennia into the domesticated dog (*Canis lupus familiaris*), offered early humans an unparalleled early-warning system against apex predators and an invaluable partner in the persistence hunt. In return, humans provided a steady supply of megafauna marrow, cooked gristle, and proximity to warmth. This biological pact was so profound that it transcended mere utility, as evidenced by the late Paleolithic Natufian burial sites in the Levant, where human skeletons were interred with their hands resting gently upon the ribcages of wolf pups, marking the earliest archaeological signature of the transition from working tool to sentimental proxy. + +As the ice sheets retreated and humanity anchored itself to the soil during the Neolithic Revolution, the nature of animal companionship shifted dramatically, giving rise to an entirely different ecological dynamic in the Fertile Crescent. The rise of grain storehouses in ancient Egypt attracted unprecedented swarms of rodents, creating a pristine ecological niche for the North African wildcat (*Felis lybica*). Unlike the highly structured social hierarchy of the wolf, the cat domesticated itself on terms of aloof independence, transitioning from a tolerated pest-control mechanism to a revered icon of divine protection. By the time of the Egyptian Middle Kingdom, cats were so thoroughly integrated into the domestic fabric that they were granted formal mourning rites; Roman historians like Herodotus noted that when a house cat died of natural causes, the entire human household would shave their eyebrows as a public manifestation of grief. These felines were often mummified using the same costly natron resins reserved for the nobility and entombed in specialized necropolises like Bubastis, dedicated to the feline-headed deity Bastet, effectively blending religious cosmology with domestic affection. + +Parallel developments were unfolding across the globe, creating distinct regional pockets of companionship that would later collide through imperial trade. In the Andean highlands of South America, the Incas domesticated the guinea pig (*Cavia porcellus*), known locally as the cuy. While primarily a source of protein and a diagnostic tool used by folk healers to absorb illness from the sick, select lineages were kept by children as cherished house-dwellers. Meanwhile, in the imperial courts of Han Dynasty China, a parallel phenomenon saw the intensive breeding of the Pekingese dog. These small, flat-faced canids were selectively bred to resemble miniature lions — the mythical protectors of Buddhism — and were guarded so fiercely within the walls of the Forbidden City that stealing one was punishable by death. They lived a life of pampered luxury, carried in the sleeves of silk robes and tended to by dedicated eunuchs, establishing an early historical precedent where certain animal breeds functioned strictly as status symbols and manifestations of political sovereignty rather than utilitarian workers. + +The classical antiquity of Europe further complicated this tapestry, as the Roman elite integrated exoticism into their definition of the domestic sphere. Roman matrons frequently kept ring-necked parakeets (*Psittacula krameri*) imported from the conquests of India, housing them in elaborate cages of ivory and silver, and teaching them to speak the name of the Emperor. Concurrently, the Roman fondness for the ferret (*Mustela furo*) emerged as a dual-purpose phenomenon; these mustelids were kept both to flush rabbits from agricultural burrows and as slinky, playful companions within the villa. This Roman domestic ecosystem was heavily documented by Pliny the Elder in his Natural History, where he noted that the elite often developed deep, seemingly irrational emotional attachments to their companion animals, including pet fish like the moray eel, which the orator Hortensius reportedly wept over when it died in his private ornamental pond. + +The medieval period in Europe introduced a sharp class divide to the concept of the pet, often viewed through the suspicious lens of ecclesiastical authority. While the peasantry kept functional yard dogs and barn cats, the high nobility — particularly noblewomen and monastic figures — indulged in the keeping of lapdogs, such as the early Maltese, and refined birds of prey. These lapdogs were often criticized by conservative church theologians who argued that the excessive meat fed to pampered pets belonged in the mouths of the starving peasantry. Furthermore, during the height of the European witch trials, the domestic pet — particularly the black cat, the toad, or the ferret — was frequently demonized by inquisitors as a "familiar," a physical vessel housing a demonic spirit. This created a perilous cultural paradox where an animal could be viewed as a comforting hearth-companion in one household and an existential piece of heretical evidence in another. + +The modern concept of pet-keeping as a universal consumer phenomenon crystallized during the Industrial Revolution and the rise of the Victorian middle class. As populations migrated from rural farms to dense urban centers, the severed connection to nature triggered a romanticized counter-movement. The Victorians elevated the domestic home into a moral sanctuary, and the pet was introduced as a pedagogical tool to teach children empathy, kindness, and middle-class domestic virtues. This era saw the birth of the commercial pet industry: standard kibble formulations were patented by James Spratt in the 1860s, dog shows like Crufts were established to formalize breed standards, and specialized pet cemeteries, like the one in London's Hyde Park, emerged to afford animals a dignified transition into the afterlife. The pet was no longer a working asset or an eccentric luxury of the aristocratic elite; it had become an institutionalized member of the nuclear family unit, setting the stage for the hyper-commodified, emotionally complex multi-billion dollar pet industry of the contemporary era. diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/metadata.json b/trustgraph-cli/trustgraph/cli/sample_documents/metadata.json new file mode 100644 index 00000000..c11d633e --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/sample_documents/metadata.json @@ -0,0 +1,527 @@ +[ + { + "id": "https://trustgraph.ai/doc/west-country-recipes", + "title": "The Foundations of West Country Cooking", + "comments": "A structured guide to traditional West Country recipes including clotted cream, scones, hedgerow jam, the cream tea, Cornish pasty, homity pie, and Devonshire fudge, with regional supply chain dependencies.", + "file": "recipes.md", + "kind": "text/markdown", + "tags": ["cooking", "west-country", "devon", "cornwall", "recipes", "food-history"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "The Foundations of West Country Cooking"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "The Foundations of West Country Cooking"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "A structured guide to traditional West Country recipes including clotted cream, scones, hedgerow jam, the cream tea, Cornish pasty, homity pie, and Devonshire fudge, with regional supply chain dependencies."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "cooking"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "west-country"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "devon"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "cornwall"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "recipes"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/west-country-recipes"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "food-history"} + } + ] + }, + { + "id": "https://trustgraph.ai/doc/belgian-beer", + "title": "The Brewing Traditions of Belgium", + "comments": "An exploration of Belgian brewing traditions covering Trappist yeasts, wild fermentation, monastic beers including Westvleteren and Chimay, spontaneous sour traditions like Oude Geuze and Flemish Red Ale.", + "file": "belgian-beer.md", + "kind": "text/markdown", + "tags": ["brewing", "belgium", "beer", "trappist", "fermentation", "food-history"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "The Brewing Traditions of Belgium"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "The Brewing Traditions of Belgium"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "An exploration of Belgian brewing traditions covering Trappist yeasts, wild fermentation, monastic beers including Westvleteren and Chimay, spontaneous sour traditions like Oude Geuze and Flemish Red Ale."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "brewing"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "belgium"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "beer"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "trappist"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "fermentation"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/belgian-beer"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "food-history"} + } + ] + }, + { + "id": "https://trustgraph.ai/doc/trade-routes-europe", + "title": "Traditional Trade Routes of Pre-Modern Europe", + "comments": "A structured overview of pre-modern European trade networks including the Hanseatic Baltic Route and Venetian Maritime Route, hub cities like Bruges, Lübeck, and Constantinople, specialized commodities, and geopolitical chokepoints.", + "file": "trade-routes-europe.md", + "kind": "text/markdown", + "tags": ["trade", "medieval", "europe", "hanseatic-league", "venice", "economic-history"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "Traditional Trade Routes of Pre-Modern Europe"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "Traditional Trade Routes of Pre-Modern Europe"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "A structured overview of pre-modern European trade networks including the Hanseatic Baltic Route and Venetian Maritime Route, hub cities like Bruges, Lübeck, and Constantinople, specialized commodities, and geopolitical chokepoints."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "trade"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "medieval"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "europe"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "hanseatic-league"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "venice"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/trade-routes-europe"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "economic-history"} + } + ] + }, + { + "id": "https://trustgraph.ai/doc/corporate-scandals", + "title": "Global Corporate Fraud & Governance Failures", + "comments": "Detailed analysis of major corporate fraud cases including Enron and Wirecard, covering the financial engineering mechanisms, key executives, audit failures, whistleblowers, and collapse sequences.", + "file": "corporate-scandals.pdf", + "kind": "application/pdf", + "tags": ["corporate-fraud", "enron", "wirecard", "accounting", "governance", "financial-crime"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "Global Corporate Fraud & Governance Failures"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "Global Corporate Fraud & Governance Failures"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "Detailed analysis of major corporate fraud cases including Enron and Wirecard, covering the financial engineering mechanisms, key executives, audit failures, whistleblowers, and collapse sequences."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "corporate-fraud"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "enron"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "wirecard"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "accounting"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "governance"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/corporate-scandals"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "financial-crime"} + } + ] + }, + { + "id": "https://trustgraph.ai/doc/history-of-pets", + "title": "The Domestic Canopy: A Unified Narrative of Companionship", + "comments": "A narrative history of human-animal companionship from Pleistocene wolf domestication through Egyptian cat worship, Roman exotic pets, medieval class divides, to the Victorian birth of the commercial pet industry.", + "file": "history-of-pets.md", + "kind": "text/markdown", + "tags": ["pets", "domestication", "animal-history", "cultural-history", "dogs", "cats"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "The Domestic Canopy: A Unified Narrative of Companionship"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "The Domestic Canopy: A Unified Narrative of Companionship"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "A narrative history of human-animal companionship from Pleistocene wolf domestication through Egyptian cat worship, Roman exotic pets, medieval class divides, to the Victorian birth of the commercial pet industry."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "pets"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "domestication"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "animal-history"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "cultural-history"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "dogs"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/history-of-pets"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "cats"} + } + ] + }, + { + "id": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c", + "title": "Military Fortifications in 19th Century America", + "comments": "The evolution of American military fortification from the Third System masonry forts through Civil War obsolescence to earthwork defenses, plus the parallel development of frontier forts for westward expansion.", + "file": "mil-fortifications-america-19th-c.md", + "kind": "text/markdown", + "tags": ["military-history", "fortifications", "civil-war", "engineering", "american-history", "coastal-defense"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "Military Fortifications in 19th Century America"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "Military Fortifications in 19th Century America"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "The evolution of American military fortification from the Third System masonry forts through Civil War obsolescence to earthwork defenses, plus the parallel development of frontier forts for westward expansion."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "military-history"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "fortifications"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "civil-war"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "engineering"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "american-history"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/mil-fortifications-america-19th-c"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "coastal-defense"} + } + ] + }, + { + "id": "https://trustgraph.ai/doc/bronze-age-collapse", + "title": "Echoes of the Void: The Late Bronze Age Collapse", + "comments": "A synthesis of the Late Bronze Age Collapse (c. 1200-1150 BCE) covering the interconnected trade networks, the cascade of failure from drought through the Sea Peoples to internal revolt, archaeological evidence, and the transition to the Iron Age.", + "file": "bronze-age-collapse.pdf", + "kind": "application/pdf", + "tags": ["bronze-age", "ancient-history", "archaeology", "mediterranean", "trade", "collapse"], + "metadata": [ + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"type": "uri", "value": "https://schema.org/DigitalDocument"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "http://www.w3.org/2000/01/rdf-schema#label"}, + "o": {"type": "literal", "value": "Echoes of the Void: The Late Bronze Age Collapse"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/name"}, + "o": {"type": "literal", "value": "Echoes of the Void: The Late Bronze Age Collapse"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/description"}, + "o": {"type": "literal", "value": "A synthesis of the Late Bronze Age Collapse (c. 1200-1150 BCE) covering the interconnected trade networks, the cascade of failure from drought through the Sea Peoples to internal revolt, archaeological evidence, and the transition to the Iron Age."} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightNotice"}, + "o": {"type": "literal", "value": "Original content, public domain"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightHolder"}, + "o": {"type": "literal", "value": "TrustGraph AI"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/copyrightYear"}, + "o": {"type": "literal", "value": "2025"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "bronze-age"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "ancient-history"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "archaeology"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "mediterranean"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "trade"} + }, + { + "s": {"type": "uri", "value": "https://trustgraph.ai/doc/bronze-age-collapse"}, + "p": {"type": "uri", "value": "https://schema.org/keywords"}, + "o": {"type": "literal", "value": "collapse"} + } + ] + } +] diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/mil-fortifications-america-19th-c.md b/trustgraph-cli/trustgraph/cli/sample_documents/mil-fortifications-america-19th-c.md new file mode 100644 index 00000000..0b15863c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/sample_documents/mil-fortifications-america-19th-c.md @@ -0,0 +1,11 @@ +# Military Fortifications in 19th Century America + +The evolution of coastal and frontier defense across North America during the nineteenth century reflects a turbulent transition from traditional European masonry concepts to the brutal realities of industrialized warfare. At the dawn of the century, the young United States found its sprawling coastline dangerously exposed to the naval might of European empires. In response, the federal government embarked on a massive, highly centralized building program known as the Third System of fortifications. Orchestrated primarily by the newly formed U.S. Army Corps of Engineers and heavily influenced by French military engineer Simon Bernard, this system sought to seal off strategic harbors, naval shipyards, and commercial estuaries from maritime invasion. These fortifications were characterized by massive, multi-tiered masonry walls constructed of brick and stone, designed to mount several tiers of heavy cannon firing through vaulted casemates, thereby concentrating overwhelming firepower against hostile warships. + +The architectural pinnacle of this philosophy was realized in structures like Fort Jefferson, situated on a remote key in the Dry Tortugas of the Gulf of Mexico, and Fort Sumter, guarding the entrance to Charleston Harbor. These fortresses were essentially artificial islands of masonry, featuring complex geometric designs — often pentagonal or hexagonal — to eliminate dead angles where an enemy could seek shelter from the garrison's fire. The walls were constructed using millions of locally fired bricks, backed by concrete and earth, creating a dense barrier meant to absorb the impact of smoothbore solid shot. Within these structures, the invention of the Totten shutter — a pair of iron doors that automatically closed over the cannon embrasure after firing — protected the artillerists from incoming musket fire and grape shot. For the first half of the century, these towering masonry sentinels were considered functionally impregnable to naval assault, as wooden warships could rarely sustain the prolonged, concentrated bombardment required to breach such thick brick facades. + +However, the catastrophic vulnerabilities of the Third System were violently exposed during the American Civil War, rendering traditional masonry fortifications obsolete almost overnight. The catalyst for this military revolution was the introduction of the rifled cannon, most notably the James and Parrott rifles. Unlike the round, smoothbore cannonballs that shattered against brickwork with diminishing effect, elongated rifled projectiles spun like rifle bullets, striking the masonry with immense kinetic energy and drilling into the brickwork with a devastating, jackhammer-like effect. This paradigm shift was demonstrated at the Siege of Fort Pulaski near Savannah, Georgia, in April 1862. Union forces stationed on Tybee Island opened fire with rifled artillery from a distance of over a mile — a range previously considered entirely safe by the fort's defenders. Within thirty hours, the rifled shells tore through the massive brick walls of Fort Pulaski, breaching the solid masonry and threatening to ignite the fort's main powder magazine, forcing an immediate surrender. + +This sudden obsolescence forced military engineers to radically re-evaluate defensive architecture, pivoting away from vertical masonry toward low-profile earthworks and subterranean engineering. It was discovered that simple mounds of loose sand and compacted earth absorbed the impact of rifled shells far better than rigid brick, as the displaced soil naturally filled in the craters left by explosions. This gave rise to formidable improvisations like Fort Fisher in North Carolina, often dubbed the "Malachite of America." Fort Fisher was an immense earthen stronghold constructed of sand face-traverses and underground bombproofs, which successfully withstood the largest naval bombardments of the war because the Union fleet's shells merely rearranged the sand rather than shattering a structural foundation. Consequently, post-Civil War modifications to coastal defense saw engineers cutting down the towering brick walls of older forts, burying them under massive earth glacis, and preparing the way for the Endicott System at the end of the century, which utilized reinforced concrete, low-profile designs, and disappearing guns. + +Concurrently, a completely different doctrine of military architecture was unfolding along the interior frontiers of the continent, where the purpose of fortification was not to resist heavy naval artillery, but to project geopolitical power, control trade routes, and subjugate Indigenous populations. These interior strongholds, such as Fort Laramie in Wyoming or Fort Snelling in Minnesota, abandoned the complex geometry and multi-tiered casemates of coastal engineering in favor of practical, localized utility. Often constructed initially of timber palisades or sun-dried adobe brick depending on the regional geography, these frontier forts served as fortified outposts for the U.S. Army, fur trading companies, and westward migrants. Rather than being designed to withstand a siege by a peer military force, their layouts usually featured a wide, open central parade ground surrounded by barracks, officer quarters, and a defensive perimeter designed to repel swift cavalry raids, protect supply depots, and enforce the shifting boundaries of American westward expansion. diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/recipes.md b/trustgraph-cli/trustgraph/cli/sample_documents/recipes.md new file mode 100644 index 00000000..33a82a92 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/sample_documents/recipes.md @@ -0,0 +1,70 @@ +# The Foundations of West Country Cooking + +## Section 1: The Foundations (Primary Ingredients & Sub-Recipes) + +### Component A: West Country Clotted Cream + +- **Alternative Names:** Devonshire Cream, Cornish Cream. +- **Origin:** Universally produced across the pastures of Devon and Cornwall, utilizing milk from Red Ruby Devon cattle. +- **Ingredients:** 2 Litres of Unpasteurized Whole Milk (High-fat dairy). +- **Process:** Pour the raw milk into a shallow brass pan. Allow it to sit for 12 hours until a thick layer of cream rises to the surface. Heat the pan slowly over a low, indirect flame (traditionally over a wood-fired stove) until the cream begins to "crinkle" but never boil. Remove from heat and cool in a larder for 24 hours. Gently skim the thick, golden crust from the top. +- **Downstream Dependencies:** Essential component for The South West Cream Tea and Devonshire Fudge. + +### Component B: Sweet Scones + +- **Alternative Names:** Hearth Cakes, Country Splits. +- **Origin:** Common across Somerset, Dorset, Devon, and Cornwall. +- **Ingredients:** 450g Self-Raising Flour, 100g Salted Butter, 50g Caster Sugar, 1 pinch of Salt, 250ml Whole Milk. +- **Process:** Rub the salted butter into the self-raising flour until it resembles fine breadcrumbs. Stir in the caster sugar and salt. Pour in the whole milk gradually, mixing with a blunt knife until a soft dough forms. Roll out on a floured surface to 2cm thickness and stamp out rounds. Bake at 220°C for 12 minutes until golden. +- **Downstream Dependencies:** Required component for The South West Cream Tea. + +### Component C: Hedgerow Jam + +- **Alternative Names:** Whortleberry Preserve, Blackberry Jam. +- **Origin:** Produced extensively in the Exmoor and Dartmoor regions of Devon and Somerset. +- **Ingredients:** 1kg Wild Blackberries (Whortleberries), 1kg Granulated Sugar, Juice of 1 Lemon (for pectin). +- **Process:** Combine blackberries, sugar, and lemon juice in a large copper preserving pan. Heat gently until the sugar dissolves entirely. Bring to a rolling boil for 15 minutes until the setting point of 105°C is reached. Pour into sterilized jars. +- **Downstream Dependencies:** Required component for The South West Cream Tea. + +## Section 2: Regional Showcases (Assembled Dishes) + +### Dish 1: The South West Cream Tea + +- **Alternative Names:** Afternoon Tea, Devonshire Tea. +- **Origin:** Heavily disputed between the historical houses of Devon and Cornwall. +- **Core Components:** Requires 1 batch of Sweet Scones, 1 jar of Hedgerow Jam, and a generous portion of West Country Clotted Cream. Always served alongside a pot of hot Black Tea. + +**The Assembly Protocol Dispute (Critical Logic Test):** + +- **The Devon Method:** The scone is split in half. The West Country Clotted Cream is spread first onto the warm scone, acting as butter, and is then topped with a dollop of Hedgerow Jam. +- **The Cornish Method:** The scone is split in half. The Hedgerow Jam is spread directly onto the scone first, and is subsequently topped with a spoonful of West Country Clotted Cream. + +### Dish 2: The Cornish Pasty + +- **Alternative Names:** Oggy, The Miner's Lunch. +- **Origin:** Cornwall (Specifically protected by PGI status, meaning it must be prepared within the county borders). +- **Ingredients:** 500g Shortcrust Pastry, 400g Beef Skirt (cubed), 300g Potato (peeled and diced), 150g Swede (locally referred to as 'Turnip' in Cornwall), 150g Onion (finely chopped), Salt and Black Pepper, 1 Egg (beaten, for glaze). +- **Process:** Roll the pastry and cut into large circles. Place the raw beef, potato, swede, and onion on one half of the circle. Season generously. Fold the pastry over to form a D-shape and crimp the edges firmly to create a thick ridge. Brush with the beaten egg and bake at 200°C for 45 minutes. +- **Historical Context:** The thick crimped ridge allowed Cornish tin miners to hold the pasty with dirty, arsenic-covered hands and discard the crust afterward. + +### Dish 3: Devonshire Homity Pie + +- **Alternative Names:** Land Girls' Pie. +- **Origin:** Popularized in Devon by the Women's Land Army during world war rationing. +- **Ingredients:** 1 Pre-baked Shortcrust Pastry Case, 500g Potatoes (boiled and cubed), 2 Large Leeks (sliced), 1 Large Onion (chopped), 50g Salted Butter, 150g Mature Cheddar Cheese (sourced from Somerset), 2 tbsps Fresh Parsley. +- **Process:** Melt the salted butter in a pan and sauté the leeks and onions until soft. Stir in the cubed potatoes, parsley, and half of the Somerset-sourced Mature Cheddar Cheese. Spoon the filling into the pre-baked pastry case. Top with the remaining cheese and bake at 190°C for 25 minutes until bubbling. + +## Section 3: Sweets & Confectionery + +### Dish 4: Devonshire Fudge + +- **Alternative Names:** Clotted Cream Fudge. +- **Origin:** Widely sold along the coastal towns of South Devon. +- **Ingredients:** 175g West Country Clotted Cream, 450g Caster Sugar, 115g Glucose Syrup, 1 tbsp Vanilla Extract. +- **Process:** Combine the caster sugar, glucose syrup, and West Country Clotted Cream in a heavy-based saucepan. Bring to a boil, stirring continuously until the mixture reaches the "soft ball" stage (116°C). Remove from heat, add vanilla extract, and beat vigorously with a wooden spoon until the mixture thickens and loses its gloss. Pour into a lined tin and allow to set. + +## Section 4: Local Supply Chain & Geopolitical Geography + +- **The Dairy Belt:** Devon and Cornwall are primary rivals in dairy manufacturing. Devon relies heavily on the Red Ruby cattle breed, whereas Cornwall utilizes traditional mixed-pasture herds. Both rely on Somerset for auxiliary hard cheeses like Cheddar. +- **The Allium Link:** Onions and leeks are the common denominator connecting the savory dishes of Devon (Homity Pie) and Cornwall (Cornish Pasty). +- **The Pectin Shortage Risk:** Should a frost hit the Tamar Valley, the production of Hedgerow Jam ceases, creating a supply bottleneck that directly disables the serving of The South West Cream Tea, despite dairy availability. diff --git a/trustgraph-cli/trustgraph/cli/sample_documents/trade-routes-europe.md b/trustgraph-cli/trustgraph/cli/sample_documents/trade-routes-europe.md new file mode 100644 index 00000000..768f2438 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/sample_documents/trade-routes-europe.md @@ -0,0 +1,63 @@ +# Traditional Trade Routes of Pre-Modern Europe + +## Section 1: The Arteries (The Core Networks) + +### Network A: The Hanseatic Baltic Route + +- **Alternative Names:** The Hansa Network, The Northern Guild Rim. +- **Geographical Span:** Spans from the North Sea across the Baltic Sea, linking London, Bruges, Lübeck, Danzig, and Novgorod. +- **Primary Commodities:** Timber, Fur, Flax, Stockfish, Amber. +- **Downstream Dependencies:** Provides raw materials for Western European shipbuilding and winter clothing markets. + +### Network B: The Venetian Maritime Route + +- **Alternative Names:** The Levantine Silk Spoke, The Adriatic Lifeline. +- **Geographical Span:** Connects Venice through the Adriatic Sea, around Greece, to Constantinople and Alexandria. +- **Primary Commodities:** Silk, Pepper, Cinnamon, Alum, Glassware. +- **Downstream Dependencies:** Feeds the luxury markets of the Holy Roman Empire via alpine passes. + +## Section 2: Hub Cities & Commodity Crossings + +### Hub 1: Bruges (The Low Countries) + +- **Alternative Names:** Brugge, The Flanders Staple. +- **Geographical Intersection:** The primary terminus where The Hanseatic Baltic Route meets Western European land routes. + +**The Staple Right Dispute (Critical Logic Test):** + +- **The Guild Law:** By ducal decree, all foreign merchants traveling through Flanders must unload their ships at Bruges and offer their goods for sale for a mandatory 15 days before they can proceed. +- **The English Subversion:** English wool merchants, seeking to bypass the Bruges tax, began smuggling raw wool directly to Antwerp, sparking an economic blockade by the Hanseatic League against English shipping. + +### Hub 2: Lübeck (The Baltic Capital) + +- **Alternative Names:** Lubeca, The Queen of the Hansa. +- **Geographical Intersection:** Located in Northern Germany, acting as the administrative node connecting the North Sea (via the Kiel land-bridge) to the wider Baltic Sea. +- **Resource Matrix:** Completely dependent on the Lüneburg Salt Works for its primary processing industry (herring preservation). + +### Hub 3: Constantinople (The Gateway) + +- **Alternative Names:** Byzantium, Istanbul, Miklagard. +- **Geographical Intersection:** The western terminus of the Silk Road land routes and the northern terminus of The Venetian Maritime Route. +- **Controlling Entity:** Transferred from Byzantine control to Ottoman control in 1453, altering the tariff structures for all Christian merchants. + +## Section 3: Specialized Commodities & Processing Nodes + +### Item 1: Lüneburg Salt + +- **Alternative Names:** White Gold, Northern Brine. +- **Origin:** Extracted from the brine springs of Lüneburg, Germany. +- **Process:** Boiled in massive lead pans using timber sourced from local forests. +- **Critical Dependency Link:** This salt is shipped directly to Bergen (Norway) via Lübeck to pack and preserve Scania Herring. Without this specific salt supply, Baltic fish rots before reaching Western markets. + +### Item 2: Phocaean Alum + +- **Alternative Names:** The Weaver's Fixative, Anatolian Alum. +- **Origin:** Mined in the hills of Phocaea (Asia Minor) under the jurisdiction of the Genoese Republic, later seized by regional powers. +- **Process:** Shipped via Mediterranean maritime routes to Flanders and Florence. +- **Chemical Function:** A mandatory chemical mordant required to fix dyes to wool and textiles. Without Alum, the famous Flemish textile industry cannot produce colored cloth. + +## Section 4: Geopolitical Disruptions & Chokepoints + +- **The Sound Toll Bottleneck:** The King of Denmark levies a mandatory tax on all ships entering or leaving the Baltic Sea through the Øresund strait. A diplomatic dispute or military blockade of the Sound by Denmark instantly halts the flow of Russian timber to the English Royal Dockyards. +- **The Sound-to-Salt Ripple Effect:** If the forests around Lüneburg are depleted, salt production drops. This directly causes a collapse in the Bergen fish trade, which in turn causes a protein shortage and subsequent famine in the labor forces of the Flemish textile hubs. +- **The Alum Monopolization:** Following conflicts in the East, the discovery of a domestic alum mine in Tolfa (Papal States) in 1461 caused a massive geopolitical shift, as the Pope banned the import of "infidel alum" from the East, forcing Venetian merchants to pivot their supply lines inward. diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 70489969..4bf17688 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", - "trustgraph-flow>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", + "trustgraph-flow>=2.5,<2.6", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index d8c690b5..547dea3c 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", "aiohttp", "anthropic", "scylla-driver", @@ -64,6 +64,7 @@ bootstrap = "trustgraph.bootstrap.bootstrapper:run" config-svc = "trustgraph.config.service:run" flow-svc = "trustgraph.flow.service:run" iam-svc = "trustgraph.iam.service:run" +no-auth-svc = "trustgraph.iam.noauth:run" doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" diff --git a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py index 3c658fe3..81b7e98d 100644 --- a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py +++ b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py @@ -326,6 +326,58 @@ class Processor(AsyncProcessor): # Main loop. # ------------------------------------------------------------------ + async def _run_pre_service(self): + """Run pre-service initialisers before opening pub/sub clients. + + These bring up infrastructure that other services depend on + (e.g. Pulsar tenant/namespaces). They use out-of-band APIs + (HTTP admin), not pub/sub, so they don't need a config client. + They run without flag tracking — they must be idempotent. + """ + pre_specs = [ + s for s in self.specs + if not s.instance.wait_for_services + ] + if not pre_specs: + return + + for spec in pre_specs: + child_logger = logger.getChild(spec.name) + child_ctx = InitContext( + logger=child_logger, + config=None, + make_flow_client=self._make_flow_client, + make_iam_client=self._make_iam_client, + ) + child_logger.info(f"Running pre-service initialiser") + try: + await spec.instance.run(child_ctx, None, spec.flag) + child_logger.info(f"Pre-service initialiser completed") + except Exception as e: + child_logger.error( + f"Pre-service initialiser failed: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + raise + + async def start(self): + # Run pre-service initialisers before opening any pub/sub + # connections. They bring up infrastructure (Pulsar + # namespaces, etc.) that super().start() depends on. + while self.running: + try: + await self._run_pre_service() + break + except Exception as e: + logger.info( + f"Pre-service initialisation failed " + f"({type(e).__name__}: {e}); retry in {GATE_BACKOFF}s" + ) + await asyncio.sleep(GATE_BACKOFF) + + await super().start() + async def run(self): logger.info( @@ -347,29 +399,18 @@ class Processor(AsyncProcessor): continue try: - # Phase 1: pre-service initialisers run unconditionally. - pre_specs = [ - s for s in self.specs - if not s.instance.wait_for_services - ] - pre_results = {} - for spec in pre_specs: - pre_results[spec.name] = await self._run_spec( - spec, config, - ) - - # Phase 2: gate. + # Phase 1: gate. gate_ok = await self._gate_ready(config) - # Phase 3: post-service initialisers, if gate passed. - post_results = {} + # Phase 2: post-service initialisers, if gate passed. + results = {} if gate_ok: post_specs = [ s for s in self.specs if s.instance.wait_for_services ] for spec in post_specs: - post_results[spec.name] = await self._run_spec( + results[spec.name] = await self._run_spec( spec, config, ) @@ -377,8 +418,7 @@ class Processor(AsyncProcessor): if not gate_ok: sleep_for = GATE_BACKOFF else: - all_results = {**pre_results, **post_results} - if any(r != "skip" for r in all_results.values()): + if any(r != "skip" for r in results.values()): sleep_for = INIT_RETRY else: sleep_for = STEADY_INTERVAL diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py index 843fe056..1e4805de 100644 --- a/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py @@ -112,6 +112,10 @@ class PulsarTopology(Initialiser): def _reconcile_sync(self, logger): if not self._tenant_exists(): clusters = self._get_clusters() + if not clusters: + raise RuntimeError( + "Pulsar cluster list is empty — broker not ready yet" + ) logger.info( f"Creating tenant {self.tenant!r} with clusters {clusters}" ) diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index 09c6137d..f1fa53f5 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -1,5 +1,6 @@ from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings +from .. schema import DocumentEmbeddings from .. knowledge import hash from .. exceptions import RequestError from .. tables.knowledge import KnowledgeTableStore @@ -157,6 +158,98 @@ class KnowledgeManager: ) ) + async def list_de_cores(self, request, respond, workspace): + + ids = await self.table_store.list_de_cores(workspace) + + await respond( + KnowledgeResponse( + error = None, + ids = ids, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def get_de_core(self, request, respond, workspace): + + logger.info("Getting document embeddings core...") + + async def publish_de(de): + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + document_embeddings = de, + ) + ) + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core retrieval complete") + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = True, + triples = None, + graph_embeddings = None, + ) + ) + + async def put_de_core(self, request, respond, workspace): + + if request.document_embeddings: + await self.table_store.add_document_embeddings( + workspace, request.document_embeddings + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def delete_de_core(self, request, respond, workspace): + + logger.info("Deleting document embeddings core...") + + await self.table_store.delete_document_embeddings( + workspace, request.id + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def load_de_core(self, request, respond, workspace): + + if self.background_task is None: + self.background_task = asyncio.create_task( + self.core_loader() + ) + + await self.loader_queue.put((request, respond, workspace)) + async def core_loader(self): logger.info("Knowledge background processor running...") @@ -165,7 +258,7 @@ class KnowledgeManager: logger.debug("Waiting for next load...") request, respond, workspace = await self.loader_queue.get() - logger.info(f"Loading knowledge: {request.id}") + logger.info(f"Loading: {request.operation} {request.id}") try: @@ -187,25 +280,14 @@ class KnowledgeManager: if "interfaces" not in flow: raise RuntimeError("No defined interfaces") - if "triples-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no triples-store") - - if "graph-embeddings-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no graph-embeddings-store") - - t_q = flow["interfaces"]["triples-store"]["flow"] - ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] - - # Got this far, it should all work - await respond( - KnowledgeResponse( - error = None, - ids = None, - eos = False, - triples = None, - graph_embeddings = None + if request.operation == "load-de-core": + await self._load_de_core( + request, respond, workspace, flow, + ) + else: + await self._load_kg_core( + request, respond, workspace, flow, ) - ) except Exception as e: @@ -223,72 +305,145 @@ class KnowledgeManager: ) ) - - logger.debug("Starting knowledge loading process...") - - try: - - t_pub = None - ge_pub = None - - logger.debug(f"Triples queue: {t_q}") - logger.debug(f"Graph embeddings queue: {ge_q}") - - t_pub = Publisher( - self.flow_config.pubsub, t_q, - schema=Triples, - ) - ge_pub = Publisher( - self.flow_config.pubsub, ge_q, - schema=GraphEmbeddings - ) - - logger.debug("Starting publishers...") - - await t_pub.start() - await ge_pub.start() - - async def publish_triples(t): - # Override collection with request collection - if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): - t.metadata.collection = request.collection or "default" - await t_pub.send(None, t) - - logger.debug("Publishing triples...") - - await self.table_store.get_triples( - workspace, - request.id, - publish_triples, - ) - - async def publish_ge(g): - # Override collection with request collection - if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): - g.metadata.collection = request.collection or "default" - await ge_pub.send(None, g) - - logger.debug("Publishing graph embeddings...") - - await self.table_store.get_graph_embeddings( - workspace, - request.id, - publish_ge, - ) - - logger.debug("Knowledge loading completed") - - except Exception as e: - - logger.error(f"Knowledge exception: {e}", exc_info=True) - - finally: - - logger.debug("Stopping publishers...") - - if t_pub: await t_pub.stop() - if ge_pub: await ge_pub.stop() - logger.debug("Knowledge processing done") continue + + async def _load_kg_core(self, request, respond, workspace, flow): + + if "triples-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no triples-store") + + if "graph-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no graph-embeddings-store") + + t_q = flow["interfaces"]["triples-store"]["flow"] + ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + t_pub = None + ge_pub = None + + try: + + logger.debug(f"Triples queue: {t_q}") + logger.debug(f"Graph embeddings queue: {ge_q}") + + t_pub = Publisher( + self.flow_config.pubsub, t_q, + schema=Triples, + ) + ge_pub = Publisher( + self.flow_config.pubsub, ge_q, + schema=GraphEmbeddings + ) + + logger.debug("Starting publishers...") + + await t_pub.start() + await ge_pub.start() + + async def publish_triples(t): + if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): + t.metadata.collection = request.collection or "default" + await t_pub.send(None, t) + + logger.debug("Publishing triples...") + + await self.table_store.get_triples( + workspace, + request.id, + publish_triples, + ) + + async def publish_ge(g): + if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): + g.metadata.collection = request.collection or "default" + await ge_pub.send(None, g) + + logger.debug("Publishing graph embeddings...") + + await self.table_store.get_graph_embeddings( + workspace, + request.id, + publish_ge, + ) + + logger.debug("Knowledge core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publishers...") + + if t_pub: await t_pub.stop() + if ge_pub: await ge_pub.stop() + + async def _load_de_core(self, request, respond, workspace, flow): + + if "document-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no document-embeddings-store") + + de_q = flow["interfaces"]["document-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + de_pub = None + + try: + + logger.debug(f"Document embeddings queue: {de_q}") + + de_pub = Publisher( + self.flow_config.pubsub, de_q, + schema=DocumentEmbeddings, + ) + + logger.debug("Starting publisher...") + + await de_pub.start() + + async def publish_de(de): + if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'): + de.metadata.collection = request.collection or "default" + await de_pub.send(None, de) + + logger.debug("Publishing document embeddings...") + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publisher...") + + if de_pub: await de_pub.stop() diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index c84b536c..a04e42ca 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor): "put-kg-core": self.knowledge.put_kg_core, "load-kg-core": self.knowledge.load_kg_core, "unload-kg-core": self.knowledge.unload_kg_core, + "list-de-cores": self.knowledge.list_de_cores, + "get-de-core": self.knowledge.get_de_core, + "delete-de-core": self.knowledge.delete_de_core, + "put-de-core": self.knowledge.put_de_core, + "load-de-core": self.knowledge.load_de_core, } if v.operation not in impls: diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 59d2a2a1..d7abd1a9 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -1,10 +1,14 @@ +import datetime +import os +import logging + from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import BatchStatement, SimpleStatement from ssl import SSLContext, PROTOCOL_TLSv1_2 -import os -import logging + +from ..tables.cassandra_async import async_execute # Global list to track clusters for cleanup _active_clusters = [] @@ -461,7 +465,6 @@ class KnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph: logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + # ======================================================================== + # Async methods — use cassandra driver's native async API via async_execute + # ======================================================================== + + async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""): + if g is None: + g = DEFAULT_GRAPH + if otype is None: + if o.startswith("http://") or o.startswith("https://"): + otype = "u" + else: + otype = "l" + + batch = BatchStatement() + batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang)) + if otype == 'u' or otype == 't': + batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang)) + if g != DEFAULT_GRAPH: + batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang)) + + await async_execute(self.session, batch) + + async def async_get_all(self, collection, limit=50): + return await async_execute( + self.session, self.get_collection_all_stmt, (collection, limit) + ) + + async def async_get_s(self, collection, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_p(self, collection, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_p_stmt, (collection, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_o(self, collection, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_stmt, (collection, o, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_sp(self, collection, s, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_po(self, collection, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_os(self, collection, o, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=row.p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_spo(self, collection, s, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_g(self, collection, g, limit=50): + if g is None: + g = DEFAULT_GRAPH + return await async_execute( + self.session, self.get_collection_by_graph_stmt, (collection, g, limit) + ) + + async def async_collection_exists(self, collection): + try: + result = await async_execute( + self.session, + f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1", + (collection,) + ) + return bool(result) + except Exception as e: + logger.error(f"Error checking collection existence: {e}") + return False + + async def async_create_collection(self, collection): + await async_execute( + self.session, + f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", + (collection, datetime.datetime.now()) + ) + logger.info(f"Created collection metadata for {collection}") + + async def async_delete_collection(self, collection): + rows = await async_execute( + self.session, + f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s", + (collection,) + ) + + entities = set() + quads = [] + for row in rows: + d, s, p, o = row.d, row.s, row.p, row.o + otype = row.otype + dtype = row.dtype if hasattr(row, 'dtype') else '' + lang = row.lang if hasattr(row, 'lang') else '' + quads.append((d, s, p, o, otype, dtype, lang)) + entities.add(s) + entities.add(p) + if otype == 'u' or otype == 't': + entities.add(o) + if d != DEFAULT_GRAPH: + entities.add(d) + + batch = BatchStatement() + count = 0 + for entity in entities: + batch.add(self.delete_entity_partition_stmt, (collection, entity)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + batch = BatchStatement() + count = 0 + for d, s, p, o, otype, dtype, lang in quads: + batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + await async_execute( + self.session, + f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s", + (collection,) + ) + logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + def close(self): """Close connections""" if hasattr(self, 'session') and self.session: diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 1d45d3f9..6a43e547 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -121,6 +121,7 @@ class Processor(FlowProcessor): # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.3) + self.bypass_selector_below = params.get("bypass_selector_below", 5) # Per-workspace ontology version tracking self.current_ontology_versions = {} # workspace -> version @@ -187,7 +188,8 @@ class Processor(FlowProcessor): ontology_embedder=ontology_embedder, ontology_loader=loader, top_k=self.top_k, - similarity_threshold=self.similarity_threshold + similarity_threshold=self.similarity_threshold, + bypass_selector_below=self.bypass_selector_below, ) # Store flow-specific components @@ -981,6 +983,13 @@ class Processor(FlowProcessor): default=0.3, help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)' ) + parser.add_argument( + '--bypass-selector-below', + type=int, + default=5, + help='Bypass ontology selector when total ontology elements ' + '(classes + properties) is below this value (default: 5)' + ) parser.add_argument( '--triples-batch-size', type=int, diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py index 5111529a..5fd60a0f 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py @@ -33,19 +33,44 @@ class OntologySelector: def __init__(self, ontology_embedder: OntologyEmbedder, ontology_loader: OntologyLoader, top_k: int = 10, - similarity_threshold: float = 0.7): - """Initialize the ontology selector. - - Args: - ontology_embedder: Embedder with vector store - ontology_loader: Loader with ontology definitions - top_k: Number of top results to retrieve per segment - similarity_threshold: Minimum similarity score - """ + similarity_threshold: float = 0.3, + bypass_selector_below: int = 5): self.embedder = ontology_embedder self.loader = ontology_loader self.top_k = top_k self.similarity_threshold = similarity_threshold + self.bypass_selector_below = bypass_selector_below + + def _total_ontology_elements(self) -> int: + total = 0 + for ontology in self.loader.get_all_ontologies().values(): + total += len(ontology.classes) + total += len(ontology.object_properties) + total += len(ontology.datatype_properties) + return total + + def _build_full_subsets(self) -> List[OntologySubset]: + subsets = [] + for ont_id, ontology in self.loader.get_all_ontologies().items(): + subset = OntologySubset( + ontology_id=ont_id, + classes={ + cid: cls.__dict__ + for cid, cls in ontology.classes.items() + }, + object_properties={ + pid: prop.__dict__ + for pid, prop in ontology.object_properties.items() + }, + datatype_properties={ + pid: prop.__dict__ + for pid, prop in ontology.datatype_properties.items() + }, + metadata=ontology.metadata, + relevance_score=1.0, + ) + subsets.append(subset) + return subsets async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]: """Select relevant ontology subsets for text segments. @@ -56,6 +81,15 @@ class OntologySelector: Returns: List of ontology subsets with relevant elements """ + total = self._total_ontology_elements() + if total < self.bypass_selector_below: + logger.info( + f"Ontology has {total} elements (below " + f"bypass_selector_below={self.bypass_selector_below}), " + f"using full ontology" + ) + return self._build_full_subsets() + # Collect all relevant elements relevant_elements = await self._find_relevant_elements(segments) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py index 06fff4f4..d9e6c837 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py @@ -6,7 +6,7 @@ with full URIs and correct is_uri flags. """ import logging -from typing import List, Optional +from typing import List, Optional, Set from .... schema import Triple, Term, IRI, LITERAL from .... rdf import RDF_TYPE, RDF_LABEL @@ -32,6 +32,25 @@ class TripleConverter: self.ontology_id = ontology_id self.entity_registry = EntityRegistry(ontology_id) + def _get_ancestor_classes(self, class_id: str) -> Set[str]: + ancestors = set() + current = class_id + while current: + cls_def = self.ontology_subset.classes.get(current) + if not cls_def: + break + parent = cls_def.get("subclass_of") if isinstance(cls_def, dict) else getattr(cls_def, "subclass_of", None) + if not parent or parent in ancestors: + break + ancestors.add(parent) + current = parent + return ancestors + + def _matches_class_constraint(self, actual_type: str, expected_type: str) -> bool: + if actual_type == expected_type: + return True + return expected_type in self._get_ancestor_classes(actual_type) + def convert_all(self, extraction: ExtractionResult) -> List[Triple]: """Convert complete extraction result to RDF triples. @@ -129,6 +148,29 @@ class TripleConverter: logger.warning(f"Unknown relationship '{relationship.relation}', skipping") return None + # Enforce domain/range constraints when declared + prop_def = self.ontology_subset.object_properties.get( + relationship.relation, {} + ) + domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None) + range_ = prop_def.get("range") if isinstance(prop_def, dict) else getattr(prop_def, "range", None) + + if domain and not self._matches_class_constraint(relationship.subject_type, domain): + logger.warning( + f"Domain violation: '{relationship.relation}' expects " + f"domain '{domain}', got subject type " + f"'{relationship.subject_type}', skipping" + ) + return None + + if range_ and not self._matches_class_constraint(relationship.object_type, range_): + logger.warning( + f"Range violation: '{relationship.relation}' expects " + f"range '{range_}', got object type " + f"'{relationship.object_type}', skipping" + ) + return None + # Generate triple: subject property object return Triple( s=Term(type=IRI, iri=subject_uri), @@ -157,11 +199,25 @@ class TripleConverter: logger.warning(f"Unknown attribute '{attribute.attribute}', skipping") return None + # Enforce domain constraint when declared + prop_def = self.ontology_subset.datatype_properties.get( + attribute.attribute, {} + ) + domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None) + + if domain and not self._matches_class_constraint(attribute.entity_type, domain): + logger.warning( + f"Domain violation: attribute '{attribute.attribute}' " + f"expects domain '{domain}', got entity type " + f"'{attribute.entity_type}', skipping" + ) + return None + # Generate triple: entity property "literal value" return Triple( s=Term(type=IRI, iri=entity_uri), p=Term(type=IRI, iri=property_uri), - o=Term(type=LITERAL, value=attribute.value) # Literal! + o=Term(type=LITERAL, value=attribute.value) ) def _get_class_uri(self, class_id: str) -> Optional[str]: diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index 1309ecfc..273fcb5a 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -233,10 +233,10 @@ class IamAuth: header = request.headers.get("Authorization", "") if not header.startswith("Bearer "): - raise _auth_failure() + return await self._authenticate_anonymous() token = header[len("Bearer "):].strip() if not token: - raise _auth_failure() + return await self._authenticate_anonymous() # API keys always start with "tg_". JWTs have two dots and # no "tg_" prefix. Discriminate cheaply. @@ -266,6 +266,26 @@ class IamAuth: handle=sub, workspace=ws, principal_id=sub, source="jwt", ) + async def _authenticate_anonymous(self): + try: + async def _call(client): + return await client.authenticate_anonymous() + user_id, workspace, _roles = await self._with_client(_call) + except Exception as e: + logger.debug( + f"Anonymous authentication rejected: " + f"{type(e).__name__}: {e}" + ) + raise _auth_failure() + + if not user_id or not workspace: + raise _auth_failure() + + return Identity( + handle=user_id, workspace=workspace, + principal_id=user_id, source="anonymous", + ) + async def _resolve_api_key(self, plaintext): h = hashlib.sha256(plaintext.encode("utf-8")).hexdigest() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 51161f9b..bddb009d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -135,13 +135,19 @@ class DispatcherWrapper: class DispatcherManager: def __init__(self, backend, config_receiver, auth, - prefix="api-gateway", queue_overrides=None): + prefix="api-gateway", queue_overrides=None, timeout=120): """ ``auth`` is required. It flows into the Mux for first-frame WebSocket authentication and into downstream dispatcher construction. There is no permissive default — constructing a DispatcherManager without an authenticator would be a silent downgrade to no-auth on the socket path. + + ``timeout`` is the per-request timeout in seconds, propagated + to every dispatcher created by this manager. Must match the + gateway's ``--timeout`` flag so that long-running requests + are not prematurely cut off at the old hard-coded 120 s + ceiling. """ if auth is None: raise ValueError( @@ -149,6 +155,8 @@ class DispatcherManager: "is no no-auth mode" ) + self.timeout = timeout + self.backend = backend self.config_receiver = config_receiver self.config_receiver.add_handler(self) @@ -291,7 +299,7 @@ class DispatcherManager: dispatcher = global_dispatchers[kind]( backend = self.backend, - timeout = 120, + timeout = self.timeout, consumer = consumer_name, subscriber = consumer_name, request_queue = request_queue, @@ -448,7 +456,7 @@ class DispatcherManager: backend = self.backend, request_queue = qconfig["request"], response_queue = qconfig["response"], - timeout = 120, + timeout = self.timeout, consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request", subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request", ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 02c0eed2..bdbd18d8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -57,16 +57,13 @@ class Mux: (important for browsers, which treat a handshake-time 401 as terminal).""" token = data.get("token", "") - if not token: - await self.ws.send_json({ - "type": "auth-failed", - "error": "auth failure", - }) - return class _Shim: def __init__(self, tok): - self.headers = {"Authorization": f"Bearer {tok}"} + self.headers = ( + {"Authorization": f"Bearer {tok}"} if tok + else {} + ) try: identity = await self.auth.authenticate(_Shim(token)) diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index 5e3344f4..4d439097 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core", "load-kg-core", "unload-kg-core"): _register_kind_op("knowledge", _op, "knowledge:write") +# knowledge: document-embeddings core service. +for _op in ("get-de-core", "list-de-cores"): + _register_kind_op("knowledge", _op, "knowledge:read") +for _op in ("put-de-core", "delete-de-core", "load-de-core"): + _register_kind_op("knowledge", _op, "knowledge:write") + # collection-management: workspace collection lifecycle. _register_kind_op("collection-management", "list-collections", "collections:read") diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 0f6a5070..fb51e1a2 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -119,6 +119,7 @@ class Api: prefix = "gateway", queue_overrides = queue_overrides, auth = self.auth, + timeout = self.timeout, ) self.endpoint_manager = EndpointManager( diff --git a/trustgraph-flow/trustgraph/iam/noauth/__init__.py b/trustgraph-flow/trustgraph/iam/noauth/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/iam/noauth/__main__.py b/trustgraph-flow/trustgraph/iam/noauth/__main__.py new file mode 100644 index 00000000..a731dd63 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/__main__.py @@ -0,0 +1,4 @@ + +from . service import run + +run() diff --git a/trustgraph-flow/trustgraph/iam/noauth/handler.py b/trustgraph-flow/trustgraph/iam/noauth/handler.py new file mode 100644 index 00000000..d457697e --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/handler.py @@ -0,0 +1,131 @@ +""" +No-auth IAM handler. Implements the IAM contract with every operation +returning a permissive or stub response. No database, no crypto, +no state. +""" + +import json +import logging + +from trustgraph.schema import IamResponse, Error, UserRecord + +logger = logging.getLogger(__name__) + + +def _err(type, message): + return IamResponse(error=Error(type=type, message=message)) + + +class NoAuthHandler: + + def __init__(self, default_user_id="anonymous", + default_workspace="default", + on_workspace_created=None): + self.default_user_id = default_user_id + self.default_workspace = default_workspace + self._on_workspace_created = on_workspace_created + + def _default_identity_response(self): + return IamResponse( + resolved_user_id=self.default_user_id, + resolved_workspace=self.default_workspace, + resolved_roles=["admin"], + ) + + def _default_user_record(self): + return UserRecord( + id=self.default_user_id, + workspace=self.default_workspace, + username=self.default_user_id, + name="Anonymous User", + roles=["admin"], + enabled=True, + ) + + async def handle(self, v): + op = v.operation + + try: + if op == "authenticate-anonymous": + return self._default_identity_response() + + if op == "resolve-api-key": + return self._default_identity_response() + + if op == "authorise": + return IamResponse( + decision_allow=True, + decision_ttl_seconds=3600, + ) + + if op == "authorise-many": + checks = json.loads(v.authorise_checks or "[]") + decisions = [ + {"allow": True, "ttl": 3600} + for _ in checks + ] + return IamResponse( + decisions_json=json.dumps(decisions), + ) + + if op == "get-signing-key-public": + return IamResponse(signing_key_public="") + + if op == "bootstrap": + return IamResponse() + + if op == "bootstrap-status": + return IamResponse(bootstrap_available=False) + + if op == "whoami": + return IamResponse(user=self._default_user_record()) + + if op == "login": + return IamResponse() + + if op in ( + "create-user", "get-user", "update-user", + "disable-user", "enable-user", + ): + return IamResponse(user=self._default_user_record()) + + if op == "list-users": + return IamResponse(users=[self._default_user_record()]) + + if op == "delete-user": + return IamResponse() + + if op == "create-workspace": + if self._on_workspace_created and v.workspace_record: + await self._on_workspace_created(v.workspace_record.id) + return IamResponse() + + if op in ( + "get-workspace", "update-workspace", + "disable-workspace", + ): + return IamResponse() + + if op == "list-workspaces": + return IamResponse() + + if op in ("create-api-key", "list-api-keys", "revoke-api-key"): + return IamResponse() + + if op in ("change-password", "reset-password"): + return IamResponse() + + if op == "rotate-signing-key": + return IamResponse() + + return _err( + "invalid-argument", + f"unknown operation: {op!r}", + ) + + except Exception as e: + logger.error( + f"no-auth {op} failed: {type(e).__name__}: {e}", + exc_info=True, + ) + return _err("internal-error", str(e)) diff --git a/trustgraph-flow/trustgraph/iam/noauth/service.py b/trustgraph-flow/trustgraph/iam/noauth/service.py new file mode 100644 index 00000000..76d13a3c --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/service.py @@ -0,0 +1,182 @@ +""" +No-auth IAM service. Drop-in replacement for iam-svc that permits +all access unconditionally. No database, no bootstrap, no signing keys. +""" + +import logging +import uuid + +from trustgraph.schema import Error +from trustgraph.schema import IamRequest, IamResponse +from trustgraph.schema import iam_request_queue, iam_response_queue +from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigValue +from trustgraph.schema import config_request_queue, config_response_queue + +from trustgraph.base import AsyncProcessor, Consumer, Producer +from trustgraph.base import ConsumerMetrics, ProducerMetrics +from trustgraph.base.metrics import SubscriberMetrics +from trustgraph.base.request_response_spec import RequestResponse + +from . handler import NoAuthHandler + +logger = logging.getLogger(__name__) + +default_ident = "no-auth-svc" + +default_iam_request_queue = iam_request_queue +default_iam_response_queue = iam_response_queue + + +class Processor(AsyncProcessor): + + def __init__(self, **params): + + iam_req_q = params.get( + "iam_request_queue", default_iam_request_queue, + ) + iam_resp_q = params.get( + "iam_response_queue", default_iam_response_queue, + ) + + default_user_id = params.get("default_user_id", "anonymous") + default_workspace = params.get("default_workspace", "default") + + super().__init__(**params) + + iam_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="iam-request", + ) + iam_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="iam-response", + ) + + self.iam_request_topic = iam_req_q + + self.iam_request_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=iam_req_q, + subscriber=self.id, + schema=IamRequest, + handler=self.on_iam_request, + metrics=iam_request_metrics, + ) + + self.iam_response_producer = Producer( + backend=self.pubsub, + topic=iam_resp_q, + schema=IamResponse, + metrics=iam_response_metrics, + ) + + self.handler = NoAuthHandler( + default_user_id=default_user_id, + default_workspace=default_workspace, + on_workspace_created=self._ensure_workspace_registered, + ) + + logger.info( + f"No-auth IAM service initialised " + f"(user={default_user_id}, workspace={default_workspace})" + ) + + async def start(self): + await self.pubsub.ensure_topic(self.iam_request_topic) + await self.iam_request_consumer.start() + + def _create_config_client(self): + config_rr_id = str(uuid.uuid4()) + config_req_metrics = ProducerMetrics( + processor=self.id, flow=None, name="config-request", + ) + config_resp_metrics = SubscriberMetrics( + processor=self.id, flow=None, name="config-response", + ) + return RequestResponse( + backend=self.pubsub, + subscription=f"{self.id}--config--{config_rr_id}", + consumer_name=self.id, + request_topic=config_request_queue, + request_schema=ConfigRequest, + request_metrics=config_req_metrics, + response_topic=config_response_queue, + response_schema=ConfigResponse, + response_metrics=config_resp_metrics, + ) + + async def _ensure_workspace_registered(self, workspace_id): + client = self._create_config_client() + try: + await client.start() + await client.request( + ConfigRequest( + operation="put", + workspace="__workspaces__", + values=[ConfigValue( + type="workspace", key=workspace_id, + value='{"enabled": true}', + )], + ), + timeout=10, + ) + finally: + await client.stop() + logger.info( + f"Registered workspace in config: {workspace_id}" + ) + + async def on_iam_request(self, msg, consumer, flow): + + id = None + try: + v = msg.value() + id = msg.properties()["id"] + logger.debug( + f"Handling IAM request {id} op={v.operation!r}" + ) + resp = await self.handler.handle(v) + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + except Exception as e: + logger.error( + f"IAM request failed: {type(e).__name__}: {e}", + exc_info=True, + ) + resp = IamResponse( + error=Error(type="internal-error", message=str(e)), + ) + if id is not None: + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + + @staticmethod + def add_args(parser): + AsyncProcessor.add_args(parser) + + parser.add_argument( + "--iam-request-queue", + default=default_iam_request_queue, + help=f"IAM request queue (default: {default_iam_request_queue})", + ) + parser.add_argument( + "--iam-response-queue", + default=default_iam_response_queue, + help=f"IAM response queue (default: {default_iam_response_queue})", + ) + parser.add_argument( + "--default-user-id", + default="anonymous", + help="User ID for all requests (default: anonymous)", + ) + parser.add_argument( + "--default-workspace", + default="default", + help="Workspace for all requests (default: default)", + ) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 755a1c5d..0335012e 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -287,6 +287,9 @@ class IamService: op = v.operation try: + if op == "authenticate-anonymous": + return _err("auth-failed", "anonymous access not permitted") + if op == "bootstrap": return await self.handle_bootstrap(v) if op == "bootstrap-status": @@ -394,8 +397,8 @@ class IamService: async def auto_bootstrap_if_token_mode(self): """Called from the service processor at startup. In - ``token`` mode, if tables are empty, seeds the default - workspace / admin / signing key using the operator-provided + ``token`` mode, if tables are empty, seeds the admin user, + API key, and signing key using the operator-provided bootstrap token. The admin's API key plaintext is *the* ``bootstrap_token`` — the operator already knows it, nothing needs to be returned or logged. @@ -405,7 +408,7 @@ class IamService: if self.bootstrap_mode != "token": return - if await self.table_store.any_workspace_exists(): + if await self.table_store.any_signing_key_exists(): logger.info( "IAM: token mode, tables already populated; skipping " "auto-bootstrap" @@ -420,22 +423,13 @@ class IamService: async def _seed_tables(self, api_key_plaintext): """Shared seeding logic used by token-mode auto-bootstrap and - bootstrap-mode handle_bootstrap. Creates the default - workspace, admin user, admin API key (using the given - plaintext), and an initial signing key. Returns the admin + bootstrap-mode handle_bootstrap. Creates the admin user, + admin API key (using the given plaintext), and an initial + signing key. The workspace is created separately by the + bootstrapper's WorkspaceInit initialiser. Returns the admin user id.""" now = _now_dt() - await self.table_store.put_workspace( - id=DEFAULT_WORKSPACE, - name="Default", - enabled=True, - created=now, - ) - - if self._on_workspace_created: - await self._on_workspace_created(DEFAULT_WORKSPACE) - admin_user_id = str(uuid.uuid4()) admin_password = secrets.token_urlsafe(32) await self.table_store.put_user( @@ -488,7 +482,7 @@ class IamService: if self.bootstrap_mode != "bootstrap": return _err("auth-failed", "auth failure") - if await self.table_store.any_workspace_exists(): + if await self.table_store.any_signing_key_exists(): return _err("auth-failed", "auth failure") plaintext = _generate_api_key() @@ -528,7 +522,7 @@ class IamService: instead of forcing callers to probe the masked-failure path.""" available = ( self.bootstrap_mode == "bootstrap" - and not await self.table_store.any_workspace_exists() + and not await self.table_store.any_signing_key_exists() ) return IamResponse(bootstrap_available=available) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 1cefcbe9..c8ab9c36 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -104,7 +104,15 @@ class Processor(LlmService): return resp - except RateLimitError: + except RateLimitError as e: + try: + body = getattr(e, 'body', {}) + if isinstance(body, dict): + code = body.get('error', {}).get('code') + if code in ('insufficient_quota', 'invalid_api_key', 'account_deactivated'): + raise RuntimeError(f"OpenAI unrecoverable error: {code} - {body['error'].get('message', '')}") + except (ValueError, KeyError, TypeError, AttributeError): + pass # Leave rate limit retries to the base handler raise TooManyRequests() @@ -188,7 +196,16 @@ class Processor(LlmService): logger.debug("Streaming complete") - except RateLimitError: + except RateLimitError as e: + try: + body = getattr(e, 'body', {}) + if isinstance(body, dict): + code = body.get('error', {}).get('code') + if code in ('insufficient_quota', 'invalid_api_key', 'account_deactivated'): + logger.warning(f"Hit unrecoverable rate limit error during streaming: {code}") + raise RuntimeError(f"OpenAI unrecoverable error: {code} - {body['error'].get('message', '')}") + except (ValueError, KeyError, TypeError, AttributeError): + pass logger.warning("Hit rate limit during streaming") raise TooManyRequests() diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 1d59c835..f6770744 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array of chunk_ids """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error @@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) async def query_document_embeddings(self, workspace, msg): @@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"d_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist, returning empty results") return [] - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit, with_payload=True, - ).points + ) + search_result = result.points chunks = [] for r in search_result: diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index b8fb1361..167130c9 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of entities """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL @@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"t_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist") return [] # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) unique entities - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit * 2, with_payload=True, - ).points + ) + search_result = result.points entity_set = set() entities = [] diff --git a/trustgraph-flow/trustgraph/query/ontology/__init__.py b/trustgraph-flow/trustgraph/query/ontology/__init__.py index 60557ea9..c5cddd9c 100644 --- a/trustgraph-flow/trustgraph/query/ontology/__init__.py +++ b/trustgraph-flow/trustgraph/query/ontology/__init__.py @@ -7,7 +7,7 @@ Provides semantic query understanding, ontology matching, and answer generation. from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType -from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset from .backend_router import BackendRouter, BackendType, QueryRoute from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult @@ -27,7 +27,7 @@ __all__ = [ 'QuestionType', # Ontology matching - 'OntologyMatcher', + 'OntologyMatcherForQueries', 'QueryOntologySubset', # Backend routing diff --git a/trustgraph-flow/trustgraph/query/ontology/monitoring.py b/trustgraph-flow/trustgraph/query/ontology/monitoring.py index 703c6e95..cb7e8a2e 100644 --- a/trustgraph-flow/trustgraph/query/ontology/monitoring.py +++ b/trustgraph-flow/trustgraph/query/ontology/monitoring.py @@ -4,6 +4,7 @@ Provides comprehensive monitoring of system performance, query patterns, and res """ import logging +import re import time import asyncio import inspect @@ -276,6 +277,26 @@ class MetricsCollector: return f"{name}{{{label_str}}}" +def _extract_metric_label(metric_name: str, label: str) -> Optional[str]: + """Extract a label value from an internal metric key.""" + labels_start = metric_name.find('{') + labels_end = metric_name.find('}', labels_start + 1) + + if labels_start == -1 or labels_end == -1: + return None + + labels = metric_name[labels_start + 1:labels_end] + label_match = re.search( + rf'(?:^|,){re.escape(label)}=(?:"([^"]*)"|([^,]*))', + labels, + ) + if not label_match: + return None + + quoted_value, unquoted_value = label_match.groups() + return quoted_value if quoted_value is not None else unquoted_value + + class PerformanceMonitor: """Monitors system performance and component health.""" @@ -474,8 +495,8 @@ class PerformanceMonitor: # Cache performance cache_types = set() for metric_name in self.metrics_collector.counters.keys(): - if 'cache_type=' in metric_name: - cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0] + cache_type = _extract_metric_label(metric_name, 'cache_type') + if cache_type is not None: cache_types.add(cache_type) for cache_type in cache_types: diff --git a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py index 895856f3..2dd6633a 100644 --- a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py +++ b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py @@ -7,10 +7,10 @@ import logging from typing import List, Dict, Any, Set, Optional from dataclasses import dataclass -from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader -from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder -from ...extract.kg.ontology.text_processor import TextSegment -from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset +from trustgraph.extract.kg.ontology.ontology_loader import Ontology, OntologyLoader +from trustgraph.extract.kg.ontology.ontology_embedder import OntologyEmbedder +from trustgraph.extract.kg.ontology.text_processor import TextSegment +from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset from .question_analyzer import QuestionComponents, QuestionType logger = logging.getLogger(__name__) diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py index c6057cc1..77e60b50 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_service.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -8,13 +8,13 @@ from typing import Dict, Any, List, Optional, Union from dataclasses import dataclass from datetime import datetime -from ....flow.flow_processor import FlowProcessor -from ....tables.config import ConfigTableStore -from ...extract.kg.ontology.ontology_loader import OntologyLoader -from ...extract.kg.ontology.vector_store import InMemoryVectorStore +from trustgraph.base.flow_processor import FlowProcessor +from trustgraph.tables.config import ConfigTableStore +from trustgraph.extract.kg.ontology.ontology_loader import OntologyLoader +from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore from .question_analyzer import QuestionAnalyzer, QuestionComponents -from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset from .backend_router import BackendRouter, QueryRoute, BackendType from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult @@ -105,7 +105,7 @@ class OntoRAGQueryService(FlowProcessor): # Initialize ontology matcher matcher_config = self.config.get('ontology_matcher', {}) - self.ontology_matcher = OntologyMatcher( + self.ontology_matcher = OntologyMatcherForQueries( vector_store=self.vector_store, embedding_service=self.embedding_service, config=matcher_config diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py index 688e7371..b7f0f423 100644 --- a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py @@ -28,7 +28,7 @@ try: except ImportError: CASSANDRA_AVAILABLE = False -from ....tables.config import ConfigTableStore +from trustgraph.tables.config import ConfigTableStore logger = logging.getLogger(__name__) diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py index 44c7e0a1..97fc5f4d 100644 --- a/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py @@ -202,11 +202,14 @@ ASK {{ if response and isinstance(response, dict): query = response.get('query', '').strip() - if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')): + parts = query.split() + if parts and parts[0].upper() in ( + 'SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE', + ): return SPARQLQuery( query=query, variables=self._extract_variables(query), - query_type=query.split()[0].upper(), + query_type=parts[0].upper(), explanation=response.get('explanation', 'Generated by LLM'), complexity_score=self._calculate_complexity(query) ) diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index dd89a8d8..1534c044 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for use in subsequent Cassandra lookups. """ +import asyncio import logging import re from typing import Optional @@ -70,7 +71,7 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: + async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: """Find the Qdrant collection for a given workspace/collection/schema""" prefix = ( f"rows_{self.sanitize_name(workspace)}_" @@ -78,14 +79,15 @@ class Processor(FlowProcessor): ) try: - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching = [ coll.name for coll in all_collections if coll.name.startswith(prefix) ] if matching: - # Return first match (there should typically be only one per dimension) return matching[0] except Exception as e: @@ -100,8 +102,7 @@ class Processor(FlowProcessor): if not vec: return [] - # Find the collection for this workspace/collection/schema - qdrant_collection = self.find_collection( + qdrant_collection = await self.find_collection( workspace, request.collection, request.schema_name ) @@ -113,7 +114,6 @@ class Processor(FlowProcessor): return [] try: - # Build optional filter for index_name query_filter = None if request.index_name: query_filter = Filter( @@ -125,16 +125,16 @@ class Processor(FlowProcessor): ] ) - # Query Qdrant - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=qdrant_collection, query=vec, limit=request.limit, with_payload=True, query_filter=query_filter, - ).points + ) + search_result = result.points - # Convert to RowIndexMatch objects matches = [] for point in search_result: payload = point.payload or {} diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 73cfcd83..7157daae 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema: - source: text """ +import asyncio import json import logging import re @@ -97,34 +98,38 @@ class Processor(FlowProcessor): # Cassandra session self.cluster = None self.session = None + self._setup_lock = asyncio.Lock() # Known keyspaces self.known_keyspaces: Set[str] = set() - def connect_cassandra(self): + async def connect_cassandra(self): """Connect to Cassandra cluster""" - if self.session: - return + async with self._setup_lock: + if self.session: + return - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) + try: + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password + ) + cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + cluster = Cluster(contact_points=self.cassandra_host) - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") + session = await asyncio.to_thread(cluster.connect) + self.cluster = cluster + self.session = session + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -140,14 +145,17 @@ class Processor(FlowProcessor): f"for workspace {workspace}" ) - # Replace existing schemas for this workspace + async with self._setup_lock: + await self._apply_schema_config(workspace, config) + + async def _apply_schema_config(self, workspace, config): + ws_schemas: Dict[str, RowSchema] = {} self.schemas[workspace] = ws_schemas builder = GraphQLSchemaBuilder() self.schema_builders[workspace] = builder - # Check if our config type exists if self.config_key not in config: logger.warning( f"No '{self.config_key}' type in configuration " @@ -156,16 +164,12 @@ class Processor(FlowProcessor): self.graphql_schemas[workspace] = None return - # Get the schemas dictionary for our type schemas_config = config[self.config_key] - # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: - # Parse the JSON schema definition schema_def = json.loads(schema_json) - # Create Field objects fields = [] for field_def in schema_def.get("fields", []): field = SchemaField( @@ -180,7 +184,6 @@ class Processor(FlowProcessor): ) fields.append(field) - # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), @@ -202,7 +205,6 @@ class Processor(FlowProcessor): f"{len(ws_schemas)} schemas" ) - # Regenerate GraphQL schema for this workspace self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: @@ -254,7 +256,7 @@ class Processor(FlowProcessor): For other queries, we need to scan and post-filter. """ # Connect if needed - self.connect_cassandra() + await self.connect_cassandra() safe_keyspace = self.sanitize_name(workspace) diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index bff9a336..c7542577 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -4,6 +4,10 @@ SPARQL algebra evaluator. Recursively evaluates an rdflib SPARQL algebra tree by issuing triple pattern queries via TriplesClient (streaming) and performing in-memory joins, filters, and projections. + +Handlers are async generators that yield solutions incrementally. +Blocking operators (joins, sort, group, distinct) materialise their +upstream into a list at the boundary, then yield results. """ import logging @@ -17,7 +21,7 @@ from ... knowledge import Uri from ... knowledge import Literal as KgLiteral from . parser import rdflib_term_to_term from . solutions import ( - hash_join, left_join, union, project, distinct, + hash_join, left_join, minus, union, project, distinct, order_by, slice_solutions, _term_key, ) from . expressions import evaluate_expression, _effective_boolean @@ -30,61 +34,60 @@ class EvaluationError(Exception): pass -async def evaluate(node, triples_client, workspace, collection, limit=10000): +async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. - Args: - node: rdflib CompValue algebra node - triples_client: TriplesClient instance for triple pattern queries - workspace: workspace/keyspace identifier - collection: collection identifier - limit: safety limit on results - - Returns: - list of solutions (dicts mapping variable names to Term values) + Yields solutions (dicts mapping variable names to Term values) + incrementally as an async generator. """ if not isinstance(node, CompValue): logger.warning(f"Expected CompValue, got {type(node)}: {node}") - return [{}] + yield {} + return name = node.name handler = _HANDLERS.get(name) if handler is None: logger.warning(f"Unsupported algebra node: {name}") - return [{}] + yield {} + return - return await handler(node, triples_client, workspace, collection, limit) + async for sol in handler(node, triples_client, collection, limit): + yield sol -# --- Node handlers --- - -async def _eval_select_query(node, tc, workspace, collection, limit): - """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, workspace, collection, limit) +async def materialise(node, triples_client, collection, limit=10000): + """Collect all solutions from evaluate() into a list.""" + return [sol async for sol in evaluate(node, triples_client, collection, limit)] -async def _eval_project(node, tc, workspace, collection, limit): - """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) +# --- Node handlers (async generators) --- + +async def _eval_select_query(node, tc, collection, limit): + async for sol in evaluate(node.p, tc, collection, limit): + yield sol + + +async def _eval_project(node, tc, collection, limit): variables = [str(v) for v in node.PV] - return project(solutions, variables) + async for sol in evaluate(node.p, tc, collection, limit): + yield {v: sol[v] for v in variables if v in sol} -async def _eval_bgp(node, tc, workspace, collection, limit): +async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. - Issues streaming triple pattern queries and joins results. Patterns - are ordered by selectivity (more bound terms first) and evaluated - sequentially with bound-variable substitution. + Patterns are ordered by selectivity and evaluated sequentially. + For the final pattern, results stream directly from the triple store. """ triples = node.triples if not triples: - return [{}] + yield {} + return - # Sort patterns by selectivity: more bound terms = more selective def selectivity(pattern): return sum(1 for t in pattern if not isinstance(t, Variable)) @@ -92,55 +95,222 @@ async def _eval_bgp(node, tc, workspace, collection, limit): enumerate(triples), key=lambda x: -selectivity(x[1]) ) + # For all patterns except the last, we must materialise intermediate + # solutions because each pattern depends on bindings from prior ones. + # The last pattern streams directly. solutions = [{}] - for _, pattern in sorted_patterns: + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) - new_solutions = [] + if is_last: + # Stream the final pattern — yield as triples arrive + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - for sol in solutions: - # Substitute known bindings into the pattern - s_val = _resolve_term(s_tmpl, sol) - p_val = _resolve_term(p_tmpl, sol) - o_val = _resolve_term(o_tmpl, sol) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + # Materialise intermediate patterns + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - # Query the triples store - results = await _query_pattern( - tc, s_val, p_val, o_val, workspace, collection, limit - ) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) - # Map results back to variable bindings, - # converting Uri/Literal to Term objects - for triple in results: - binding = dict(sol) - if isinstance(s_tmpl, Variable): - binding[str(s_tmpl)] = _to_term(triple.s) - if isinstance(p_tmpl, Variable): - binding[str(p_tmpl)] = _to_term(triple.p) - if isinstance(o_tmpl, Variable): - binding[str(o_tmpl)] = _to_term(triple.o) - new_solutions.append(binding) - - solutions = new_solutions - - if not solutions: - break - - return solutions[:limit] + solutions = new_solutions + if not solutions: + return -async def _eval_join(node, tc, workspace, collection, limit): - """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) - return hash_join(left, right)[:limit] +# --- Blocking operators: materialise upstream, then yield --- + +def _is_small_node(node): + """Check if a node is likely to produce a small number of solutions.""" + if not isinstance(node, CompValue): + return False + if node.name in ("values", "ToMultiSet"): + return True + if node.name == "Extend" and hasattr(node, "p"): + return _is_small_node(node.p) + return False -async def _eval_left_join(node, tc, workspace, collection, limit): - """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, workspace, collection, limit) - right_sols = await evaluate(node.p2, tc, workspace, collection, limit) +async def _eval_join(node, tc, collection, limit): + # Bind join: if one side is small (e.g. VALUES), materialise it and + # substitute its bindings into the other side's evaluation. This + # turns wildcard BGP queries into selective ones. + if _is_small_node(node.p1): + yield_from = _bind_join(node.p1, node.p2, tc, collection, limit) + elif _is_small_node(node.p2): + yield_from = _bind_join(node.p2, node.p1, tc, collection, limit) + else: + yield_from = _hash_join(node, tc, collection, limit) + + async for sol in yield_from: + yield sol + + +async def _hash_join(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in hash_join(left, right)[:limit]: + yield sol + + +async def _bind_join(small_node, big_node, tc, collection, limit): + """Iterate over the small side and inject bindings into the big side.""" + small_sols = await materialise(small_node, tc, collection, limit) + + count = 0 + for binding in small_sols: + async for sol in _evaluate_with_bindings( + big_node, binding, tc, collection, limit + ): + yield sol + count += 1 + if count >= limit: + return + + +def _merge_compatible(left, right): + """Merge two solutions if compatible (shared vars have equal values).""" + merged = dict(left) + for k, v in right.items(): + if k in merged: + if _term_key(merged[k]) != _term_key(v): + return None + else: + merged[k] = v + return merged + + +async def _evaluate_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a node with pre-seeded variable bindings. + + For BGP nodes, the bindings are injected so _resolve_term sees them, + turning wildcard queries into selective ones. For other node types, + evaluate normally and merge/filter against the bindings. + """ + if isinstance(node, CompValue) and node.name == "BGP": + async for sol in _eval_bgp_with_bindings( + node, bindings, tc, collection, limit + ): + yield sol + else: + async for sol in evaluate(node, tc, collection, limit): + merged = _merge_compatible(bindings, sol) + if merged is not None: + yield merged + + +async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a BGP with pre-seeded bindings so variables resolve to terms.""" + triples = node.triples + if not triples: + yield dict(bindings) + return + + def selectivity(pattern): + score = 0 + for t in pattern: + if not isinstance(t, Variable): + score += 1 + elif str(t) in bindings: + score += 1 + return score + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [dict(bindings)] + + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): + s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) + + if is_last: + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + if not solutions: + return + + +async def _eval_left_join(node, tc, collection, limit): + # Buffer right side for hash index; stream left through probe + left_sols = await materialise(node.p1, tc, collection, limit) + right_sols = await materialise(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -150,42 +320,35 @@ async def _eval_left_join(node, tc, workspace, collection, limit): evaluate_expression(expr, sol) ) - return left_join(left_sols, right_sols, filter_fn)[:limit] + for sol in left_join(left_sols, right_sols, filter_fn)[:limit]: + yield sol -async def _eval_union(node, tc, workspace, collection, limit): - """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) - return union(left, right)[:limit] +async def _eval_minus(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in minus(left, right): + yield sol -async def _eval_filter(node, tc, workspace, collection, limit): - """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) - expr = node.expr - return [ - sol for sol in solutions - if _effective_boolean(evaluate_expression(expr, sol)) - ] +async def _eval_distinct(node, tc, collection, limit): + seen = set() + async for sol in evaluate(node.p, tc, collection, limit): + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + yield sol -async def _eval_distinct(node, tc, workspace, collection, limit): - """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) - return distinct(solutions) +async def _eval_reduced(node, tc, collection, limit): + async for sol in _eval_distinct(node, tc, collection, limit): + yield sol -async def _eval_reduced(node, tc, workspace, collection, limit): - """Evaluate a Reduced node (like Distinct but implementation-defined).""" - # Treat same as Distinct - solutions = await evaluate(node.p, tc, workspace, collection, limit) - return distinct(solutions) - - -async def _eval_order_by(node, tc, workspace, collection, limit): - """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) +async def _eval_order_by(node, tc, collection, limit): + solutions = await materialise(node.p, tc, collection, limit) key_fns = [] for cond in node.expr: @@ -197,36 +360,104 @@ async def _eval_order_by(node, tc, workspace, collection, limit): ascending, )) else: - # Simple variable or expression key_fns.append(( lambda sol, e=cond: evaluate_expression(e, sol), True, )) - return order_by(solutions, key_fns) + for sol in order_by(solutions, key_fns): + yield sol -async def _eval_slice(node, tc, workspace, collection, limit): - """Evaluate a Slice node (LIMIT/OFFSET).""" - # Pass tighter limit downstream if possible - inner_limit = limit - if node.length is not None: - offset = node.start or 0 - inner_limit = min(limit, offset + node.length) +# --- Streamable operators --- - solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) - return slice_solutions(solutions, node.start or 0, node.length) +async def _eval_slice(node, tc, collection, limit): + offset = node.start or 0 + length = node.length + skipped = 0 + emitted = 0 + + async for sol in evaluate(node.p, tc, collection, limit): + if skipped < offset: + skipped += 1 + continue + yield sol + emitted += 1 + if length is not None and emitted >= length: + return -async def _eval_extend(node, tc, workspace, collection, limit): - """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) +async def _eval_union(node, tc, collection, limit): + async for sol in evaluate(node.p1, tc, collection, limit): + yield sol + async for sol in evaluate(node.p2, tc, collection, limit): + yield sol + + +async def _check_exists(graph_node, sol, tc, collection, limit): + """Evaluate an EXISTS graph pattern against a solution.""" + async for r in evaluate(graph_node, tc, collection, limit): + shared = set(sol.keys()) & set(r.keys()) + if all( + _term_key(sol[v]) == _term_key(r[v]) + for v in shared + if sol.get(v) is not None and r.get(v) is not None + ): + return True + return False + + +async def _pre_eval_exists(expr, sol, tc, collection, limit, cache): + """Walk an expression tree, pre-evaluate EXISTS/NOT EXISTS, cache results.""" + if not isinstance(expr, CompValue): + return + if expr.name in ("Builtin_EXISTS", "Builtin_NOTEXISTS"): + key = id(expr.graph), id(sol) + if key not in cache: + cache[key] = await _check_exists( + expr.graph, sol, tc, collection, limit + ) + return + for attr in ("expr", "other", "arg", "arg1", "arg2", "arg3"): + child = getattr(expr, attr, None) + if child is None: + continue + if isinstance(child, CompValue): + await _pre_eval_exists(child, sol, tc, collection, limit, cache) + elif isinstance(child, (list, tuple)): + for item in child: + if isinstance(item, CompValue): + await _pre_eval_exists( + item, sol, tc, collection, limit, cache + ) + + +async def _eval_filter(node, tc, collection, limit): + expr = node.expr + exists_cache = {} + + def exists_cb(graph_node, sol): + key = id(graph_node), id(sol) + return exists_cache.get(key, False) + + async for sol in evaluate(node.p, tc, collection, limit): + await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) + if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)): + yield sol + + +async def _eval_extend(node, tc, collection, limit): var_name = str(node.var) expr = node.expr + exists_cache = {} - result = [] - for sol in solutions: - val = evaluate_expression(expr, sol) + def exists_cb(graph_node, sol): + key = id(graph_node), id(sol) + return exists_cache.get(key, False) + + async for sol in evaluate(node.p, tc, collection, limit): + await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) + val = evaluate_expression(expr, sol, exists_cb=exists_cb) new_sol = dict(sol) if isinstance(val, Term): new_sol[var_name] = val @@ -241,16 +472,14 @@ async def _eval_extend(node, tc, workspace, collection, limit): ) elif val is not None: new_sol[var_name] = Term(type=LITERAL, value=str(val)) - result.append(new_sol) - - return result + yield new_sol -async def _eval_group(node, tc, workspace, collection, limit): - """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) +# --- Aggregation (blocking) --- + +async def _eval_group(node, tc, collection, limit): + solutions = await materialise(node.p, tc, collection, limit) - # Extract grouping expressions group_exprs = [] if hasattr(node, "expr") and node.expr: for expr in node.expr: @@ -261,7 +490,6 @@ async def _eval_group(node, tc, workspace, collection, limit): else: group_exprs.append((expr, None)) - # Group solutions groups = defaultdict(list) for sol in solutions: key_parts = [] @@ -271,81 +499,72 @@ async def _eval_group(node, tc, workspace, collection, limit): groups[tuple(key_parts)].append(sol) if not group_exprs: - # No GROUP BY - entire result is one group groups[()].extend(solutions) - # Build grouped solutions (one per group) - result = [] for key, group_sols in groups.items(): sol = {} - # Include group key variables if group_sols: for (expr, var_name), k in zip(group_exprs, key): if var_name and group_sols: sol[var_name] = evaluate_expression(expr, group_sols[0]) sol["__group__"] = group_sols - result.append(sol) - - return result + yield sol -async def _eval_aggregate_join(node, tc, workspace, collection, limit): - """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) - - result = [] - for sol in solutions: +async def _eval_aggregate_join(node, tc, collection, limit): + async for sol in evaluate(node.p, tc, collection, limit): group = sol.get("__group__", [sol]) new_sol = {k: v for k, v in sol.items() if k != "__group__"} - # Apply aggregate functions if hasattr(node, "A") and node.A: for agg in node.A: var_name = str(agg.res) agg_val = _compute_aggregate(agg, group) new_sol[var_name] = agg_val - result.append(new_sol) - - return result + yield new_sol -async def _eval_graph(node, tc, workspace, collection, limit): - """Evaluate a Graph node (GRAPH clause).""" +async def _eval_graph(node, tc, collection, limit): term = node.term if isinstance(term, URIRef): - # GRAPH { ... } — fixed graph - # We'd need to pass graph to triples queries - # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) elif isinstance(term, Variable): - # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) - else: - return await evaluate(node.p, tc, workspace, collection, limit) + + async for sol in evaluate(node.p, tc, collection, limit): + yield sol -async def _eval_values(node, tc, workspace, collection, limit): - """Evaluate a VALUES clause (inline data).""" +async def _eval_values(node, tc, collection, limit): + # rdflib has two representations for VALUES: + # 1. var=[Variable...], value=[[val, ...], ...] — positional + # 2. var=None, res=[{Variable: val, ...}, ...] — dict-based + if hasattr(node, "res") and node.res: + for row in node.res: + sol = {} + for var, val in row.items(): + if val is not None and str(val) != "UNDEF": + sol[str(var)] = rdflib_term_to_term(val) + yield sol + return + + if not node.var or not node.value: + yield {} + return variables = [str(v) for v in node.var] - solutions = [] - for row in node.value: sol = {} for var_name, val in zip(variables, row): if val is not None and str(val) != "UNDEF": sol[var_name] = rdflib_term_to_term(val) - solutions.append(sol) - - return solutions + yield sol -async def _eval_to_multiset(node, tc, workspace, collection, limit): - """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, workspace, collection, limit) +async def _eval_to_multiset(node, tc, collection, limit): + async for sol in evaluate(node.p, tc, collection, limit): + yield sol # --- Aggregate computation --- @@ -354,7 +573,6 @@ def _compute_aggregate(agg, group): """Compute a single aggregate function over a group of solutions.""" agg_name = agg.name if hasattr(agg, "name") else "" - # Get the expression to aggregate expr = agg.vars if hasattr(agg, "vars") else None if agg_name == "Aggregate_Count": @@ -487,7 +705,7 @@ def _resolve_term(tmpl, solution): return rdflib_term_to_term(tmpl) -async def _query_pattern(tc, s, p, o, workspace, collection, limit): +async def _query_pattern(tc, s, p, o, collection, limit): """ Issue a streaming triple pattern query via TriplesClient. @@ -496,7 +714,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit): results = await tc.query( s=s, p=p, o=o, limit=limit, - workspace=workspace, collection=collection, ) return results @@ -527,6 +744,7 @@ _HANDLERS = { "Join": _eval_join, "LeftJoin": _eval_left_join, "Union": _eval_union, + "Minus": _eval_minus, "Filter": _eval_filter, "Distinct": _eval_distinct, "Reduced": _eval_reduced, diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py index eac1199c..608eeff2 100644 --- a/trustgraph-flow/trustgraph/query/sparql/expressions.py +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -5,9 +5,15 @@ Evaluates rdflib algebra expression nodes against a solution (variable binding) to produce a value or boolean result. """ +import hashlib +import math +import random import re import logging import operator +import uuid +from datetime import datetime, date, timezone +from urllib.parse import quote from rdflib.term import Variable, URIRef, Literal, BNode from rdflib.plugins.sparql.parserutils import CompValue @@ -17,23 +23,31 @@ from . parser import rdflib_term_to_term logger = logging.getLogger(__name__) +_exists_callback = None + class ExpressionError(Exception): """Raised when a SPARQL expression cannot be evaluated.""" pass -def evaluate_expression(expr, solution): +def evaluate_expression(expr, solution, exists_cb=None): """ Evaluate a SPARQL expression against a solution binding. Args: expr: rdflib algebra expression node solution: dict mapping variable names to Term values + exists_cb: optional callback(graph_node, solution) -> bool for + EXISTS/NOT EXISTS evaluation; provided by algebra.py Returns: The result value (Term, bool, number, string, or None) """ + global _exists_callback + if exists_cb is not None: + _exists_callback = exists_cb + if expr is None: return True @@ -111,6 +125,13 @@ def _evaluate_comp_value(node, solution): if name == "MultiplicativeExpression": return _eval_multiplicative(node, solution) + # IN / NOT IN — must be checked before the generic Builtin_ dispatch + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + # SPARQL built-in functions if name.startswith("Builtin_"): return _eval_builtin(name, node, solution) @@ -119,27 +140,10 @@ def _evaluate_comp_value(node, solution): if name == "Function": return _eval_function(node, solution) - # Exists / NotExists - if name == "Builtin_EXISTS": - # EXISTS requires graph pattern evaluation - not handled here - logger.warning("EXISTS not supported in filter expressions") - return True - - if name == "Builtin_NOTEXISTS": - logger.warning("NOT EXISTS not supported in filter expressions") - return True - # TrueFilter (used with OPTIONAL) if name == "TrueFilter": return True - # IN / NOT IN - if name == "Builtin_IN": - return _eval_in(node, solution) - - if name == "Builtin_NOTIN": - return not _eval_in(node, solution) - logger.warning(f"Unknown CompValue expression: {name}") return None @@ -165,6 +169,22 @@ def _eval_relational(node, solution): ">=": operator.ge, } + if str(op) == "IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return True + return False + + if str(op) == "NOT IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return False + return True + op_fn = ops.get(str(op)) if op_fn is None: logger.warning(f"Unknown relational operator: {op}") @@ -335,6 +355,197 @@ def _eval_builtin(name, node, solution): return val return None + if builtin == "YEAR": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.year if dt is not None else None + + if builtin == "MONTH": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.month if dt is not None else None + + if builtin == "DAY": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.day if dt is not None else None + + if builtin == "HOURS": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.hour if isinstance(dt, datetime) else 0 + + if builtin == "MINUTES": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.minute if isinstance(dt, datetime) else 0 + + if builtin == "SECONDS": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.second if isinstance(dt, datetime) else 0 + + if builtin == "FLOOR": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return int(math.floor(val)) + + if builtin == "CEIL": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return int(math.ceil(val)) + + if builtin == "ABS": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return abs(val) + + if builtin == "ROUND": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return round(val) + + if builtin == "STRBEFORE": + string = _to_string(evaluate_expression(node.arg1, solution)) + sep = _to_string(evaluate_expression(node.arg2, solution)) + idx = string.find(sep) + if idx < 0: + return Term(type=LITERAL, value="") + return Term(type=LITERAL, value=string[:idx]) + + if builtin == "STRAFTER": + string = _to_string(evaluate_expression(node.arg1, solution)) + sep = _to_string(evaluate_expression(node.arg2, solution)) + idx = string.find(sep) + if idx < 0: + return Term(type=LITERAL, value="") + return Term(type=LITERAL, value=string[idx + len(sep):]) + + if builtin == "ENCODE_FOR_URI": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=quote(val, safe="")) + + if builtin == "REPLACE": + string = _to_string(evaluate_expression(node.arg, solution)) + pattern = _to_string(evaluate_expression(node.pattern, solution)) + replacement = _to_string( + evaluate_expression(node.replacement, solution) + ) + flags_str = "" + if hasattr(node, "flags") and node.flags is not None: + flags_str = _to_string(evaluate_expression(node.flags, solution)) + re_flags = 0 + if "i" in flags_str: + re_flags |= re.IGNORECASE + if "m" in flags_str: + re_flags |= re.MULTILINE + if "s" in flags_str: + re_flags |= re.DOTALL + try: + result = re.sub(pattern, replacement, string, flags=re_flags) + return Term(type=LITERAL, value=result) + except re.error: + return None + + if builtin == "SUBSTR": + string = _to_string(evaluate_expression(node.arg, solution)) + start = _to_numeric(evaluate_expression(node.start, solution)) + if start is None: + return None + start_idx = max(int(start) - 1, 0) + if hasattr(node, "length") and node.length is not None: + length = _to_numeric(evaluate_expression(node.length, solution)) + if length is None: + return None + return Term( + type=LITERAL, value=string[start_idx:start_idx + int(length)] + ) + return Term(type=LITERAL, value=string[start_idx:]) + + if builtin == "EXISTS": + if _exists_callback is not None: + return _exists_callback(node.graph, solution) + logger.warning("EXISTS requires an exists_cb; not available") + return True + + if builtin == "NOTEXISTS": + if _exists_callback is not None: + return not _exists_callback(node.graph, solution) + logger.warning("NOT EXISTS requires an exists_cb; not available") + return True + + if builtin == "LANGMATCHES": + tag = _to_string(evaluate_expression(node.arg1, solution)) + rng = _to_string(evaluate_expression(node.arg2, solution)) + if rng == "*": + return len(tag) > 0 + return tag.lower().startswith(rng.lower()) + + if builtin == "IRI" or builtin == "URI": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=IRI, iri=val) + + if builtin == "BNODE": + if hasattr(node, "arg") and node.arg is not None: + label = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=BLANK, id=label) + return Term(type=BLANK, id=str(uuid.uuid4())) + + if builtin == "NOW": + now = datetime.now(timezone.utc) + return Term( + type=LITERAL, + value=now.strftime("%Y-%m-%dT%H:%M:%S%z"), + datatype="http://www.w3.org/2001/XMLSchema#dateTime", + ) + + if builtin == "TZ": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return Term(type=LITERAL, value="") + if dt.tzinfo is not None: + offset = dt.strftime("%z") + if offset: + return Term(type=LITERAL, value=offset[:3] + ":" + offset[3:]) + return Term(type=LITERAL, value="") + + if builtin == "RAND": + return random.random() + + if builtin == "UUID": + return Term(type=IRI, iri="urn:uuid:" + str(uuid.uuid4())) + + if builtin == "STRUUID": + return Term(type=LITERAL, value=str(uuid.uuid4())) + + if builtin == "MD5": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.md5(val.encode()).hexdigest() + ) + + if builtin == "SHA1": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha1(val.encode()).hexdigest() + ) + + if builtin == "SHA256": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha256(val.encode()).hexdigest() + ) + + if builtin == "SHA512": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha512(val.encode()).hexdigest() + ) + if builtin == "sameTerm": left = evaluate_expression(node.arg1, solution) right = evaluate_expression(node.arg2, solution) @@ -454,6 +665,27 @@ def _to_numeric(val): return None +def _to_datetime(val): + """Convert a value to a date or datetime object.""" + if val is None: + return None + s = _to_string(val) + for fmt in ( + "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z", + "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M", + "%Y-%m-%d", + ): + try: + return datetime.strptime(s, fmt) + except ValueError: + continue + try: + return datetime.fromisoformat(s) + except ValueError: + pass + return None + + def _comparable_value(val): """ Convert a value to a form suitable for comparison. diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 983cd4f6..bbe375f0 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import TriplesClientSpec from . parser import parse_sparql, ParseError -from . algebra import evaluate, EvaluationError +from . algebra import evaluate, materialise, EvaluationError logger = logging.getLogger(__name__) @@ -66,11 +66,10 @@ class Processor(FlowProcessor): logger.debug(f"Handling SPARQL query request {id}...") - response = await self.execute_sparql(request, flow) - - if request.streaming and response.query_type == "select": - await self.send_streaming(response, flow, id, request) + if request.streaming: + await self.execute_sparql_streaming(request, flow, id) else: + response = await self.execute_sparql(request, flow) await flow("response").send( response, properties={"id": id} ) @@ -92,37 +91,77 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def send_streaming(self, response, flow, id, request): - """Send SELECT results in batches.""" + async def execute_sparql_streaming(self, request, flow, id): + """Execute a SPARQL query and stream results as they arrive.""" - bindings = response.bindings + try: + parsed = parse_sparql(request.query) + except ParseError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ), + properties={"id": id} + ) + return + + if parsed.query_type != "select": + response = await self._execute_non_select(parsed, request, flow) + await flow("response").send(response, properties={"id": id}) + return + + triples_client = flow("triples-request") + variables = parsed.variables batch_size = request.batch_size if request.batch_size > 0 else 20 + batch = [] - for i in range(0, len(bindings), batch_size): - batch = bindings[i:i + batch_size] - is_final = (i + batch_size >= len(bindings)) - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=batch, - is_final=is_final, - ) - await flow("response").send(r, properties={"id": id}) + try: + async for sol in evaluate( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ): + values = [sol.get(v) for v in variables] + batch.append(SparqlBinding(values=values)) - # Handle empty results - if len(bindings) == 0: - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=[], - is_final=True, + if len(batch) >= batch_size: + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=False, + ) + await flow("response").send(r, properties={"id": id}) + batch = [] + + except EvaluationError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ), + properties={"id": id} ) - await flow("response").send(r, properties={"id": id}) + return + + # Final batch (may be empty for zero results) + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) async def execute_sparql(self, request, flow): - """Parse and evaluate a SPARQL query.""" + """Parse and evaluate a SPARQL query (non-streaming).""" - # Parse the SPARQL query try: parsed = parse_sparql(request.query) except ParseError as e: @@ -133,15 +172,33 @@ class Processor(FlowProcessor): ), ) - # Get the triples client from the flow - triples_client = flow("triples-request") + if parsed.query_type == "select": + triples_client = flow("triples-request") + try: + solutions = await materialise( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + return self._build_select_response(parsed, solutions) - # Evaluate the algebra + return await self._execute_non_select(parsed, request, flow) + + async def _execute_non_select(self, parsed, request, flow): + """Execute ASK, CONSTRUCT, or DESCRIBE queries.""" + triples_client = flow("triples-request") try: - solutions = await evaluate( + solutions = await materialise( parsed.algebra, triples_client, - workspace=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, ) @@ -153,10 +210,7 @@ class Processor(FlowProcessor): ), ) - # Build response based on query type - if parsed.query_type == "select": - return self._build_select_response(parsed, solutions) - elif parsed.query_type == "ask": + if parsed.query_type == "ask": return self._build_ask_response(solutions) elif parsed.query_type == "construct": return self._build_construct_response(parsed, solutions) diff --git a/trustgraph-flow/trustgraph/query/sparql/solutions.py b/trustgraph-flow/trustgraph/query/sparql/solutions.py index d1ea8373..edf3401d 100644 --- a/trustgraph-flow/trustgraph/query/sparql/solutions.py +++ b/trustgraph-flow/trustgraph/query/sparql/solutions.py @@ -150,6 +150,30 @@ def left_join(left, right, filter_fn=None): return results +def minus(left, right): + """ + MINUS operation: remove left solutions that are compatible with any + right solution sharing at least one variable. + """ + if not right: + return list(left) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + results = [] + for sol_l in left: + shared = set(sol_l.keys()) & right_vars + if not shared: + results.append(sol_l) + continue + if not any(_compatible(sol_l, sol_r) for sol_r in right): + results.append(sol_l) + + return results + + def union(left, right): """Union two solution sequences (concatenation).""" return list(left) + list(right) @@ -177,6 +201,28 @@ def distinct(solutions): return results +def _sort_comparable(val): + """Convert a value to a form suitable for sort ordering.""" + if val is None: + return (0, "") + if isinstance(val, (int, float)): + return (2, val) + if isinstance(val, Term): + if val.type == LITERAL: + try: + if "." in val.value: + return (2, float(val.value)) + return (2, int(val.value)) + except (ValueError, TypeError): + pass + return (3, val.value) + elif val.type == IRI: + return (4, val.iri) + elif val.type == BLANK: + return (5, val.id) + return (6, str(val)) + + def order_by(solutions, key_fns): """ Sort solutions by the given key functions. @@ -191,14 +237,7 @@ def order_by(solutions, key_fns): keys = [] for fn, ascending in key_fns: val = fn(sol) - # Convert to comparable form - if val is None: - comparable = ("", "") - elif isinstance(val, Term): - comparable = _term_key(val) - else: - comparable = ("v", str(val)) - keys.append(comparable) + keys.append(_sort_comparable(val)) return keys # Handle ascending/descending @@ -224,10 +263,8 @@ def _mixed_sort(solutions, key_fns): def compare(a, b): for fn, ascending in key_fns: - va = fn(a) - vb = fn(b) - ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "") - kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "") + ka = _sort_comparable(fn(a)) + kb = _sort_comparable(fn(b)) if ka < kb: return -1 if ascending else 1 diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index a9bdbbac..822dba25 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -6,8 +6,8 @@ null. Output is a list of quads. import asyncio import logging - import json + from cassandra.query import SimpleStatement from .... direct.cassandra_kg import ( @@ -176,45 +176,42 @@ class Processor(TriplesQueryService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - def ensure_connection(self, workspace): - """Ensure we have a connection to the correct keyspace.""" - if workspace != self.table: - KGClass = EntityCentricKnowledgeGraph + self._connections = {} + self._conn_lock = asyncio.Lock() - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - self.table = workspace + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] async def query_triples(self, workspace, query): try: - # ensure_connection may construct a fresh - # EntityCentricKnowledgeGraph which does sync schema - # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per workspace. - await asyncio.to_thread(self.ensure_connection, workspace) - - # Extract values from query s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) - g_val = query.g # Already a string or None + g_val = query.g + + tg = await self._get_connection(workspace) def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), @@ -223,33 +220,21 @@ class Processor(TriplesQueryService): quads = [] - # All self.tg.get_* calls below are sync wrappers around - # cassandra session.execute. Materialise inside a worker - # thread so iteration never triggers sync paging back on - # the event loop. - - # Route to appropriate query method based on which fields are specified if s_val is not None: if p_val is not None: if o_val is not None: - # SPO specified - find matching graphs - resp = await asyncio.to_thread( - lambda: list(self.tg.get_spo( - query.collection, s_val, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_spo( + query.collection, s_val, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: - # SP specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_sp( - query.collection, s_val, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_sp( + query.collection, s_val, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -257,24 +242,18 @@ class Processor(TriplesQueryService): quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # SO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_os( - query.collection, o_val, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_os( + query.collection, o_val, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: - # S only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_s( - query.collection, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_s( + query.collection, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -283,24 +262,18 @@ class Processor(TriplesQueryService): else: if p_val is not None: if o_val is not None: - # PO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_po( - query.collection, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_po( + query.collection, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: - # P only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_p( - query.collection, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_p( + query.collection, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -308,40 +281,26 @@ class Processor(TriplesQueryService): quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # O only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_o( - query.collection, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_o( + query.collection, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: - # Nothing specified - get all - resp = await asyncio.to_thread( - lambda: list(self.tg.get_all( - query.collection, limit=query.limit, - )) + resp = await tg.async_get_all( + query.collection, limit=query.limit, ) for t in resp: - # Note: quads_by_collection uses 'd' for graph field g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH - # Filter by graph - # g_val=None means all graphs (no filter) - # g_val="" means default graph only - # otherwise filter to specific named graph if g_val is not None: if g != g_val: continue term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) - # Convert to Triple objects (with g field) - # s and p are always IRIs in RDF - # Object uses term_type/datatype/language metadata from database triples = [ Triple( s=create_term(q[0], term_type='u'), @@ -365,51 +324,41 @@ class Processor(TriplesQueryService): Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 - # Extract query pattern s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) g_val = query.g def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), getattr(row, 'lang', None), ) - # For streaming, we need to execute with fetch_size - # Use the collection table for get_all queries (most common streaming case) - - # Determine which query to use based on pattern if s_val is None and p_val is None and o_val is None: - # Get all - use collection table with paging - cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" + + tg = await self._get_connection(workspace) + + cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s" params = [query.collection] + statement = SimpleStatement(cql, fetch_size=batch_size) + # async_execute only materialises the first page; + # this query needs all pages, so use sync execute + # in a worker thread where page iteration can block. + result_set = await asyncio.to_thread( + lambda: list(tg.session.execute(statement, params)) + ) + else: - # For specific patterns, fall back to non-streaming - # (these typically return small result sets anyway) async for batch, is_final in self._fallback_stream(workspace, query, batch_size): yield batch, is_final return - # Materialise in a worker thread. We lose true streaming - # paging (the driver fetches all pages eagerly inside the - # thread) but the event loop stays responsive, and result - # sets at this layer are typically small enough that this - # is acceptable. If true async paging is needed later, - # revisit using ResponseFuture page callbacks. - statement = SimpleStatement(cql, fetch_size=batch_size) - result_set = await asyncio.to_thread( - lambda: list(self.tg.session.execute(statement, params)) - ) - batch = [] count = 0 diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py index 986558ec..5e71a0e5 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -1,130 +1,140 @@ import asyncio import logging import uuid -from typing import Dict, Any, Optional -from trustgraph.messaging import TranslatorRegistry +from typing import Dict, Any, Optional, Callable, Awaitable from ..gateway.dispatch.manager import DispatcherManager logger = logging.getLogger("dispatcher") -logger.setLevel(logging.INFO) -class WebSocketResponder: - """Simple responder that captures response for websocket return""" - def __init__(self): - self.response = None - self.completed = False - - async def send(self, data): - """Capture the response data""" - self.response = data - self.completed = True - - async def __call__(self, data, final=False): - """Make the responder callable for compatibility with requestor""" - await self.send(data) - if final: - self.completed = True + +class _TokenShim: + def __init__(self, token): + self.headers = ( + {"Authorization": f"Bearer {token}"} if token else {} + ) + class MessageDispatcher: - def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): + def __init__(self, max_workers=10, config_receiver=None, backend=None, + auth=None, timeout=120): self.max_workers = max_workers self.semaphore = asyncio.Semaphore(max_workers) self.active_tasks = set() self.backend = backend + self.auth = auth - # Use DispatcherManager for flow and service management - if backend and config_receiver: - self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") + if backend and config_receiver and auth: + self.dispatcher_manager = DispatcherManager( + backend, config_receiver, + auth=auth, + prefix="rev-gateway", + timeout=timeout, + ) else: self.dispatcher_manager = None - logger.warning("No backend or config_receiver provided - using fallback mode") - - # Service name mapping from websocket protocol to translator registry + logger.warning( + "Missing backend, config_receiver, or auth " + "— using fallback mode" + ) + self.service_mapping = { "text-completion": "text-completion", - "graph-rag": "graph-rag", + "graph-rag": "graph-rag", "agent": "agent", "embeddings": "embeddings", "graph-embeddings": "graph-embeddings", "triples": "triples", "document-load": "document", - "text-load": "text-document", + "text-load": "text-document", "flow": "flow", "knowledge": "knowledge", "config": "config", "librarian": "librarian", - "document-rag": "document-rag" + "document-rag": "document-rag", } - - async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: + + async def handle_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): async with self.semaphore: - task = asyncio.create_task(self._process_message(message)) + task = asyncio.create_task( + self._process_message(message, sender) + ) self.active_tasks.add(task) - + try: - result = await task - return result + await task finally: self.active_tasks.discard(task) - - async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: + + async def _authenticate(self, token): + if not self.auth: + raise RuntimeError("Auth not configured") + return await self.auth.authenticate(_TokenShim(token)) + + async def _process_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): request_id = message.get('id', str(uuid.uuid4())) service = message.get('service') request_data = message.get('request', {}) - flow_id = message.get('flow', 'default') # Default flow - - logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}") - + token = message.get('token', '') + flow_id = message.get('flow', 'default') + + logger.info( + f"Processing message {request_id} for service " + f"{service} on flow {flow_id}" + ) + try: if not self.dispatcher_manager: - raise RuntimeError("DispatcherManager not available - backend and config_receiver required") - - # Use DispatcherManager for flow-based processing - responder = WebSocketResponder() - - # Map websocket service name to dispatcher service name + raise RuntimeError( + "DispatcherManager not available" + ) + + identity = await self._authenticate(token) + workspace = identity.workspace + + async def responder(resp, fin): + await sender({ + "id": request_id, + "response": resp, + "complete": fin, + }) + dispatcher_service = self.service_mapping.get(service, service) - - # Check if this is a global service or flow service + from ..gateway.dispatch.manager import global_dispatchers if dispatcher_service in global_dispatchers: - # Use global service dispatcher await self.dispatcher_manager.invoke_global_service( - request_data, responder, dispatcher_service + request_data, responder, dispatcher_service, + workspace=workspace, ) else: - # Use DispatcherManager to process the request through Pulsar queues await self.dispatcher_manager.invoke_flow_service( - request_data, responder, flow_id, dispatcher_service + request_data, responder, workspace, flow_id, + dispatcher_service, ) - - # Get the response from the responder - if responder.completed: - response_data = responder.response - else: - response_data = {'error': 'No response received'} - - response = { - 'id': request_id, - 'response': response_data - } - + except Exception as e: logger.error(f"Error processing message {request_id}: {e}") - response = { - 'id': request_id, - 'response': {'error': str(e)} - } - + await sender({ + "id": request_id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) + logger.info(f"Completed processing message {request_id}") - return response - - + async def shutdown(self): if self.active_tasks: - logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") + logger.info( + f"Waiting for {len(self.active_tasks)} active " + f"tasks to complete" + ) await asyncio.gather(*self.active_tasks, return_exceptions=True) - - # DispatcherManager handles its own cleanup + logger.info("Dispatcher shutdown complete") diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py index cc905172..39180e2c 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/service.py +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -1,3 +1,9 @@ +""" +Reverse gateway. Initiates outbound WebSocket connections to a remote +relay and dispatches incoming requests through the same DispatcherManager +pipeline as api-gateway. +""" + import asyncio import argparse import logging @@ -9,82 +15,86 @@ from typing import Optional from urllib.parse import urlparse, urlunparse from .dispatcher import MessageDispatcher +from ..gateway.auth import IamAuth from ..gateway.config.receiver import ConfigReceiver -from ..base import get_pubsub +from ..base.pubsub import get_pubsub, add_pubsub_args +from ..base.logging import setup_logging, add_logging_args logger = logging.getLogger("rev_gateway") -logger.setLevel(logging.INFO) default_websocket = "ws://localhost:7650/out" +default_timeout = 600 class ReverseGateway: - - def __init__(self, websocket_uri: str = None, max_workers: int = 10, - pulsar_host: str = None, pulsar_api_key: str = None, - pulsar_listener: str = None): - # Set default WebSocket URI with environment variable support + + def __init__(self, **config): + websocket_uri = config.get("websocket_uri") if websocket_uri is None: websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) - - # Parse and validate the WebSocket URI + parsed_uri = urlparse(websocket_uri) if parsed_uri.scheme not in ('ws', 'wss'): - raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") + raise ValueError( + f"WebSocket URI must use ws:// or wss:// scheme, " + f"got: {parsed_uri.scheme}" + ) if not parsed_uri.netloc: - raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") - - # Store parsed components for debugging/logging + raise ValueError( + f"WebSocket URI must include hostname, " + f"got: {websocket_uri}" + ) + self.websocket_uri = websocket_uri self.host = parsed_uri.hostname self.port = parsed_uri.port self.scheme = parsed_uri.scheme self.path = parsed_uri.path or "/ws" - - # Construct the full URL (in case path was missing) + if not parsed_uri.path: self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" else: self.url = websocket_uri - - self.max_workers = max_workers + + self.max_workers = int(config.get("max_workers", 10)) + self.timeout = int(config.get("timeout", default_timeout)) self.ws: Optional[ClientWebSocketResponse] = None self.session: Optional[ClientSession] = None self.running = False self.reconnect_delay = 3.0 - - # Pulsar configuration - self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") - self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) - self.pulsar_listener = pulsar_listener - # Create backend using factory - backend_params = { - 'pulsar_host': self.pulsar_host, - 'pulsar_api_key': self.pulsar_api_key, - 'pulsar_listener': self.pulsar_listener, - } - self.backend = get_pubsub(**backend_params) + self.backend = get_pubsub(**config) - # Initialize config receiver - self.config_receiver = ConfigReceiver(self.backend) + self.auth = IamAuth( + backend=self.backend, + id=config.get("id", "rev-gateway"), + ) + + self.config_receiver = ConfigReceiver( + self.backend, auth=self.auth, + ) + + self.dispatcher = MessageDispatcher( + self.max_workers, self.config_receiver, self.backend, + auth=self.auth, timeout=self.timeout, + ) - # Initialize dispatcher with config_receiver and backend - must be created after config_receiver - self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend) - async def connect(self) -> bool: try: if self.session is None: self.session = ClientSession() - + logger.info(f"Connecting to {self.url}") self.ws = await self.session.ws_connect(self.url) - logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") + logger.info( + f"WebSocket connection established to " + f"{self.host}:{self.port or 'default'}" + ) return True - + except Exception as e: logger.error(f"Failed to connect to {self.url}: {e}") return False - + async def disconnect(self): if self.ws and not self.ws.closed: await self.ws.close() @@ -92,32 +102,31 @@ class ReverseGateway: await self.session.close() self.ws = None self.session = None - + async def send_message(self, message: dict): if self.ws and not self.ws.closed: try: await self.ws.send_str(json.dumps(message)) except Exception as e: logger.error(f"Failed to send message: {e}") - + async def handle_message(self, message: str): try: logger.debug(f"Received message: {message}") - + msg_data = json.loads(message) - response = await self.dispatcher.handle_message(msg_data) - - if response: - await self.send_message(response) - + await self.dispatcher.handle_message( + msg_data, self.send_message, + ) + except Exception as e: logger.error(f"Error handling message: {e}") - + async def listen(self): while self.running and self.ws and not self.ws.closed: try: msg = await self.ws.receive() - + if msg.type == WSMsgType.TEXT: await self.handle_message(msg.data) elif msg.type == WSMsgType.BINARY: @@ -125,31 +134,33 @@ class ReverseGateway: elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): logger.warning("WebSocket closed or error occurred") break - + except Exception as e: logger.error(f"Error in listen loop: {e}") break - + async def run(self): self.running = True logger.info("Starting reverse gateway") - - # Start config receiver - logger.info("Starting config receiver") + + await self.auth.start() await self.config_receiver.start() - + while self.running: try: if await self.connect(): await self.listen() else: - logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") - + logger.warning( + f"Connection failed, retrying in " + f"{self.reconnect_delay} seconds" + ) + await self.disconnect() - + if self.running: await asyncio.sleep(self.reconnect_delay) - + except KeyboardInterrupt: logger.info("Shutdown requested") break @@ -157,77 +168,69 @@ class ReverseGateway: logger.error(f"Unexpected error: {e}") if self.running: await asyncio.sleep(self.reconnect_delay) - + await self.shutdown() - + async def shutdown(self): logger.info("Shutting down reverse gateway") self.running = False await self.dispatcher.shutdown() await self.disconnect() - # Close backend if hasattr(self, 'backend'): self.backend.close() - + def stop(self): self.running = False -def parse_args(): + +def run(): + parser = argparse.ArgumentParser( prog="reverse-gateway", - description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" + description=__doc__, ) - + + parser.add_argument( + '--id', + default='rev-gateway', + help='Service identifier (default: rev-gateway)', + ) + parser.add_argument( '--websocket-uri', default=None, - help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' + help=f'WebSocket URI to connect to (default: {default_websocket})', ) - + parser.add_argument( '--max-workers', type=int, default=10, - help='Maximum concurrent message handlers (default: 10)' + help='Maximum concurrent message handlers (default: 10)', ) - - parser.add_argument( - '-p', '--pulsar-host', - default=None, - help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)' - ) - - parser.add_argument( - '--pulsar-api-key', - default=None, - help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)' - ) - - parser.add_argument( - '--pulsar-listener', - default=None, - help='Pulsar listener name' - ) - - return parser.parse_args() -def run(): - args = parse_args() - - gateway = ReverseGateway( - websocket_uri=args.websocket_uri, - max_workers=args.max_workers, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener + parser.add_argument( + '--timeout', + type=int, + default=default_timeout, + help=f'Request timeout in seconds (default: {default_timeout})', ) - + + add_pubsub_args(parser) + add_logging_args(parser) + + args = parser.parse_args() + args = vars(args) + + setup_logging(args) + + gateway = ReverseGateway(**args) + logger.info(f"Starting reverse gateway:") logger.info(f" WebSocket URI: {gateway.url}") - logger.info(f" Max workers: {args.max_workers}") - logger.info(f" Pulsar host: {gateway.pulsar_host}") - + logger.info(f" Max workers: {gateway.max_workers}") + try: asyncio.run(gateway.run()) except KeyboardInterrupt: diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index fb7166b5..2bfef99c 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_document_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"d_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "chunk_id": chunk_id, } ) - ] + ], ) @staticmethod @@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): try: prefix = f"d_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 391c2a04..13dcdba8 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_graph_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"t_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) payload = { "entity": entity_value, @@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity.chunk_id: payload["chunk_id"] = entity.chunk_id - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): vector=vec, payload=payload, ) - ] + ], ) @staticmethod @@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): try: prefix = f"t_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 32d87871..a01629c5 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -16,10 +16,10 @@ Payload structure: - text: The text that was embedded (for debugging/display) """ +import asyncio import logging import re import uuid -from typing import Set, Tuple from qdrant_client import QdrantClient from qdrant_client.models import PointStruct, Distance, VectorParams @@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register config handler for collection management self.register_config_handler(self.on_collection_config, types=["collection"]) - # Cache of created Qdrant collections - self.created_collections: Set[str] = set() - - # Qdrant client self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() def sanitize_name(self, name: str) -> str: """Sanitize names for Qdrant collection naming""" @@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor): safe_schema = self.sanitize_name(schema_name) return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" - def ensure_collection(self, collection_name: str, dimension: int): + async def ensure_collection(self, collection_name: str, dimension: int): """Create Qdrant collection if it doesn't exist""" - if collection_name in self.created_collections: - return - - if not self.qdrant.collection_exists(collection_name): - logger.info( - f"Creating Qdrant collection {collection_name} " - f"with dimension {dimension}" + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name ) - self.qdrant.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=dimension, - distance=Distance.COSINE + if not exists: + logger.info( + f"Creating Qdrant collection {collection_name} " + f"with dimension {dimension}" ) - ) - - self.created_collections.add(collection_name) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dimension, + distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) async def on_embeddings(self, msg, consumer, flow): """Process incoming RowEmbeddings and write to Qdrant""" @@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): dimension = len(vector) - # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( workspace, collection, schema_name, dimension ) - self.ensure_collection(qdrant_collection, dimension) + await self.ensure_collection(qdrant_collection, dimension) - # Write to Qdrant - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=qdrant_collection, points=[ PointStruct( @@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): "text": row_emb.text } ) - ] + ], ) embeddings_written += 1 @@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): try: prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info( f"Deleted {len(matching_collections)} collection(s) " @@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") except Exception as e: diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index a5dad748..65eeee06 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Cache of known keyspaces and whether tables exist self.known_keyspaces: Set[str] = set() - self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables + self.tables_initialized: Set[str] = set() # Cache of registered (collection, schema_name) pairs self.registered_partitions: Set[Tuple[str, str]] = set() @@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.cluster = None self.session = None + # Protects connection setup and cache mutations + self._setup_lock = asyncio.Lock() + def connect_cassandra(self): """Connect to Cassandra cluster""" if self.session: @@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"for workspace {workspace}" ) + async with self._setup_lock: + return await self._apply_schema_config(workspace, config, version) + + async def _apply_schema_config(self, workspace, config, version): + # Track which schemas changed in this workspace old_schemas = self.schemas.get(workspace, {}) old_schema_names = set(old_schemas.keys()) @@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor): schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' - # Ensure tables exist (sync DDL — push to a worker thread - # so the event loop stays responsive when running in a - # processor group sharing the loop with siblings). - await asyncio.to_thread(self.ensure_tables, keyspace) - - # Register partitions if first time seeing this (collection, schema_name) - await asyncio.to_thread( - self.register_partitions, - keyspace, collection, schema_name, workspace, - ) + async with self._setup_lock: + await asyncio.to_thread(self.ensure_tables, keyspace) + await asyncio.to_thread( + self.register_partitions, + keyspace, collection, schema_name, workspace, + ) safe_keyspace = self.sanitize_name(keyspace) @@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" - # Connect if not already connected (sync, push to thread) - await asyncio.to_thread(self.connect_cassandra) - - # Ensure tables exist (sync DDL, push to thread) - await asyncio.to_thread(self.ensure_tables, workspace) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + await asyncio.to_thread(self.ensure_tables, workspace) logger.info(f"Collection {collection} ready for workspace {workspace}") async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection using partition tracking""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + if workspace not in self.known_keyspaces: + safe_ks = self.sanitize_name(workspace) + check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s" + result = await async_execute(self.session, check_cql, (safe_ks,)) + if not result: + logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete") + return + self.known_keyspaces.add(workspace) safe_keyspace = self.sanitize_name(workspace) - # Check if keyspace exists - if workspace not in self.known_keyspaces: - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = await async_execute( - self.session, check_keyspace_cql, (safe_keyspace,) - ) - if not result: - logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") - return - self.known_keyspaces.add(workspace) - # Discover all partitions for this collection select_partitions_cql = f""" SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions @@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to clean up row_partitions for {collection}: {e}") raise - # Clear from local cache - self.registered_partitions = { - (col, sch) for col, sch in self.registered_partitions - if col != collection - } + async with self._setup_lock: + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if col != collection + } logger.info( f"Deleted collection {collection}: {partitions_deleted} partitions " @@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) safe_keyspace = self.sanitize_name(workspace) @@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) raise - # Clear from local cache - self.registered_partitions.discard((collection, schema_name)) + async with self._setup_lock: + self.registered_partitions.discard((collection, schema_name)) logger.info( f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 0774153b..79d6c549 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ import asyncio -import base64 -import os -import argparse -import time import logging -import json from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, DEFAULT_GRAPH @@ -28,6 +23,8 @@ default_ident = "triples-write" def serialize_triple(triple): """Serialize a Triple object to JSON for storage.""" + import json + if triple is None: return None @@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - self.tg = None + + self._connections = {} + self._conn_lock = asyncio.Lock() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] + async def store_triples(self, workspace, message): - # The cassandra-driver work below — connection, schema - # setup, and per-triple inserts — is all synchronous. - # Wrap the whole batch in a worker thread so the event - # loop stays responsive for sibling processors when - # running in a processor group. + tg = await self._get_connection(workspace) - def _do_store(): + for t in message.triples: + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + g_val = t.g if t.g is not None else DEFAULT_GRAPH - if self.table is None or self.table != workspace: + otype = get_term_otype(t.o) + dtype = get_term_dtype(t.o) + lang = get_term_lang(t.o) - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Exception: {e}", exc_info=True) - time.sleep(1) - raise e - - self.table = workspace - - for t in message.triples: - # Extract values from Term objects - s_val = get_term_value(t.s) - p_val = get_term_value(t.p) - o_val = get_term_value(t.o) - # t.g is None for default graph, or a graph IRI - g_val = t.g if t.g is not None else DEFAULT_GRAPH - - # Extract object type metadata for entity-centric storage - otype = get_term_otype(t.o) - dtype = get_term_dtype(t.o) - lang = get_term_lang(t.o) - - self.tg.insert( - message.metadata.collection, - s_val, - p_val, - o_val, - g=g_val, - otype=otype, - dtype=dtype, - lang=lang, - ) - - await asyncio.to_thread(_do_store) + await tg.async_insert( + message.metadata.collection, + s_val, + p_val, + o_val, + g=g_val, + otype=otype, + dtype=dtype, + lang=lang, + ) async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" + try: + tg = await self._get_connection(workspace) - def _do_create(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Create collection using the built-in method logger.info(f"Creating collection {collection} for workspace {workspace}") - if self.tg.collection_exists(collection): + exists = await tg.async_collection_exists(collection) + if exists: logger.info(f"Collection {collection} already exists") else: - self.tg.create_collection(collection) + await tg.async_create_collection(collection) logger.info(f"Created collection {collection}") - try: - await asyncio.to_thread(_do_create) except Exception as e: logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection from the unified triples table""" + try: + tg = await self._get_connection(workspace) - def _do_delete(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Delete all triples for this collection using the built-in method - self.tg.delete_collection(collection) + await tg.async_delete_collection(collection) logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") - try: - await asyncio.to_thread(_do_delete) except Exception as e: logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py index 8bf9c8b4..d7bf5e3d 100644 --- a/trustgraph-flow/trustgraph/tables/iam.py +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -435,3 +435,7 @@ class IamTableStore: async def any_workspace_exists(self): rows = await self.list_workspaces() return bool(rows) + + async def any_signing_key_exists(self): + rows = await self.list_signing_keys() + return bool(rows) diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 5d45358d..cf085fdd 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -1,6 +1,7 @@ from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings +from .. schema import DocumentEmbeddings, ChunkEmbeddings from cassandra.cluster import Cluster @@ -217,6 +218,16 @@ class KnowledgeTableStore: WHERE workspace = ? AND document_id = ? """) + self.delete_document_embeddings_stmt = self.cassandra.prepare(""" + DELETE FROM document_embeddings + WHERE workspace = ? AND document_id = ? + """) + + self.list_de_cores_stmt = self.cassandra.prepare(""" + SELECT DISTINCT workspace, document_id FROM document_embeddings + WHERE workspace = ? + """) + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -338,6 +349,50 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def delete_document_embeddings(self, workspace, document_id): + + logger.debug("Delete document embeddings...") + + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def list_de_cores(self, workspace): + + logger.debug("List DE cores...") + + try: + rows = await async_execute( + self.cassandra, + self.list_de_cores_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + lst = [row[1] for row in rows] + + logger.debug("Done") + + return lst + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -417,3 +472,42 @@ class KnowledgeTableStore: logger.debug("Done") + async def get_document_embeddings(self, workspace, document_id, receiver): + + logger.debug("Get DE...") + + try: + rows = await async_execute( + self.cassandra, + self.get_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + for row in rows: + + if row[3]: + chunks = [ + ChunkEmbeddings( + chunk_id=ch[0], + vector=ch[1], + ) + for ch in row[3] + ] + else: + chunks = [] + + await receiver( + DocumentEmbeddings( + metadata = Metadata( + id = document_id, + collection = "default", + ), + chunks = chunks + ) + ) + + logger.debug("Done") + diff --git a/trustgraph-flow/trustgraph/template/prompt_manager.py b/trustgraph-flow/trustgraph/template/prompt_manager.py index 546a7faf..976d3695 100644 --- a/trustgraph-flow/trustgraph/template/prompt_manager.py +++ b/trustgraph-flow/trustgraph/template/prompt_manager.py @@ -31,12 +31,12 @@ class PromptManager: try: system = json.loads(config["system"]) - except: + except (KeyError, TypeError, json.JSONDecodeError): system = "Be helpful." try: ix = json.loads(config["template-index"]) - except: + except (KeyError, TypeError, json.JSONDecodeError): ix = [] prompts = {} @@ -68,8 +68,8 @@ class PromptManager: try: self.system_template = ibis.Template(self.config.system_template) - except: - raise RuntimeError("Error in system template") + except Exception as e: + raise RuntimeError(f"Error in system template: {e}") self.templates = {} for k, v in self.prompts.items(): @@ -136,8 +136,6 @@ class PromptManager: terms = self.terms | self.prompts[id].terms | input - resp_type = self.prompts[id].response_type - return self.templates[id].render(terms) async def invoke(self, id, input, llm): @@ -161,7 +159,7 @@ class PromptManager: if resp_type == "json": try: obj = self.parse_json(resp) - except: + except (json.JSONDecodeError, TypeError): logger.error(f"JSON parse failed: {resp}") raise RuntimeError("JSON parse fail") @@ -195,4 +193,3 @@ class PromptManager: return objects raise RuntimeError(f"Response type {resp_type} not known") - diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index 1718258f..fa9f7cd4 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index 7169fc8b..f17b9812 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", "pulsar-client", "prometheus-client", "python-magic", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index f43f154d..347594fe 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", "pulsar-client", "google-genai", "google-api-core", diff --git a/trustgraph/pyproject.toml b/trustgraph/pyproject.toml index dc896700..bcc72a41 100644 --- a/trustgraph/pyproject.toml +++ b/trustgraph/pyproject.toml @@ -10,13 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.4,<2.5", - "trustgraph-bedrock>=2.4,<2.5", - "trustgraph-cli>=2.4,<2.5", - "trustgraph-embeddings-hf>=2.4,<2.5", - "trustgraph-flow>=2.4,<2.5", - "trustgraph-unstructured>=2.4,<2.5", - "trustgraph-vertexai>=2.4,<2.5", + "trustgraph-base>=2.5,<2.6", + "trustgraph-bedrock>=2.5,<2.6", + "trustgraph-cli>=2.5,<2.6", + "trustgraph-embeddings-hf>=2.5,<2.6", + "trustgraph-flow>=2.5,<2.6", + "trustgraph-unstructured>=2.5,<2.6", + "trustgraph-vertexai>=2.5,<2.6", ] classifiers = [ "Programming Language :: Python :: 3",