diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 359a8c72..28b21772 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -15,14 +15,14 @@ jobs: runs-on: ubuntu-latest container: - image: python:3.12 + image: python:3.13 steps: - name: Checkout uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=1.4.999 + run: make update-package-versions VERSION=1.5.999 - name: Setup environment run: python3 -m venv env diff --git a/Containerfile b/Containerfile deleted file mode 100644 index 7283a06b..00000000 --- a/Containerfile +++ /dev/null @@ -1,84 +0,0 @@ - -# ---------------------------------------------------------------------------- -# Build an AI container. This does the torch install which is huge, and I -# like to avoid re-doing this. -# ---------------------------------------------------------------------------- - -FROM docker.io/fedora:40 AS ai - -ENV PIP_BREAK_SYSTEM_PACKAGES=1 - -RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ - python3-rdflib - -RUN pip3 install torch==2.5.1+cpu \ - --index-url https://download.pytorch.org/whl/cpu - -RUN pip3 install \ - anthropic boto3 cohere mistralai openai google-cloud-aiplatform \ - ollama google-generativeai \ - langchain==0.3.13 langchain-core==0.3.28 langchain-huggingface==0.1.2 \ - langchain-text-splitters==0.3.4 \ - langchain-community==0.3.13 \ - sentence-transformers==3.4.0 transformers==4.47.1 \ - huggingface-hub==0.27.0 \ - pymilvus \ - pulsar-client==3.5.0 cassandra-driver pyyaml \ - neo4j tiktoken falkordb && \ - pip3 cache purge - -# Most commonly used embeddings model, just build it into the container -# image -RUN huggingface-cli download sentence-transformers/all-MiniLM-L6-v2 - -# ---------------------------------------------------------------------------- -# Build a container which contains the built Python packages. The build -# creates a bunch of left-over cruft, a separate phase means this is only -# needed to support package build -# ---------------------------------------------------------------------------- - -FROM ai AS build - -COPY trustgraph-base/ /root/build/trustgraph-base/ -COPY trustgraph-flow/ /root/build/trustgraph-flow/ -COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/ -COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/ -COPY trustgraph-embeddings-hf/ /root/build/trustgraph-embeddings-hf/ -COPY trustgraph-cli/ /root/build/trustgraph-cli/ -COPY trustgraph-ocr/ /root/build/trustgraph-ocr/ - -WORKDIR /root/build/ - -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-flow/ -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/ -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/ -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-embeddings-hf/ -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/ -RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-ocr/ - -RUN ls /root/wheels - -# ---------------------------------------------------------------------------- -# Finally, the target container. Start with base and add the package. -# ---------------------------------------------------------------------------- - -FROM ai - -COPY --from=build /root/wheels /root/wheels - -RUN \ - pip3 install /root/wheels/trustgraph_base-* && \ - pip3 install /root/wheels/trustgraph_flow-* && \ - pip3 install /root/wheels/trustgraph_vertexai-* && \ - pip3 install /root/wheels/trustgraph_bedrock-* && \ - pip3 install /root/wheels/trustgraph_embeddings_hf-* && \ - pip3 install /root/wheels/trustgraph_cli-* && \ - pip3 install /root/wheels/trustgraph_ocr-* && \ - pip3 cache purge && \ - rm -rf /root/wheels - -WORKDIR / - -CMD sleep 1000000 - diff --git a/containers/Containerfile.base b/containers/Containerfile.base index 067b4c2c..fa5d653c 100644 --- a/containers/Containerfile.base +++ b/containers/Containerfile.base @@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ +RUN dnf install -y python3.13 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ diff --git a/containers/Containerfile.bedrock b/containers/Containerfile.bedrock index a35d12ad..b9ab99ac 100644 --- a/containers/Containerfile.bedrock +++ b/containers/Containerfile.bedrock @@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ +RUN dnf install -y python3.13 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ diff --git a/containers/Containerfile.flow b/containers/Containerfile.flow index 2ffa17d3..8b20050d 100644 --- a/containers/Containerfile.flow +++ b/containers/Containerfile.flow @@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ +RUN dnf install -y python3.13 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp rdflib && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ @@ -22,7 +22,7 @@ RUN pip3 install --no-cache-dir \ langchain-text-splitters==0.3.8 \ langchain-community==0.3.24 \ pymilvus \ - pulsar-client==3.7.0 cassandra-driver pyyaml \ + pulsar-client==3.7.0 scylla-driver pyyaml \ neo4j tiktoken falkordb && \ pip3 cache purge diff --git a/containers/Containerfile.hf b/containers/Containerfile.hf index b76179ff..351300ae 100644 --- a/containers/Containerfile.hf +++ b/containers/Containerfile.hf @@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS ai ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ +RUN dnf install -y python3.13 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ diff --git a/containers/Containerfile.mcp b/containers/Containerfile.mcp index 2377a663..389b919e 100644 --- a/containers/Containerfile.mcp +++ b/containers/Containerfile.mcp @@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ +RUN dnf install -y python3.13 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir mcp websockets && \ dnf clean all diff --git a/containers/Containerfile.ocr b/containers/Containerfile.ocr index bb1f3ae2..41655e42 100644 --- a/containers/Containerfile.ocr +++ b/containers/Containerfile.ocr @@ -8,9 +8,9 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ +RUN dnf install -y python3.13 && \ dnf install -y tesseract poppler-utils && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ diff --git a/containers/Containerfile.vertexai b/containers/Containerfile.vertexai index 9a4bd15f..bf0e55da 100644 --- a/containers/Containerfile.vertexai +++ b/containers/Containerfile.vertexai @@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.12 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ +RUN dnf install -y python3.13 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ diff --git a/docs/cli/tg-set-mcp-tool.md b/docs/cli/tg-set-mcp-tool.md index 6d693e6e..90f137a0 100644 --- a/docs/cli/tg-set-mcp-tool.md +++ b/docs/cli/tg-set-mcp-tool.md @@ -3,12 +3,12 @@ ## Synopsis ``` -tg-set-mcp-tool [OPTIONS] --name NAME --tool-url URL +tg-set-mcp-tool [OPTIONS] --id ID --tool-url URL [--auth-token TOKEN] ``` ## Description -The `tg-set-mcp-tool` command configures and registers MCP (Model Control Protocol) tools in the TrustGraph system. It allows defining MCP tool configurations with name and URL. Tools are stored in the 'mcp' configuration group for discovery and execution. +The `tg-set-mcp-tool` command configures and registers MCP (Model Control Protocol) tools in the TrustGraph system. It allows defining MCP tool configurations with id, URL, and optional authentication token. Tools are stored in the 'mcp' configuration group for discovery and execution. This command is useful for: - Registering MCP tool endpoints for agent use @@ -25,16 +25,27 @@ The command stores MCP tool configurations in the 'mcp' configuration group, sep - Default: `http://localhost:8088/` (or `TRUSTGRAPH_URL` environment variable) - Should point to a running TrustGraph API instance -- `--name NAME` - - **Required.** MCP tool name identifier +- `-i, --id ID` + - **Required.** MCP tool identifier - Used to reference the MCP tool in configurations - Must be unique within the MCP tool registry +- `-r, --remote-name NAME` + - **Optional.** Remote MCP tool name used by the MCP server + - If not specified, defaults to the value of `--id` + - Use when the MCP server expects a different tool name + - `--tool-url URL` - **Required.** MCP tool URL endpoint - Should point to the MCP server endpoint providing the tool functionality - Must be a valid URL accessible by the TrustGraph system +- `--auth-token TOKEN` + - **Optional.** Bearer token for authentication + - Used to authenticate with secured MCP endpoints + - Token is sent as `Authorization: Bearer {TOKEN}` header + - Stored in plaintext in configuration (see Security Considerations) + - `-h, --help` - Show help message and exit @@ -44,54 +55,96 @@ The command stores MCP tool configurations in the 'mcp' configuration group, sep Register a weather service MCP tool: ```bash -tg-set-mcp-tool --name weather --tool-url "http://localhost:3000/weather" +tg-set-mcp-tool --id weather --tool-url "http://localhost:3000/weather" ``` ### Calculator MCP Tool Register a calculator MCP tool: ```bash -tg-set-mcp-tool --name calculator --tool-url "http://mcp-tools.example.com/calc" +tg-set-mcp-tool --id calculator --tool-url "http://mcp-tools.example.com/calc" ``` ### Remote MCP Service Register a remote MCP service: ```bash -tg-set-mcp-tool --name document-processor \ +tg-set-mcp-tool --id document-processor \ --tool-url "https://api.example.com/mcp/documents" ``` +### Secured MCP Tool with Authentication + +Register an MCP tool that requires bearer token authentication: +```bash +tg-set-mcp-tool --id secure-tool \ + --tool-url "https://api.example.com/mcp" \ + --auth-token "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." +``` + +### MCP Tool with Remote Name + +Register an MCP tool where the server uses a different name: +```bash +tg-set-mcp-tool --id my-weather \ + --remote-name weather_v2 \ + --tool-url "http://weather-server:3000/api" +``` + ### Custom API URL Register MCP tool with custom TrustGraph API: ```bash tg-set-mcp-tool -u http://trustgraph.example.com:8088/ \ - --name custom-mcp --tool-url "http://custom.mcp.com/api" + --id custom-mcp --tool-url "http://custom.mcp.com/api" ``` ### Local Development Setup Register MCP tools for local development: ```bash -tg-set-mcp-tool --name dev-tool --tool-url "http://localhost:8080/mcp" +tg-set-mcp-tool --id dev-tool --tool-url "http://localhost:8080/mcp" +``` + +### Production Setup with Authentication + +Register authenticated MCP tools for production: +```bash +# Using environment variable for token +export MCP_AUTH_TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." +tg-set-mcp-tool --id prod-tool \ + --tool-url "https://prod-mcp.example.com/api" \ + --auth-token "$MCP_AUTH_TOKEN" ``` ## MCP Tool Configuration -MCP tools are configured with minimal metadata: +MCP tools are configured with the following metadata: -- **name**: Unique identifier for the tool +- **id**: Unique identifier for the tool (configuration key) +- **remote-name**: Name used by the MCP server (optional, defaults to id) - **url**: Endpoint URL for the MCP server +- **auth-token**: Bearer token for authentication (optional) The configuration is stored as JSON in the 'mcp' configuration group: + +**Basic configuration:** ```json { - "name": "weather", + "remote-name": "weather", "url": "http://localhost:3000/weather" } ``` +**Configuration with authentication:** +```json +{ + "remote-name": "secure-tool", + "url": "https://api.example.com/mcp", + "auth-token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." +} +``` + ## Advanced Usage ### Updating Existing MCP Tools @@ -99,7 +152,15 @@ The configuration is stored as JSON in the 'mcp' configuration group: Update an existing MCP tool configuration: ```bash # Update MCP tool URL -tg-set-mcp-tool --name weather --tool-url "http://new-weather-server:3000/api" +tg-set-mcp-tool --id weather --tool-url "http://new-weather-server:3000/api" + +# Add authentication to existing tool +tg-set-mcp-tool --id weather \ + --tool-url "http://weather-server:3000/api" \ + --auth-token "new-token-here" + +# Remove authentication (by setting tool without auth-token) +tg-set-mcp-tool --id weather --tool-url "http://weather-server:3000/api" ``` ### Batch MCP Tool Registration @@ -108,22 +169,33 @@ Register multiple MCP tools in a script: ```bash #!/bin/bash # Register a suite of MCP tools -tg-set-mcp-tool --name search --tool-url "http://search-mcp:3000/api" -tg-set-mcp-tool --name translate --tool-url "http://translate-mcp:3000/api" -tg-set-mcp-tool --name summarize --tool-url "http://summarize-mcp:3000/api" +tg-set-mcp-tool --id search --tool-url "http://search-mcp:3000/api" +tg-set-mcp-tool --id translate --tool-url "http://translate-mcp:3000/api" +tg-set-mcp-tool --id summarize --tool-url "http://summarize-mcp:3000/api" + +# Register secured tools with authentication +tg-set-mcp-tool --id secure-search \ + --tool-url "https://secure-search:3000/api" \ + --auth-token "$SEARCH_TOKEN" +tg-set-mcp-tool --id secure-translate \ + --tool-url "https://secure-translate:3000/api" \ + --auth-token "$TRANSLATE_TOKEN" ``` ### Environment-Specific Configuration Configure MCP tools for different environments: ```bash -# Development environment +# Development environment (no auth) export TRUSTGRAPH_URL="http://dev.trustgraph.com:8088/" -tg-set-mcp-tool --name dev-mcp --tool-url "http://dev.mcp.com/api" +tg-set-mcp-tool --id dev-mcp --tool-url "http://dev.mcp.com/api" -# Production environment +# Production environment (with auth) export TRUSTGRAPH_URL="http://prod.trustgraph.com:8088/" -tg-set-mcp-tool --name prod-mcp --tool-url "http://prod.mcp.com/api" +export PROD_MCP_TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." +tg-set-mcp-tool --id prod-mcp \ + --tool-url "https://prod.mcp.com/api" \ + --auth-token "$PROD_MCP_TOKEN" ``` ### MCP Tool Validation @@ -131,10 +203,10 @@ tg-set-mcp-tool --name prod-mcp --tool-url "http://prod.mcp.com/api" Verify MCP tool registration: ```bash # Register MCP tool and verify -tg-set-mcp-tool --name test-mcp --tool-url "http://test.mcp.com/api" +tg-set-mcp-tool --id test-mcp --tool-url "http://test.mcp.com/api" -# Check if MCP tool was registered -tg-show-mcp-tools | grep test-mcp +# Check if MCP tool was registered and view auth status +tg-show-mcp-tools ``` ## Error Handling @@ -149,15 +221,15 @@ The command handles various error conditions: Common error scenarios: ```bash # Missing required field -tg-set-mcp-tool --name tool1 +tg-set-mcp-tool --id tool1 # Output: Exception: Must specify --tool-url for MCP tool -# Missing name +# Missing id tg-set-mcp-tool --tool-url "http://example.com/mcp" -# Output: Exception: Must specify --name for MCP tool +# Output: Exception: Must specify --id for MCP tool # Invalid API URL -tg-set-mcp-tool -u "invalid-url" --name tool1 --tool-url "http://mcp.com" +tg-set-mcp-tool -u "invalid-url" --id tool1 --tool-url "http://mcp.com" # Output: Exception: [API connection error] ``` @@ -168,9 +240,9 @@ tg-set-mcp-tool -u "invalid-url" --name tool1 --tool-url "http://mcp.com" View registered MCP tools: ```bash # Register MCP tool -tg-set-mcp-tool --name new-mcp --tool-url "http://new.mcp.com/api" +tg-set-mcp-tool --id new-mcp --tool-url "http://new.mcp.com/api" -# View all MCP tools +# View all MCP tools (shows auth status) tg-show-mcp-tools ``` @@ -178,11 +250,13 @@ tg-show-mcp-tools Use MCP tools in agent workflows: ```bash -# Register MCP tool -tg-set-mcp-tool --name weather --tool-url "http://weather.mcp.com/api" +# Register MCP tool with authentication +tg-set-mcp-tool --id weather \ + --tool-url "https://weather.mcp.com/api" \ + --auth-token "$WEATHER_TOKEN" -# Invoke MCP tool directly -tg-invoke-mcp-tool --name weather --input "location=London" +# Invoke MCP tool directly (auth handled automatically) +tg-invoke-mcp-tool --name weather --parameters '{"location": "London"}' ``` ### With Configuration Management @@ -190,21 +264,24 @@ tg-invoke-mcp-tool --name weather --input "location=London" MCP tools integrate with configuration management: ```bash # Register MCP tool -tg-set-mcp-tool --name config-mcp --tool-url "http://config.mcp.com/api" +tg-set-mcp-tool --id config-mcp --tool-url "http://config.mcp.com/api" -# View configuration including MCP tools -tg-show-config +# View all MCP tool configurations +tg-show-mcp-tools ``` ## Best Practices -1. **Clear Naming**: Use descriptive, unique MCP tool names +1. **Clear Naming**: Use descriptive, unique MCP tool identifiers 2. **Reliable URLs**: Ensure MCP endpoints are stable and accessible -3. **Health Checks**: Verify MCP endpoints are operational before registration -4. **Documentation**: Document MCP tool capabilities and usage -5. **Error Handling**: Implement proper error handling for MCP endpoints -6. **Security**: Use secure URLs (HTTPS) when possible -7. **Monitoring**: Monitor MCP tool availability and performance +3. **Use HTTPS**: Always use HTTPS URLs when authentication is required +4. **Secure Tokens**: Store auth tokens in environment variables, not in scripts +5. **Token Rotation**: Regularly rotate authentication tokens +6. **Health Checks**: Verify MCP endpoints are operational before registration +7. **Documentation**: Document MCP tool capabilities and usage +8. **Error Handling**: Implement proper error handling for MCP endpoints +9. **Monitoring**: Monitor MCP tool availability and performance +10. **Access Control**: Restrict access to configuration system containing tokens ## Troubleshooting @@ -248,10 +325,45 @@ The Model Control Protocol (MCP) is a standardized interface for AI model tools: When registering MCP tools: 1. **URL Validation**: Ensure URLs are legitimate and secure -2. **Network Security**: Use HTTPS when possible -3. **Access Control**: Implement proper authentication for MCP endpoints -4. **Input Validation**: Validate all inputs to MCP tools -5. **Error Handling**: Don't expose sensitive information in error messages +2. **Network Security**: Always use HTTPS for authenticated endpoints +3. **Token Storage**: Auth tokens are stored in plaintext in the configuration system + - Ensure proper access control on the configuration storage + - Use short-lived tokens when possible + - Implement token rotation policies +4. **Token Transmission**: Use HTTPS to prevent token interception +5. **Access Control**: Implement proper authentication for MCP endpoints +6. **Token Exposure**: + - Use environment variables to pass tokens to the command + - Don't hardcode tokens in scripts or commit them to version control + - The `tg-show-mcp-tools` command masks token values for security +7. **Input Validation**: Validate all inputs to MCP tools +8. **Error Handling**: Don't expose sensitive information in error messages +9. **Least Privilege**: Grant tokens minimum required permissions +10. **Audit Logging**: Monitor configuration changes for security events + +### Authentication Best Practices + +When using the `--auth-token` parameter: + +- **Store tokens securely**: Use environment variables or secrets management systems +- **Use HTTPS**: Always use HTTPS URLs when providing authentication tokens +- **Rotate regularly**: Implement a token rotation schedule +- **Monitor usage**: Track which services are accessing authenticated endpoints +- **Revoke on compromise**: Have a process to quickly revoke and rotate compromised tokens + +Example secure workflow: +```bash +# Store token in environment variable (not in script) +export MCP_TOKEN=$(cat /secure/path/to/token) + +# Use HTTPS for authenticated endpoints +tg-set-mcp-tool --id secure-service \ + --tool-url "https://secure.example.com/mcp" \ + --auth-token "$MCP_TOKEN" + +# Clear token from environment after use +unset MCP_TOKEN +``` ## Related Commands diff --git a/docs/tech-specs/mcp-tool-bearer-token.md b/docs/tech-specs/mcp-tool-bearer-token.md new file mode 100644 index 00000000..f3f75d29 --- /dev/null +++ b/docs/tech-specs/mcp-tool-bearer-token.md @@ -0,0 +1,554 @@ +# MCP Tool Bearer Token Authentication Specification + +> **⚠️ IMPORTANT: SINGLE-TENANT ONLY** +> +> This specification describes a **basic, service-level authentication mechanism** for MCP tools. It is **NOT** a complete authentication solution and is **NOT suitable** for: +> - Multi-user environments +> - Multi-tenant deployments +> - Federated authentication +> - User context propagation +> - Per-user authorization +> +> This feature provides **one static token per MCP tool**, shared across all users and sessions. If you need per-user or per-tenant authentication, this is not the right solution. + +## Overview +**Feature Name**: MCP Tool Bearer Token Authentication Support +**Author**: Claude Code Assistant +**Date**: 2025-11-11 +**Status**: In Development + +### Executive Summary + +Enable MCP tool configurations to specify optional bearer tokens for authenticating with protected MCP servers. This allows TrustGraph to securely invoke MCP tools hosted on servers that require authentication, without modifying the agent or tool invocation interfaces. + +**IMPORTANT**: This is a basic authentication mechanism designed for single-tenant, service-to-service authentication scenarios. It is **NOT** suitable for: +- Multi-user environments where different users need different credentials +- Multi-tenant deployments requiring per-tenant isolation +- Federated authentication scenarios +- User-level authentication or authorization +- Dynamic credential management or token refresh + +This feature provides a static, system-wide bearer token per MCP tool configuration, shared across all users and invocations of that tool. + +### Problem Statement + +Currently, MCP tools can only connect to publicly accessible MCP servers. Many production MCP deployments require authentication via bearer tokens for security. Without authentication support: +- MCP tools cannot connect to secured MCP servers +- Users must either expose MCP servers publicly or implement reverse proxies +- No standardized way to pass credentials to MCP connections +- Security best practices cannot be enforced on MCP endpoints + +### Goals + +- [ ] Allow MCP tool configurations to specify optional `auth-token` parameter +- [ ] Update MCP tool service to use bearer tokens when connecting to MCP servers +- [ ] Update CLI tools to support setting/displaying auth tokens +- [ ] Maintain backward compatibility with unauthenticated MCP configurations +- [ ] Document security considerations for token storage + +### Non-Goals +- Dynamic token refresh or OAuth flows (static tokens only) +- Encryption of stored tokens (configuration system security is out of scope) +- Alternative authentication methods (Basic auth, API keys, etc.) +- Token validation or expiration checking +- **Per-user authentication**: This feature does NOT support user-specific credentials +- **Multi-tenant isolation**: This feature does NOT provide per-tenant token management +- **Federated authentication**: This feature does NOT integrate with identity providers (SSO, OAuth, SAML, etc.) +- **Context-aware authentication**: Tokens are not passed based on user context or session + +## Background and Context + +### Current State +MCP tool configurations are stored in the `mcp` configuration group with this structure: +```json +{ + "remote-name": "tool_name", + "url": "http://mcp-server:3000/api" +} +``` + +The MCP tool service connects to servers using `streamablehttp_client(url)` without any authentication headers. + +### Limitations + +**Current System Limitations:** +1. **No authentication support**: Cannot connect to secured MCP servers +2. **Security exposure**: MCP servers must be publicly accessible or use network-level security only +3. **Production deployment issues**: Cannot follow security best practices for API endpoints + +**Limitations of This Solution:** +1. **Single-tenant only**: One static token per MCP tool, shared across all users +2. **No per-user credentials**: Cannot authenticate as different users or pass user context +3. **No multi-tenant support**: Cannot isolate credentials by tenant or organization +4. **Static tokens only**: No support for token refresh, rotation, or expiration handling +5. **Service-level authentication**: Authenticates the TrustGraph service, not individual users +6. **Shared security context**: All invocations of an MCP tool use the same credential + +### Use Case Applicability + +**✅ Appropriate Use Cases:** +- Single-tenant TrustGraph deployments +- Service-to-service authentication (TrustGraph → MCP Server) +- Development and testing environments +- Internal MCP tools accessed by the TrustGraph system +- Scenarios where all users share the same MCP tool access level +- Static, long-lived service credentials + +**❌ Inappropriate Use Cases:** +- Multi-user systems requiring per-user authentication +- Multi-tenant SaaS deployments with tenant isolation requirements +- Federated authentication scenarios (SSO, OAuth, SAML) +- Systems requiring user context propagation to MCP servers +- Environments needing dynamic token refresh or short-lived tokens +- Applications where different users need different permission levels +- Compliance requirements for user-level audit trails + +**Example Appropriate Scenario:** +A single-organization TrustGraph deployment where all employees use the same internal MCP tool (e.g., company database lookup). The MCP server requires authentication to prevent external access, but all internal users have the same access level. + +**Example Inappropriate Scenario:** +A multi-tenant TrustGraph SaaS platform where Tenant A and Tenant B each need to access their own isolated MCP servers with separate credentials. This feature does NOT support per-tenant token management. + +### Related Components +- **trustgraph-flow/trustgraph/agent/mcp_tool/service.py**: MCP tool invocation service +- **trustgraph-cli/trustgraph/cli/set_mcp_tool.py**: CLI tool for creating/updating MCP configurations +- **trustgraph-cli/trustgraph/cli/show_mcp_tools.py**: CLI tool for displaying MCP configurations +- **MCP Python SDK**: `streamablehttp_client` from `mcp.client.streamable_http` + +## Requirements + +### Functional Requirements + +1. **MCP Configuration Auth Token**: MCP tool configurations MUST support an optional `auth-token` field +2. **Bearer Token Usage**: MCP tool service MUST send `Authorization: Bearer {token}` header when auth-token is configured +3. **CLI Support**: `tg-set-mcp-tool` MUST accept optional `--auth-token` parameter +4. **Token Display**: `tg-show-mcp-tools` MUST indicate when auth-token is configured (masked for security) +5. **Backward Compatibility**: Existing MCP tool configurations without auth-token MUST continue to work + +### Non-Functional Requirements +1. **Backward Compatibility**: Zero breaking changes for existing MCP tool configurations +2. **Performance**: No significant performance impact on MCP tool invocation +3. **Security**: Tokens stored in configuration (document security implications) + +### User Stories + +1. As a **DevOps engineer**, I want to configure bearer tokens for MCP tools so that I can secure MCP server endpoints +2. As a **CLI user**, I want to set auth tokens when creating MCP tools so that I can connect to protected servers +3. As a **system administrator**, I want to see which MCP tools have authentication configured so that I can audit security settings + +## Design + +### High-Level Architecture +Extend MCP tool configuration and service to support bearer token authentication: +1. Add optional `auth-token` field to MCP tool configuration schema +2. Modify MCP tool service to read auth-token and pass to HTTP client +3. Update CLI tools to support setting and displaying auth tokens +4. Document security considerations and best practices + +### Configuration Schema + +**Current Schema**: +```json +{ + "remote-name": "tool_name", + "url": "http://mcp-server:3000/api" +} +``` + +**New Schema** (with optional auth-token): +```json +{ + "remote-name": "tool_name", + "url": "http://mcp-server:3000/api", + "auth-token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." +} +``` + +**Field Descriptions**: +- `remote-name` (optional): Name used by MCP server (defaults to config key) +- `url` (required): MCP server endpoint URL +- `auth-token` (optional): Bearer token for authentication + +### Data Flow + +1. **Configuration Storage**: User runs `tg-set-mcp-tool --id my-tool --tool-url http://server/api --auth-token xyz123` +2. **Config Loading**: MCP tool service receives config update via `on_mcp_config()` callback +3. **Tool Invocation**: When tool is invoked: + - Service reads `auth-token` from config (if present) + - Creates headers dict: `{"Authorization": "Bearer {token}"}` + - Passes headers to `streamablehttp_client(url, headers=headers)` + - MCP server validates token and processes request + +### API Changes +No external API changes - configuration schema extension only. + +### Component Details + +#### Component 1: service.py (MCP Tool Service) +**File**: `trustgraph-flow/trustgraph/agent/mcp_tool/service.py` + +**Purpose**: Invoke MCP tools on remote servers + +**Changes Required** (in `invoke_tool()` method): +1. Check for `auth-token` in `self.mcp_services[name]` config +2. Build headers dict with Authorization header if token exists +3. Pass headers to `streamablehttp_client(url, headers=headers)` + +**Current Code** (lines 42-89): +```python +async def invoke_tool(self, name, parameters): + try: + if name not in self.mcp_services: + raise RuntimeError(f"MCP service {name} not known") + if "url" not in self.mcp_services[name]: + raise RuntimeError(f"MCP service {name} URL not defined") + + url = self.mcp_services[name]["url"] + + if "remote-name" in self.mcp_services[name]: + remote_name = self.mcp_services[name]["remote-name"] + else: + remote_name = name + + logger.info(f"Invoking {remote_name} at {url}") + + # Connect to a streamable HTTP server + async with streamablehttp_client(url) as ( + read_stream, + write_stream, + _, + ): + # ... rest of method +``` + +**Modified Code**: +```python +async def invoke_tool(self, name, parameters): + try: + if name not in self.mcp_services: + raise RuntimeError(f"MCP service {name} not known") + if "url" not in self.mcp_services[name]: + raise RuntimeError(f"MCP service {name} URL not defined") + + url = self.mcp_services[name]["url"] + + if "remote-name" in self.mcp_services[name]: + remote_name = self.mcp_services[name]["remote-name"] + else: + remote_name = name + + # Build headers with optional bearer token + headers = {} + if "auth-token" in self.mcp_services[name]: + token = self.mcp_services[name]["auth-token"] + headers["Authorization"] = f"Bearer {token}" + + logger.info(f"Invoking {remote_name} at {url}") + + # Connect to a streamable HTTP server with headers + async with streamablehttp_client(url, headers=headers) as ( + read_stream, + write_stream, + _, + ): + # ... rest of method (unchanged) +``` + +#### Component 2: set_mcp_tool.py (CLI Configuration Tool) +**File**: `trustgraph-cli/trustgraph/cli/set_mcp_tool.py` + +**Purpose**: Create/update MCP tool configurations + +**Changes Required**: +1. Add `--auth-token` optional argument to argparse +2. Include `auth-token` in configuration JSON when provided + +**Current Arguments**: +- `--id` (required): MCP tool identifier +- `--remote-name` (optional): Remote MCP tool name +- `--tool-url` (required): MCP tool URL endpoint +- `-u, --api-url` (optional): TrustGraph API URL + +**New Argument**: +- `--auth-token` (optional): Bearer token for authentication + +**Modified Configuration Building**: +```python +# Build configuration object +config = { + "url": args.tool_url, +} + +if args.remote_name: + config["remote-name"] = args.remote_name + +if args.auth_token: + config["auth-token"] = args.auth_token + +# Store configuration +api.config().put([ + ConfigValue(type="mcp", key=args.id, value=json.dumps(config)) +]) +``` + +#### Component 3: show_mcp_tools.py (CLI Display Tool) +**File**: `trustgraph-cli/trustgraph/cli/show_mcp_tools.py` + +**Purpose**: Display MCP tool configurations + +**Changes Required**: +1. Add "Auth" column to output table +2. Display "Yes" or "No" based on presence of auth-token +3. Do not display actual token value (security) + +**Current Output**: +``` +ID Remote Name URL +---------- ------------- ------------------------ +my-tool my-tool http://server:3000/api +``` + +**New Output**: +``` +ID Remote Name URL Auth +---------- ------------- ------------------------ ------ +my-tool my-tool http://server:3000/api Yes +other-tool other-tool http://other:3000/api No +``` + +#### Component 4: Documentation +**File**: `docs/cli/tg-set-mcp-tool.md` + +**Changes Required**: +1. Document new `--auth-token` parameter +2. Provide example usage with authentication +3. Document security considerations + +## Implementation Plan + +### Phase 1: Create Technical Specification +- [x] Write comprehensive tech spec documenting all changes + +### Phase 2: Update MCP Tool Service +- [ ] Modify `invoke_tool()` in `service.py` to read auth-token from config +- [ ] Build headers dict and pass to `streamablehttp_client` +- [ ] Test with authenticated MCP server + +### Phase 3: Update CLI Tools +- [ ] Add `--auth-token` argument to `set_mcp_tool.py` +- [ ] Include auth-token in configuration JSON +- [ ] Add "Auth" column to `show_mcp_tools.py` output +- [ ] Test CLI tool changes + +### Phase 4: Update Documentation +- [ ] Document `--auth-token` parameter in `tg-set-mcp-tool.md` +- [ ] Add security considerations section +- [ ] Provide example usage + +### Phase 5: Testing +- [ ] Test MCP tool with auth-token connects successfully +- [ ] Test backward compatibility (tools without auth-token still work) +- [ ] Test CLI tools accept and store auth-token correctly +- [ ] Test show command displays auth status correctly + +### Code Changes Summary +| File | Change Type | Lines | Description | +|------|------------|-------|-------------| +| `service.py` | Modified | ~52-66 | Add auth-token reading and header building | +| `set_mcp_tool.py` | Modified | ~30-60 | Add --auth-token argument and config storage | +| `show_mcp_tools.py` | Modified | ~40-70 | Add Auth column to display | +| `tg-set-mcp-tool.md` | Modified | Various | Document new parameter | + +## Testing Strategy + +### Unit Tests +- **Auth Token Reading**: Test `invoke_tool()` correctly reads auth-token from config +- **Header Building**: Test Authorization header is built correctly with Bearer prefix +- **Backward Compatibility**: Test tools without auth-token work unchanged +- **CLI Argument Parsing**: Test `--auth-token` argument is parsed correctly + +### Integration Tests +- **Authenticated Connection**: Test MCP tool service connects to authenticated server +- **End-to-End**: Test CLI → config storage → service invocation with auth token +- **Token Not Required**: Test connection to unauthenticated server still works + +### Manual Testing +- **Real MCP Server**: Test with actual MCP server requiring bearer token authentication +- **CLI Workflow**: Test complete workflow: set tool with auth → invoke tool → verify success +- **Display Masking**: Verify auth status shown but token value not exposed + +## Migration and Rollout + +### Migration Strategy +No migration required - this is purely additive functionality: +- Existing MCP tool configurations without `auth-token` continue to work unchanged +- New configurations can optionally include `auth-token` field +- CLI tools accept but don't require `--auth-token` parameter + +### Rollout Plan +1. **Phase 1**: Deploy core service changes to development/staging +2. **Phase 2**: Deploy CLI tool updates +3. **Phase 3**: Update documentation +4. **Phase 4**: Production rollout with monitoring + +### Rollback Plan +- Core changes are backward compatible - existing tools unaffected +- If issues arise, auth-token handling can be disabled by removing header building logic +- CLI changes are independent and can be rolled back separately + +## Security Considerations + +### ⚠️ Critical Limitation: Single-Tenant Authentication Only + +**This authentication mechanism is NOT suitable for multi-user or multi-tenant environments.** + +- **Shared credentials**: All users and invocations share the same token per MCP tool +- **No user context**: The MCP server cannot distinguish between different TrustGraph users +- **No tenant isolation**: All tenants share the same credential for each MCP tool +- **Audit trail limitation**: MCP server logs show all requests from the same credential +- **Permission scope**: Cannot enforce different permission levels for different users + +**Do NOT use this feature if:** +- Your TrustGraph deployment serves multiple organizations (multi-tenant) +- You need to track which user accessed which MCP tool +- Different users require different permission levels +- You need to comply with user-level audit requirements +- Your MCP server enforces per-user rate limits or quotas + +**Alternative solutions for multi-user/multi-tenant scenarios:** +- Implement user context propagation through custom headers +- Deploy separate TrustGraph instances per tenant +- Use network-level isolation (VPCs, service meshes) +- Implement a proxy layer that handles per-user authentication + +### Token Storage +**Risk**: Auth tokens stored in plaintext in configuration system + +**Mitigation**: +- Document that tokens are stored unencrypted +- Recommend using short-lived tokens when possible +- Recommend proper access control on configuration storage +- Consider future enhancement for encrypted token storage + +### Token Exposure +**Risk**: Tokens could be exposed in logs or CLI output + +**Mitigation**: +- Do not log token values (only log "auth configured: yes/no") +- CLI show command displays masked status only, not actual token +- Do not include tokens in error messages + +### Network Security +**Risk**: Tokens transmitted over unencrypted connections + +**Mitigation**: +- Document recommendation to use HTTPS URLs for MCP servers +- Warn users about plaintext transmission risk with HTTP + +### Configuration Access +**Risk**: Unauthorized access to configuration system exposes tokens + +**Mitigation**: +- Document importance of securing configuration system access +- Recommend principle of least privilege for configuration access +- Consider audit logging for configuration changes (future enhancement) + +### Multi-User Environments +**Risk**: In multi-user deployments, all users share the same MCP credentials + +**Understanding the Risk**: +- User A and User B both use the same token when accessing an MCP tool +- MCP server cannot distinguish between different TrustGraph users +- No way to enforce per-user permissions or rate limits +- Audit logs on MCP server show all requests from same credential +- If one user's session is compromised, attacker has same MCP access as all users + +**This is NOT a bug - it's a fundamental limitation of this design.** + +## Performance Impact +- **Minimal overhead**: Header building adds negligible processing time +- **Network impact**: Additional HTTP header adds ~50-200 bytes per request +- **Memory usage**: Negligible increase for storing token string in config + +## Documentation + +### User Documentation +- [ ] Update `tg-set-mcp-tool.md` with `--auth-token` parameter +- [ ] Add security considerations section +- [ ] Provide example usage with bearer token +- [ ] Document token storage implications + +### Developer Documentation +- [ ] Add inline comments for auth token handling in `service.py` +- [ ] Document header building logic +- [ ] Update MCP tool configuration schema documentation + +## Open Questions +1. **Token encryption**: Should we implement encrypted token storage in configuration system? +2. **Token refresh**: Future support for OAuth refresh flows or token rotation? +3. **Alternative auth methods**: Should we support Basic auth, API keys, or other methods? + +## Alternatives Considered + +1. **Environment variables for tokens**: Store tokens in env vars instead of config + - **Rejected**: Complicates deployment and configuration management + +2. **Separate secrets store**: Use dedicated secrets management system + - **Deferred**: Out of scope for initial implementation, consider future enhancement + +3. **Multiple auth methods**: Support Basic, API key, OAuth, etc. + - **Rejected**: Bearer tokens cover most use cases, keep initial implementation simple + +4. **Encrypted token storage**: Encrypt tokens in configuration system + - **Deferred**: Configuration system security is broader concern, defer to future work + +5. **Per-invocation tokens**: Allow tokens to be passed at invocation time + - **Rejected**: Violates separation of concerns, agent shouldn't handle credentials + +## References +- [MCP Protocol Specification](https://github.com/modelcontextprotocol/spec) +- [HTTP Bearer Authentication (RFC 6750)](https://tools.ietf.org/html/rfc6750) +- [Current MCP Tool Service](../trustgraph-flow/trustgraph/agent/mcp_tool/service.py) +- [MCP Tool Arguments Specification](./mcp-tool-arguments.md) + +## Appendix + +### Example Usage + +**Setting MCP tool with authentication**: +```bash +tg-set-mcp-tool \ + --id secure-tool \ + --tool-url https://secure-server.example.com/mcp \ + --auth-token eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... +``` + +**Showing MCP tools**: +```bash +tg-show-mcp-tools + +ID Remote Name URL Auth +----------- ----------- ------------------------------------ ------ +secure-tool secure-tool https://secure-server.example.com/mcp Yes +public-tool public-tool http://localhost:3000/mcp No +``` + +### Configuration Example + +**Stored in configuration system**: +```json +{ + "type": "mcp", + "key": "secure-tool", + "value": "{\"url\": \"https://secure-server.example.com/mcp\", \"auth-token\": \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...\"}" +} +``` + +### Security Best Practices + +1. **Use HTTPS**: Always use HTTPS URLs for MCP servers with authentication +2. **Short-lived tokens**: Use tokens with expiration when possible +3. **Least privilege**: Grant tokens minimum required permissions +4. **Access control**: Restrict access to configuration system +5. **Token rotation**: Rotate tokens regularly +6. **Audit logging**: Monitor configuration changes for security events diff --git a/docs/tech-specs/ontology.md b/docs/tech-specs/ontology.md new file mode 100644 index 00000000..61cc09e4 --- /dev/null +++ b/docs/tech-specs/ontology.md @@ -0,0 +1,286 @@ +# Ontology Structure Technical Specification + +## Overview + +This specification describes the structure and format of ontologies within the TrustGraph system. Ontologies provide formal knowledge models that define classes, properties, and relationships, supporting reasoning and inference capabilities. The system uses an OWL-inspired configuration format that broadly represents OWL/RDFS concepts while being optimized for TrustGraph's requirements. + +**Naming Convention**: This project uses kebab-case for all identifiers (configuration keys, API endpoints, module names, etc.) rather than snake_case. + +## Goals + +- **Class and Property Management**: Define OWL-like classes with properties, domains, ranges, and type constraints +- **Rich Semantic Support**: Enable comprehensive RDFS/OWL properties including labels, multi-language support, and formal constraints +- **Multi-Ontology Support**: Allow multiple ontologies to coexist and interoperate +- **Validation and Reasoning**: Ensure ontologies conform to OWL-like standards with consistency checking and inference support +- **Standard Compatibility**: Support import/export in standard formats (Turtle, RDF/XML, OWL/XML) while maintaining internal optimization + +## Background + +TrustGraph stores ontologies as configuration items in a flexible key-value system. While the format is inspired by OWL (Web Ontology Language), it is optimized for TrustGraph's specific use cases and does not strictly adhere to all OWL specifications. + +Ontologies in TrustGraph enable: +- Definition of formal object types and their properties +- Specification of property domains, ranges, and type constraints +- Logical reasoning and inference +- Complex relationships and cardinality constraints +- Multi-language support for internationalization + +## Ontology Structure + +### Configuration Storage + +Ontologies are stored as configuration items with the following pattern: +- **Type**: `ontology` +- **Key**: Unique ontology identifier (e.g., `natural-world`, `domain-model`) +- **Value**: Complete ontology in JSON format + +### JSON Structure + +The ontology JSON format consists of four main sections: + +#### 1. Metadata + +Contains administrative and descriptive information about the ontology: + +```json +{ + "metadata": { + "name": "The natural world", + "description": "Ontology covering the natural order", + "version": "1.0.0", + "created": "2025-09-20T12:07:37.068Z", + "modified": "2025-09-20T12:12:20.725Z", + "creator": "current-user", + "namespace": "http://trustgraph.ai/ontologies/natural-world", + "imports": ["http://www.w3.org/2002/07/owl#"] + } +} +``` + +**Fields:** +- `name`: Human-readable name of the ontology +- `description`: Brief description of the ontology's purpose +- `version`: Semantic version number +- `created`: ISO 8601 timestamp of creation +- `modified`: ISO 8601 timestamp of last modification +- `creator`: Identifier of the creating user/system +- `namespace`: Base URI for ontology elements +- `imports`: Array of imported ontology URIs + +#### 2. Classes + +Defines the object types and their hierarchical relationships: + +```json +{ + "classes": { + "animal": { + "uri": "http://trustgraph.ai/ontologies/natural-world#animal", + "type": "owl:Class", + "rdfs:label": [{"value": "Animal", "lang": "en"}], + "rdfs:comment": "An animal", + "rdfs:subClassOf": "lifeform", + "owl:equivalentClass": ["creature"], + "owl:disjointWith": ["plant"], + "dcterms:identifier": "ANI-001" + } + } +} +``` + +**Supported Properties:** +- `uri`: Full URI of the class +- `type`: Always `"owl:Class"` +- `rdfs:label`: Array of language-tagged labels +- `rdfs:comment`: Description of the class +- `rdfs:subClassOf`: Parent class identifier (single inheritance) +- `owl:equivalentClass`: Array of equivalent class identifiers +- `owl:disjointWith`: Array of disjoint class identifiers +- `dcterms:identifier`: Optional external reference identifier + +#### 3. Object Properties + +Properties that link instances to other instances: + +```json +{ + "objectProperties": { + "has-parent": { + "uri": "http://trustgraph.ai/ontologies/natural-world#has-parent", + "type": "owl:ObjectProperty", + "rdfs:label": [{"value": "has parent", "lang": "en"}], + "rdfs:comment": "Links an animal to its parent", + "rdfs:domain": "animal", + "rdfs:range": "animal", + "owl:inverseOf": "parent-of", + "owl:functionalProperty": false + } + } +} +``` + +**Supported Properties:** +- `uri`: Full URI of the property +- `type`: Always `"owl:ObjectProperty"` +- `rdfs:label`: Array of language-tagged labels +- `rdfs:comment`: Description of the property +- `rdfs:domain`: Class identifier that has this property +- `rdfs:range`: Class identifier for property values +- `owl:inverseOf`: Identifier of inverse property +- `owl:functionalProperty`: Boolean indicating at most one value +- `owl:inverseFunctionalProperty`: Boolean for unique identifying properties + +#### 4. Datatype Properties + +Properties that link instances to literal values: + +```json +{ + "datatypeProperties": { + "number-of-legs": { + "uri": "http://trustgraph.ai/ontologies/natural-world#number-of-legs", + "type": "owl:DatatypeProperty", + "rdfs:label": [{"value": "number of legs", "lang": "en"}], + "rdfs:comment": "Count of number of legs of the animal", + "rdfs:domain": "animal", + "rdfs:range": "xsd:nonNegativeInteger", + "owl:functionalProperty": true, + "owl:minCardinality": 0, + "owl:maxCardinality": 1 + } + } +} +``` + +**Supported Properties:** +- `uri`: Full URI of the property +- `type`: Always `"owl:DatatypeProperty"` +- `rdfs:label`: Array of language-tagged labels +- `rdfs:comment`: Description of the property +- `rdfs:domain`: Class identifier that has this property +- `rdfs:range`: XSD datatype for property values +- `owl:functionalProperty`: Boolean indicating at most one value +- `owl:minCardinality`: Minimum number of values (optional) +- `owl:maxCardinality`: Maximum number of values (optional) +- `owl:cardinality`: Exact number of values (optional) + +### Supported XSD Datatypes + +The following XML Schema datatypes are supported for datatype property ranges: + +- `xsd:string` - Text values +- `xsd:integer` - Integer numbers +- `xsd:nonNegativeInteger` - Non-negative integers +- `xsd:float` - Floating point numbers +- `xsd:double` - Double precision numbers +- `xsd:boolean` - True/false values +- `xsd:dateTime` - Date and time values +- `xsd:date` - Date values +- `xsd:anyURI` - URI references + +### Language Support + +Labels and comments support multiple languages using the W3C language tag format: + +```json +{ + "rdfs:label": [ + {"value": "Animal", "lang": "en"}, + {"value": "Tier", "lang": "de"}, + {"value": "Animal", "lang": "es"} + ] +} +``` + +## Example Ontology + +Here's a complete example of a simple ontology: + +```json +{ + "metadata": { + "name": "The natural world", + "description": "Ontology covering the natural order", + "version": "1.0.0", + "created": "2025-09-20T12:07:37.068Z", + "modified": "2025-09-20T12:12:20.725Z", + "creator": "current-user", + "namespace": "http://trustgraph.ai/ontologies/natural-world", + "imports": ["http://www.w3.org/2002/07/owl#"] + }, + "classes": { + "lifeform": { + "uri": "http://trustgraph.ai/ontologies/natural-world#lifeform", + "type": "owl:Class", + "rdfs:label": [{"value": "Lifeform", "lang": "en"}], + "rdfs:comment": "A living thing" + }, + "animal": { + "uri": "http://trustgraph.ai/ontologies/natural-world#animal", + "type": "owl:Class", + "rdfs:label": [{"value": "Animal", "lang": "en"}], + "rdfs:comment": "An animal", + "rdfs:subClassOf": "lifeform" + }, + "cat": { + "uri": "http://trustgraph.ai/ontologies/natural-world#cat", + "type": "owl:Class", + "rdfs:label": [{"value": "Cat", "lang": "en"}], + "rdfs:comment": "A cat", + "rdfs:subClassOf": "animal" + }, + "dog": { + "uri": "http://trustgraph.ai/ontologies/natural-world#dog", + "type": "owl:Class", + "rdfs:label": [{"value": "Dog", "lang": "en"}], + "rdfs:comment": "A dog", + "rdfs:subClassOf": "animal", + "owl:disjointWith": ["cat"] + } + }, + "objectProperties": {}, + "datatypeProperties": { + "number-of-legs": { + "uri": "http://trustgraph.ai/ontologies/natural-world#number-of-legs", + "type": "owl:DatatypeProperty", + "rdfs:label": [{"value": "number-of-legs", "lang": "en"}], + "rdfs:comment": "Count of number of legs of the animal", + "rdfs:range": "xsd:nonNegativeInteger", + "rdfs:domain": "animal" + } + } +} +``` + +## Validation Rules + +### Structural Validation + +1. **URI Consistency**: All URIs should follow the pattern `{namespace}#{identifier}` +2. **Class Hierarchy**: No circular inheritance in `rdfs:subClassOf` +3. **Property Domains/Ranges**: Must reference existing classes or valid XSD types +4. **Disjoint Classes**: Cannot be subclasses of each other +5. **Inverse Properties**: Must be bidirectional if specified + +### Semantic Validation + +1. **Unique Identifiers**: Class and property identifiers must be unique within an ontology +2. **Language Tags**: Must follow BCP 47 language tag format +3. **Cardinality Constraints**: `minCardinality` ≤ `maxCardinality` when both specified +4. **Functional Properties**: Cannot have `maxCardinality` > 1 + +## Import/Export Format Support + +While the internal format is JSON, the system supports conversion to/from standard ontology formats: + +- **Turtle (.ttl)** - Compact RDF serialization +- **RDF/XML (.rdf, .owl)** - W3C standard format +- **OWL/XML (.owx)** - OWL-specific XML format +- **JSON-LD (.jsonld)** - JSON for Linked Data + +## References + +- [OWL 2 Web Ontology Language](https://www.w3.org/TR/owl2-overview/) +- [RDF Schema 1.1](https://www.w3.org/TR/rdf-schema/) +- [XML Schema Datatypes](https://www.w3.org/TR/xmlschema-2/) +- [BCP 47 Language Tags](https://tools.ietf.org/html/bcp47) \ No newline at end of file diff --git a/docs/tech-specs/ontorag.md b/docs/tech-specs/ontorag.md new file mode 100644 index 00000000..47426dbe --- /dev/null +++ b/docs/tech-specs/ontorag.md @@ -0,0 +1,1067 @@ +# OntoRAG: Ontology-Based Knowledge Extraction and Query Technical Specification + +## Overview + +OntoRAG is an ontology-driven knowledge extraction and query system that enforces strict semantic consistency during both the extraction of knowledge triples from unstructured text and the querying of the resulting knowledge graph. Similar to GraphRAG but with formal ontology constraints, OntoRAG ensures all extracted triples conform to predefined ontological structures and provides semantically-aware querying capabilities. + +The system uses vector similarity matching to dynamically select relevant ontology subsets for both extraction and query operations, enabling focused and contextually appropriate processing while maintaining semantic validity. + +**Service Name**: `kg-extract-ontology` + +## Goals + +- **Ontology-Conformant Extraction**: Ensure all extracted triples strictly conform to loaded ontologies +- **Dynamic Context Selection**: Use embeddings to select relevant ontology subsets for each chunk +- **Semantic Consistency**: Maintain class hierarchies, property domains/ranges, and constraints +- **Efficient Processing**: Use in-memory vector stores for fast ontology element matching +- **Scalable Architecture**: Support multiple concurrent ontologies with different domains + +## Background + +Current knowledge extraction services (`kg-extract-definitions`, `kg-extract-relationships`) operate without formal constraints, potentially producing inconsistent or incompatible triples. OntoRAG addresses this by: + +1. Loading formal ontologies that define valid classes and properties +2. Using embeddings to match text content with relevant ontology elements +3. Constraining extraction to only produce ontology-conformant triples +4. Providing semantic validation of extracted knowledge + +This approach combines the flexibility of neural extraction with the rigor of formal knowledge representation. + +## Technical Design + +### Architecture + +The OntoRAG system consists of the following components: + +``` +┌─────────────────┐ +│ Configuration │ +│ Service │ +└────────┬────────┘ + │ Ontologies + ▼ +┌─────────────────┐ ┌──────────────┐ +│ kg-extract- │────▶│ Embedding │ +│ ontology │ │ Service │ +└────────┬────────┘ └──────────────┘ + │ │ + ▼ ▼ +┌─────────────────┐ ┌──────────────┐ +│ In-Memory │◀────│ Ontology │ +│ Vector Store │ │ Embedder │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────┐ +│ Sentence │────▶│ Chunker │ +│ Splitter │ │ Service │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────┐ +│ Ontology │────▶│ Vector │ +│ Selector │ │ Search │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────┐ +│ Prompt │────▶│ Prompt │ +│ Constructor │ │ Service │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ +│ Triple Output │ +└─────────────────┘ +``` + +### Component Details + +#### 1. Ontology Loader + +**Purpose**: Retrieves and parses ontology configurations from the configuration service using event-driven updates. + +**Implementation**: +The Ontology Loader uses TrustGraph's ConfigPush queue to receive event-driven ontology configuration updates. When a configuration element of type "ontology" is added or modified, the loader receives the update via the config-update queue and parses the JSON structure containing metadata, classes, object properties, and datatype properties. These parsed ontologies are stored in memory as structured objects that can be efficiently accessed during the extraction process. + +**Key Operations**: +- Subscribe to config-update queue for ontology-type configurations +- Parse JSON ontology structures into OntologyClass and OntologyProperty objects +- Validate ontology structure and consistency +- Cache parsed ontologies in memory for fast access +- Handle per-flow processing with flow-specific vector stores + +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py` + +#### 2. Ontology Embedder + +**Purpose**: Creates vector embeddings for all ontology elements to enable semantic similarity matching. + +**Implementation**: +The Ontology Embedder processes each element in the loaded ontologies (classes, object properties, and datatype properties) and generates vector embeddings using the EmbeddingsClientSpec service. For each element, it combines the element's identifier, labels, and description (comment) to create a text representation. This text is then converted to a high-dimensional vector embedding that captures its semantic meaning. These embeddings are stored in a per-flow in-memory FAISS vector store along with metadata about the element type, source ontology, and full definition. The embedder automatically detects the embedding dimension from the first embedding response. + +**Key Operations**: +- Create text representations from element IDs, labels, and comments +- Generate embeddings via EmbeddingsClientSpec (using asyncio.gather for batch processing) +- Store embeddings with comprehensive metadata in FAISS vector store +- Index by ontology, element type, and element ID for efficient retrieval +- Auto-detect embedding dimensions for vector store initialization +- Handle per-flow embedding models with independent vector stores + +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py` + +#### 3. Text Processor (Sentence Splitter) + +**Purpose**: Decomposes text chunks into fine-grained segments for precise ontology matching. + +**Implementation**: +The Text Processor uses NLTK for sentence tokenization and POS tagging to break down incoming text chunks into sentences. It handles NLTK version compatibility by attempting to download `punkt_tab` and `averaged_perceptron_tagger_eng` with fallbacks to older versions if needed. Each text chunk is split into individual sentences that can be independently matched against ontology elements. + +**Key Operations**: +- Split text into sentences using NLTK sentence tokenization +- Handle NLTK version compatibility (punkt_tab vs punkt) +- Create TextSegment objects with text and position information +- Support both complete sentences and individual chunks + +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py` + +#### 4. Ontology Selector + +**Purpose**: Identifies the most relevant subset of ontology elements for the current text chunk. + +**Implementation**: +The Ontology Selector performs semantic matching between text segments and ontology elements using FAISS vector similarity search. For each sentence from the text chunk, it generates an embedding and searches the vector store for the most similar ontology elements using cosine similarity with a configurable threshold (default 0.3). After collecting all relevant elements, it performs comprehensive dependency resolution: if a class is selected, its parent classes are included; if a property is selected, its domain and range classes are added. Additionally, for each selected class, it automatically includes **all properties that reference that class** in their domain or range. This ensures the extraction has access to all relevant relationship properties. + +**Key Operations**: +- Generate embeddings for each text segment (sentences) +- Perform k-nearest neighbor search in FAISS vector store (top_k=10, threshold=0.3) +- Apply similarity threshold to filter weak matches +- Resolve dependencies (parent classes, domains, ranges) +- **Auto-include all properties related to selected classes** (domain/range matching) +- Construct coherent ontology subset with all required relationships +- Deduplicate elements appearing multiple times + +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py` + +#### 5. Prompt Construction + +**Purpose**: Creates structured prompts that guide the LLM to extract only ontology-conformant triples. + +**Implementation**: +The extraction service uses a Jinja2 template loaded from `ontology-prompt.md` which formats the ontology subset and text for LLM extraction. The template dynamically iterates over classes, object properties, and datatype properties using Jinja2 syntax, presenting each with their descriptions, domains, ranges, and hierarchical relationships. The prompt includes strict rules about using only the provided ontology elements and requests JSON output format for consistent parsing. + +**Key Operations**: +- Use Jinja2 template with loops over ontology elements +- Format classes with parent relationships (subclass_of) and comments +- Format properties with domain/range constraints and comments +- Include explicit extraction rules and output format requirements +- Call prompt service with template ID "extract-with-ontologies" + +**Template Location**: `ontology-prompt.md` +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/extract.py` (build_extraction_variables method) + +#### 6. Main Extractor Service + +**Purpose**: Coordinates all components to perform end-to-end ontology-based triple extraction. + +**Implementation**: +The Main Extractor Service (KgExtractOntology) is the orchestration layer that manages the complete extraction workflow. It uses TrustGraph's FlowProcessor pattern with per-flow component initialization. When an ontology configuration update arrives, it initializes or updates the flow-specific components (ontology loader, embedder, text processor, selector). When a text chunk arrives for processing, it coordinates the pipeline: splitting the text into segments, finding relevant ontology elements through vector search, constructing a constrained prompt, calling the prompt service, parsing and validating the response, generating ontology definition triples, and emitting both content triples and entity contexts. + +**Extraction Pipeline**: +1. Receive text chunk via chunks-input queue +2. Initialize flow components if needed (on first chunk or config update) +3. Split text into sentences using NLTK +4. Search FAISS vector store to find relevant ontology concepts +5. Build ontology subset with automatic property inclusion +6. Construct Jinja2-templated prompt variables +7. Call prompt service with extract-with-ontologies template +8. Parse JSON response into structured triples +9. Validate triples and expand URIs to full ontology URIs +10. Generate ontology definition triples (classes and properties with labels/comments/domains/ranges) +11. Build entity contexts from all triples +12. Emit to triples and entity-contexts queues + +**Key Features**: +- Per-flow vector stores supporting different embedding models +- Event-driven ontology updates via config-update queue +- Automatic URI expansion using ontology URIs +- Ontology elements added to knowledge graph with full metadata +- Entity contexts include both content and ontology elements + +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/extract.py` + +### Configuration + +The service uses TrustGraph's standard configuration approach with command-line arguments: + +```bash +kg-extract-ontology \ + --id kg-extract-ontology \ + --pulsar-host localhost:6650 \ + --input-queue chunks \ + --config-input-queue config-update \ + --output-queue triples \ + --entity-contexts-output-queue entity-contexts +``` + +**Key Configuration Parameters**: +- `similarity_threshold`: 0.3 (default, configurable in code) +- `top_k`: 10 (number of ontology elements to retrieve per segment) +- `vector_store`: Per-flow FAISS IndexFlatIP with auto-detected dimensions +- `text_processor`: NLTK with punkt_tab sentence tokenization +- `prompt_template`: "extract-with-ontologies" (Jinja2 template) + +**Ontology Configuration**: +Ontologies are loaded dynamically via the config-update queue with type="ontology". + +### Data Flow + +1. **Initialisation Phase** (per flow): + - Receive ontology configuration via config-update queue + - Parse ontology JSON into OntologyClass and OntologyProperty objects + - Generate embeddings for all ontology elements using EmbeddingsClientSpec + - Store embeddings in per-flow FAISS vector store + - Auto-detect embedding dimensions from first response + +2. **Extraction Phase** (per chunk): + - Receive chunk from chunks-input queue + - Split chunk into sentences using NLTK + - Compute embeddings for each sentence + - Search FAISS vector store for relevant ontology elements + - Build ontology subset with automatic property inclusion + - Construct Jinja2 template variables with text and ontology + - Call prompt service with extract-with-ontologies template + - Parse JSON response and validate triples + - Expand URIs using ontology URIs + - Generate ontology definition triples + - Build entity contexts from all triples + - Emit to triples and entity-contexts queues + +### In-Memory Vector Store + +**Purpose**: Provides fast, memory-based similarity search for ontology element matching. + +**Implementation: FAISS** + +The system uses **FAISS (Facebook AI Similarity Search)** with IndexFlatIP for exact cosine similarity search. Key features: + +- **IndexFlatIP**: Exact cosine similarity search using inner product +- **Auto-detection**: Dimension determined from first embedding response +- **Per-flow stores**: Each flow has independent vector store for different embedding models +- **Normalization**: All vectors normalized before indexing +- **Batch operations**: Efficient batch add for initial ontology loading + +**Implementation Location**: `trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py` + +### Ontology Subset Selection Algorithm + +**Purpose**: Dynamically selects the minimal relevant portion of the ontology for each text chunk. + +**Detailed Algorithm Steps**: + +1. **Text Segmentation**: + - Split the input chunk into sentences using NLP sentence detection + - Extract noun phrases, verb phrases, and named entities from each sentence + - Create a hierarchical structure of segments preserving context + +2. **Embedding Generation**: + - Generate vector embeddings for each text segment (sentences and phrases) + - Use the same embedding model as used for ontology elements + - Cache embeddings for repeated segments to improve performance + +3. **Similarity Search**: + - For each text segment embedding, search the vector store + - Retrieve top-k (e.g., 10) most similar ontology elements + - Apply similarity threshold (e.g., 0.7) to filter weak matches + - Aggregate results across all segments, tracking match frequencies + +4. **Dependency Resolution**: + - For each selected class, recursively include all parent classes up to root + - For each selected property, include its domain and range classes + - For inverse properties, ensure both directions are included + - Add equivalent classes if they exist in the ontology + +5. **Subset Construction**: + - Deduplicate collected elements while preserving relationships + - Organise into classes, object properties, and datatype properties + - Ensure all constraints and relationships are preserved + - Create a self-contained mini-ontology that is valid and complete + +**Example Walkthrough**: +Given text: "The brown dog chased the white cat up the tree." +- Segments: ["brown dog", "white cat", "tree", "chased"] +- Matched elements: [dog (class), cat (class), animal (parent), chases (property)] +- Dependencies: [animal (parent of dog and cat), lifeform (parent of animal)] +- Final subset: Complete mini-ontology with animal hierarchy and chase relationship + +### Triple Validation + +**Purpose**: Ensures all extracted triples strictly conform to ontology constraints. + +**Validation Algorithm**: + +1. **Class Validation**: + - Verify that subjects are instances of classes defined in the ontology subset + - For object properties, verify that objects are also valid class instances + - Check class names against the ontology's class dictionary + - Handle class hierarchies - instances of subclasses are valid for parent class constraints + +2. **Property Validation**: + - Confirm predicates correspond to properties in the ontology subset + - Distinguish between object properties (entity-to-entity) and datatype properties (entity-to-literal) + - Verify property names match exactly (considering namespace if present) + +3. **Domain/Range Checking**: + - For each property used as predicate, retrieve its domain and range + - Verify the subject's type matches or inherits from the property's domain + - Verify the object's type matches or inherits from the property's range + - For datatype properties, verify the object is a literal of the correct XSD type + +4. **Cardinality Validation**: + - Track property usage counts per subject + - Check minimum cardinality - ensure required properties are present + - Check maximum cardinality - ensure property isn't used too many times + - For functional properties, ensure at most one value per subject + +5. **Datatype Validation**: + - Parse literal values according to their declared XSD types + - Validate integers are valid numbers, dates are properly formatted, etc. + - Check string patterns if regex constraints are defined + - Ensure URIs are well-formed for xsd:anyURI types + +**Validation Example**: +Triple: ("Buddy", "has-owner", "John") +- Check "Buddy" is typed as a class that can have "has-owner" property +- Check "has-owner" exists in the ontology +- Verify domain constraint: subject must be of type "Pet" or subclass +- Verify range constraint: object must be of type "Person" or subclass +- If valid, add to output; if invalid, log violation and skip + +## Performance Considerations + +### Optimisation Strategies + +1. **Embedding Caching**: Cache embeddings for frequently used text segments +2. **Batch Processing**: Process multiple segments in parallel +3. **Vector Store Indexing**: Use approximate nearest neighbor algorithms for large ontologies +4. **Prompt Optimisation**: Minimise prompt size by including only essential ontology elements +5. **Result Caching**: Cache extraction results for identical chunks + +### Scalability + +- **Horizontal Scaling**: Multiple extractor instances with shared ontology cache +- **Ontology Partitioning**: Split large ontologies by domain +- **Streaming Processing**: Process chunks as they arrive without batching +- **Memory Management**: Periodic cleanup of unused embeddings + +## Error Handling + +### Failure Scenarios + +1. **Missing Ontologies**: Fallback to unconstrained extraction +2. **Embedding Service Failure**: Use cached embeddings or skip semantic matching +3. **Prompt Service Timeout**: Retry with exponential backoff +4. **Invalid Triple Format**: Log and skip malformed triples +5. **Ontology Inconsistencies**: Report conflicts and use most specific valid elements + +### Monitoring + +Key metrics to track: + +- Ontology load time and memory usage +- Embedding generation latency +- Vector search performance +- Prompt service response time +- Triple extraction accuracy +- Ontology conformance rate + +## Migration Path + +### From Existing Extractors + +1. **Parallel Operation**: Run alongside existing extractors initially +2. **Gradual Rollout**: Start with specific document types +3. **Quality Comparison**: Compare output quality with existing extractors +4. **Full Migration**: Replace existing extractors once quality verified + +### Ontology Development + +1. **Bootstrap from Existing**: Generate initial ontologies from existing knowledge +2. **Iterative Refinement**: Refine based on extraction patterns +3. **Domain Expert Review**: Validate with subject matter experts +4. **Continuous Improvement**: Update based on extraction feedback + +## Ontology-Sensitive Query Service + +### Overview + +The ontology-sensitive query service provides multiple query paths to support different backend graph stores. It leverages ontology knowledge for precise, semantically-aware question answering across both Cassandra (via SPARQL) and Cypher-based graph stores (Neo4j, Memgraph, FalkorDB). + +**Service Components**: +- `onto-query-sparql`: Converts natural language to SPARQL for Cassandra +- `sparql-cassandra`: SPARQL query layer for Cassandra using rdflib +- `onto-query-cypher`: Converts natural language to Cypher for graph databases +- `cypher-executor`: Cypher query execution for Neo4j/Memgraph/FalkorDB + +### Architecture + +``` + ┌─────────────────┐ + │ User Query │ + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ ┌──────────────┐ + │ Question │────▶│ Sentence │ + │ Analyser │ │ Splitter │ + └────────┬────────┘ └──────────────┘ + │ + ▼ + ┌─────────────────┐ ┌──────────────┐ + │ Ontology │────▶│ Vector │ + │ Matcher │ │ Store │ + └────────┬────────┘ └──────────────┘ + │ + ▼ + ┌─────────────────┐ + │ Backend Router │ + └────────┬────────┘ + │ + ┌───────────┴───────────┐ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ onto-query- │ │ onto-query- │ + │ sparql │ │ cypher │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ SPARQL │ │ Cypher │ + │ Generator │ │ Generator │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ sparql- │ │ cypher- │ + │ cassandra │ │ executor │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ Cassandra │ │ Neo4j/Memgraph/ │ + │ │ │ FalkorDB │ + └────────┬────────┘ └────────┬────────┘ + │ │ + └────────────┬───────────────┘ + │ + ▼ + ┌─────────────────┐ ┌──────────────┐ + │ Answer │────▶│ Prompt │ + │ Generator │ │ Service │ + └────────┬────────┘ └──────────────┘ + │ + ▼ + ┌─────────────────┐ + │ Final Answer │ + └─────────────────┘ +``` + +### Query Processing Pipeline + +#### 1. Question Analyser + +**Purpose**: Decomposes user questions into semantic components for ontology matching. + +**Algorithm Description**: +The Question Analyser takes the incoming natural language question and breaks it down into meaningful segments using the same sentence splitting approach as the extraction pipeline. It identifies key entities, relationships, and constraints mentioned in the question. Each segment is analysed for question type (factual, aggregation, comparison, etc.) and the expected answer format. This decomposition helps identify which parts of the ontology are most relevant for answering the question. + +**Key Operations**: +- Split question into sentences and phrases +- Identify question type and intent +- Extract mentioned entities and relationships +- Detect constraints and filters in the question +- Determine expected answer format + +#### 2. Ontology Matcher for Queries + +**Purpose**: Identifies the relevant ontology subset needed to answer the question. + +**Algorithm Description**: +Similar to the extraction pipeline's Ontology Selector, but optimised for question answering. The matcher generates embeddings for question segments and searches the vector store for relevant ontology elements. However, it focuses on finding concepts that would be useful for query construction rather than extraction. It expands the selection to include related properties that might be traversed during graph exploration, even if not explicitly mentioned in the question. For example, if asked about "employees," it might include properties like "works-for," "manages," and "reports-to" that could be relevant for finding employee information. + +**Matching Strategy**: +- Embed question segments +- Find directly mentioned ontology concepts +- Include properties that connect mentioned classes +- Add inverse and related properties for traversal +- Include parent/child classes for hierarchical queries +- Build query-focused ontology partition + +#### 3. Backend Router + +**Purpose**: Routes queries to the appropriate backend-specific query path based on configuration. + +**Algorithm Description**: +The Backend Router examines the system configuration to determine which graph backend is active (Cassandra or Cypher-based). It routes the question and ontology partition to the appropriate query generation service. The router can also support load balancing across multiple backends or fallback mechanisms if the primary backend is unavailable. + +**Routing Logic**: +- Check configured backend type from system settings +- Route to `onto-query-sparql` for Cassandra backends +- Route to `onto-query-cypher` for Neo4j/Memgraph/FalkorDB +- Support multi-backend configurations with query distribution +- Handle failover and load balancing scenarios + +#### 4. SPARQL Query Generation (`onto-query-sparql`) + +**Purpose**: Converts natural language questions to SPARQL queries for Cassandra execution. + +**Algorithm Description**: +The SPARQL query generator takes the question and ontology partition and constructs a SPARQL query optimised for execution against the Cassandra backend. It uses the prompt service with a SPARQL-specific template that includes RDF/OWL semantics. The generator understands SPARQL patterns like property paths, optional clauses, and filters that can efficiently translate to Cassandra operations. + +**SPARQL Generation Prompt Template**: +``` +Generate a SPARQL query for the following question using the provided ontology. + +ONTOLOGY CLASSES: +{classes} + +ONTOLOGY PROPERTIES: +{properties} + +RULES: +- Use proper RDF/OWL semantics +- Include relevant prefixes +- Use property paths for hierarchical queries +- Add FILTER clauses for constraints +- Optimise for Cassandra backend + +QUESTION: {question} + +SPARQL QUERY: +``` + +#### 5. Cypher Query Generation (`onto-query-cypher`) + +**Purpose**: Converts natural language questions to Cypher queries for graph databases. + +**Algorithm Description**: +The Cypher query generator creates native Cypher queries optimised for Neo4j, Memgraph, and FalkorDB. It maps ontology classes to node labels and properties to relationships, using Cypher's pattern matching syntax. The generator includes Cypher-specific optimisations like relationship direction hints, index usage, and query planning hints. + +**Cypher Generation Prompt Template**: +``` +Generate a Cypher query for the following question using the provided ontology. + +NODE LABELS (from classes): +{classes} + +RELATIONSHIP TYPES (from properties): +{properties} + +RULES: +- Use MATCH patterns for graph traversal +- Include WHERE clauses for filters +- Use aggregation functions when needed +- Optimise for graph database performance +- Consider index hints for large datasets + +QUESTION: {question} + +CYPHER QUERY: +``` + +#### 6. SPARQL-Cassandra Query Engine (`sparql-cassandra`) + +**Purpose**: Executes SPARQL queries against Cassandra using Python rdflib. + +**Algorithm Description**: +The SPARQL-Cassandra engine implements a SPARQL processor using Python's rdflib library with a custom Cassandra backend store. It translates SPARQL graph patterns into appropriate Cassandra CQL queries, handling joins, filters, and aggregations. The engine maintains an RDF-to-Cassandra mapping that preserves the semantic structure while optimising for Cassandra's column-family storage model. + +**Implementation Features**: +- rdflib Store interface implementation for Cassandra +- SPARQL 1.1 query support with common patterns +- Efficient translation of triple patterns to CQL +- Support for property paths and hierarchical queries +- Result streaming for large datasets +- Connection pooling and query caching + +**Example Translation**: +```sparql +SELECT ?animal WHERE { + ?animal rdf:type :Animal . + ?animal :hasOwner "John" . +} +``` +Translates to optimised Cassandra queries leveraging indexes and partition keys. + +#### 7. Cypher Query Executor (`cypher-executor`) + +**Purpose**: Executes Cypher queries against Neo4j, Memgraph, and FalkorDB. + +**Algorithm Description**: +The Cypher executor provides a unified interface for executing Cypher queries across different graph databases. It handles database-specific connection protocols, query optimisation hints, and result format normalisation. The executor includes retry logic, connection pooling, and transaction management appropriate for each database type. + +**Multi-Database Support**: +- **Neo4j**: Bolt protocol, transaction functions, index hints +- **Memgraph**: Custom protocol, streaming results, analytical queries +- **FalkorDB**: Redis protocol adaptation, in-memory optimisations + +**Execution Features**: +- Database-agnostic connection management +- Query validation and syntax checking +- Timeout and resource limit enforcement +- Result pagination and streaming +- Performance monitoring per database type +- Automatic failover between database instances + +#### 8. Answer Generator + +**Purpose**: Synthesises a natural language answer from query results. + +**Algorithm Description**: +The Answer Generator takes the structured query results and the original question, then uses the prompt service to generate a comprehensive answer. Unlike simple template-based responses, it uses an LLM to interpret the graph data in the context of the question, handling complex relationships, aggregations, and inferences. The generator can explain its reasoning by referencing the ontology structure and the specific triples retrieved from the graph. + +**Answer Generation Process**: +- Format query results into structured context +- Include relevant ontology definitions for clarity +- Construct prompt with question and results +- Generate natural language answer via LLM +- Validate answer against query intent +- Add citations to specific graph entities if needed + +### Integration with Existing Services + +#### Relationship with GraphRAG + +- **Complementary**: onto-query provides semantic precision while GraphRAG provides broad coverage +- **Shared Infrastructure**: Both use the same knowledge graph and prompt services +- **Query Routing**: System can route queries to most appropriate service based on question type +- **Hybrid Mode**: Can combine both approaches for comprehensive answers + +#### Relationship with OntoRAG Extraction + +- **Shared Ontologies**: Uses same ontology configurations loaded by kg-extract-ontology +- **Shared Vector Store**: Reuses the in-memory embeddings from extraction service +- **Consistent Semantics**: Queries operate on graphs built with same ontological constraints + +### Query Examples + +#### Example 1: Simple Entity Query +**Question**: "What animals are mammals?" +**Ontology Match**: [animal, mammal, subClassOf] +**Generated Query**: +```cypher +MATCH (a:animal)-[:subClassOf*]->(m:mammal) +RETURN a.name +``` + +#### Example 2: Relationship Query +**Question**: "Which documents were authored by John Smith?" +**Ontology Match**: [document, person, has-author] +**Generated Query**: +```cypher +MATCH (d:document)-[:has-author]->(p:person {name: "John Smith"}) +RETURN d.title, d.date +``` + +#### Example 3: Aggregation Query +**Question**: "How many legs do cats have?" +**Ontology Match**: [cat, number-of-legs (datatype property)] +**Generated Query**: +```cypher +MATCH (c:cat) +RETURN c.name, c.number_of_legs +``` + +### Configuration + +```yaml +onto-query: + embedding_model: "text-embedding-3-small" + vector_store: + shared_with_extractor: true # Reuse kg-extract-ontology's store + query_builder: + model: "gpt-4" + temperature: 0.1 + max_query_length: 1000 + graph_executor: + timeout: 30000 # ms + max_results: 1000 + answer_generator: + model: "gpt-4" + temperature: 0.3 + max_tokens: 500 +``` + +### Performance Optimisations + +#### Query Optimisation + +- **Ontology Pruning**: Only include necessary ontology elements in prompts +- **Query Caching**: Cache frequently asked questions and their queries +- **Result Caching**: Store results for identical queries within time window +- **Batch Processing**: Handle multiple related questions in single graph traversal + +#### Scalability Considerations + +- **Distributed Execution**: Parallelise subqueries across graph partitions +- **Incremental Results**: Stream results for large datasets +- **Load Balancing**: Distribute query load across multiple service instances +- **Resource Pools**: Manage connection pools to graph databases + +### Error Handling + +#### Failure Scenarios + +1. **Invalid Query Generation**: Fallback to GraphRAG or simple keyword search +2. **Ontology Mismatch**: Expand search to broader ontology subset +3. **Query Timeout**: Simplify query or increase timeout +4. **Empty Results**: Suggest query reformulation or related questions +5. **LLM Service Failure**: Use cached queries or template-based responses + +### Monitoring Metrics + +- Question complexity distribution +- Ontology partition sizes +- Query generation success rate +- Graph query execution time +- Answer quality scores +- Cache hit rates +- Error frequencies by type + +## Future Enhancements + +1. **Ontology Learning**: Automatically extend ontologies based on extraction patterns +2. **Confidence Scoring**: Assign confidence scores to extracted triples +3. **Explanation Generation**: Provide reasoning for triple extraction +4. **Active Learning**: Request human validation for uncertain extractions + +## Security Considerations + +1. **Prompt Injection Prevention**: Sanitise chunk text before prompt construction +2. **Resource Limits**: Cap memory usage for vector store +3. **Rate Limiting**: Limit extraction requests per client +4. **Audit Logging**: Track all extraction requests and results + +## Testing Strategy + +### Unit Testing + +- Ontology loader with various formats +- Embedding generation and storage +- Sentence splitting algorithms +- Vector similarity calculations +- Triple parsing and validation + +### Integration Testing + +- End-to-end extraction pipeline +- Configuration service integration +- Prompt service interaction +- Concurrent extraction handling + +### Performance Testing + +- Large ontology handling (1000+ classes) +- High-volume chunk processing +- Memory usage under load +- Latency benchmarks + +## Delivery Plan + +### Overview + +The OntoRAG system will be delivered in four major phases, with each phase providing incremental value while building toward the complete system. The plan focuses on establishing core extraction capabilities first, then adding query functionality, followed by optimizations and advanced features. + +### Phase 1: Foundation and Core Extraction + +**Goal**: Establish the basic ontology-driven extraction pipeline with simple vector matching. + +#### Step 1.1: Ontology Management Foundation +- Implement ontology configuration loader (`OntologyLoader`) +- Parse and validate ontology JSON structures +- Create in-memory ontology storage and access patterns +- Implement ontology refresh mechanism + +**Success Criteria**: +- Successfully load and parse ontology configurations +- Validate ontology structure and consistency +- Handle multiple concurrent ontologies + +#### Step 1.2: Vector Store Implementation +- Implement simple NumPy-based vector store as initial prototype +- Add FAISS vector store implementation +- Create vector store interface abstraction +- Implement similarity search with configurable thresholds + +**Success Criteria**: +- Store and retrieve embeddings efficiently +- Perform similarity search with <100ms latency +- Support both NumPy and FAISS backends + +#### Step 1.3: Ontology Embedding Pipeline +- Integrate with embedding service +- Implement `OntologyEmbedder` component +- Generate embeddings for all ontology elements +- Store embeddings with metadata in vector store + +**Success Criteria**: +- Generate embeddings for classes and properties +- Store embeddings with proper metadata +- Rebuild embeddings on ontology updates + +#### Step 1.4: Text Processing Components +- Implement sentence splitter using NLTK/spaCy +- Extract phrases and named entities +- Create text segment hierarchy +- Generate embeddings for text segments + +**Success Criteria**: +- Accurately split text into sentences +- Extract meaningful phrases +- Maintain context relationships + +#### Step 1.5: Ontology Selection Algorithm +- Implement similarity matching between text and ontology +- Build dependency resolution for ontology elements +- Create minimal coherent ontology subsets +- Optimize subset generation performance + +**Success Criteria**: +- Select relevant ontology elements with >80% precision +- Include all necessary dependencies +- Generate subsets in <500ms + +#### Step 1.6: Basic Extraction Service +- Implement prompt construction for extraction +- Integrate with prompt service +- Parse and validate triple responses +- Create `kg-extract-ontology` service endpoint + +**Success Criteria**: +- Extract ontology-conformant triples +- Validate all triples against ontology +- Handle extraction errors gracefully + +### Phase 2: Query System Implementation + +**Goal**: Add ontology-aware query capabilities with support for multiple backends. + +#### Step 2.1: Query Foundation Components +- Implement question analyzer +- Create ontology matcher for queries +- Adapt vector search for query context +- Build backend router component + +**Success Criteria**: +- Analyze questions into semantic components +- Match questions to relevant ontology elements +- Route queries to appropriate backend + +#### Step 2.2: SPARQL Path Implementation +- Implement `onto-query-sparql` service +- Create SPARQL query generator using LLM +- Develop prompt templates for SPARQL generation +- Validate generated SPARQL syntax + +**Success Criteria**: +- Generate valid SPARQL queries +- Use appropriate SPARQL patterns +- Handle complex query types + +#### Step 2.3: SPARQL-Cassandra Engine +- Implement rdflib Store interface for Cassandra +- Create CQL query translator +- Optimize triple pattern matching +- Handle SPARQL result formatting + +**Success Criteria**: +- Execute SPARQL queries on Cassandra +- Support common SPARQL patterns +- Return results in standard format + +#### Step 2.4: Cypher Path Implementation +- Implement `onto-query-cypher` service +- Create Cypher query generator using LLM +- Develop prompt templates for Cypher generation +- Validate generated Cypher syntax + +**Success Criteria**: +- Generate valid Cypher queries +- Use appropriate graph patterns +- Support Neo4j, Memgraph, FalkorDB + +#### Step 2.5: Cypher Executor +- Implement multi-database Cypher executor +- Support Bolt protocol (Neo4j/Memgraph) +- Support Redis protocol (FalkorDB) +- Handle result normalization + +**Success Criteria**: +- Execute Cypher on all target databases +- Handle database-specific differences +- Maintain connection pools efficiently + +#### Step 2.6: Answer Generation +- Implement answer generator component +- Create prompts for answer synthesis +- Format query results for LLM consumption +- Generate natural language answers + +**Success Criteria**: +- Generate accurate answers from query results +- Maintain context from original question +- Provide clear, concise responses + +### Phase 3: Optimization and Robustness + +**Goal**: Optimize performance, add caching, improve error handling, and enhance reliability. + +#### Step 3.1: Performance Optimization +- Implement embedding caching +- Add query result caching +- Optimize vector search with FAISS IVF indexes +- Implement batch processing for embeddings + +**Success Criteria**: +- Reduce average query latency by 50% +- Support 10x more concurrent requests +- Maintain sub-second response times + +#### Step 3.2: Advanced Error Handling +- Implement comprehensive error recovery +- Add fallback mechanisms between query paths +- Create retry logic with exponential backoff +- Improve error logging and diagnostics + +**Success Criteria**: +- Gracefully handle all failure scenarios +- Automatic failover between backends +- Detailed error reporting for debugging + +#### Step 3.3: Monitoring and Observability +- Add performance metrics collection +- Implement query tracing +- Create health check endpoints +- Add resource usage monitoring + +**Success Criteria**: +- Track all key performance indicators +- Identify bottlenecks quickly +- Monitor system health in real-time + +#### Step 3.4: Configuration Management +- Implement dynamic configuration updates +- Add configuration validation +- Create configuration templates +- Support environment-specific settings + +**Success Criteria**: +- Update configuration without restart +- Validate all configuration changes +- Support multiple deployment environments + +### Phase 4: Advanced Features + +**Goal**: Add sophisticated capabilities for production deployment and enhanced functionality. + +#### Step 4.1: Multi-Ontology Support +- Implement ontology selection logic +- Support cross-ontology queries +- Handle ontology versioning +- Create ontology merge capabilities + +**Success Criteria**: +- Query across multiple ontologies +- Handle ontology conflicts +- Support ontology evolution + +#### Step 4.2: Intelligent Query Routing +- Implement performance-based routing +- Add query complexity analysis +- Create adaptive routing algorithms +- Support A/B testing for paths + +**Success Criteria**: +- Route queries optimally +- Learn from query performance +- Improve routing over time + +#### Step 4.3: Advanced Extraction Features +- Add confidence scoring for triples +- Implement explanation generation +- Create feedback loops for improvement +- Support incremental learning + +**Success Criteria**: +- Provide confidence scores +- Explain extraction decisions +- Continuously improve accuracy + +#### Step 4.4: Production Hardening +- Add rate limiting +- Implement authentication/authorization +- Create deployment automation +- Add backup and recovery + +**Success Criteria**: +- Production-ready security +- Automated deployment pipeline +- Disaster recovery capability + +### Delivery Milestones + +1. **Milestone 1** (End of Phase 1): Basic ontology-driven extraction operational +2. **Milestone 2** (End of Phase 2): Full query system with both SPARQL and Cypher paths +3. **Milestone 3** (End of Phase 3): Optimized, robust system ready for staging +4. **Milestone 4** (End of Phase 4): Production-ready system with advanced features + +### Risk Mitigation + +#### Technical Risks +- **Vector Store Scalability**: Start with NumPy, migrate to FAISS gradually +- **Query Generation Accuracy**: Implement validation and fallback mechanisms +- **Backend Compatibility**: Test extensively with each database type +- **Performance Bottlenecks**: Profile early and often, optimize iteratively + +#### Operational Risks +- **Ontology Quality**: Implement validation and consistency checking +- **Service Dependencies**: Add circuit breakers and fallbacks +- **Resource Constraints**: Monitor and set appropriate limits +- **Data Consistency**: Implement proper transaction handling + +### Success Metrics + +#### Phase 1 Success Metrics +- Extraction accuracy: >90% ontology conformance +- Processing speed: <1 second per chunk +- Ontology load time: <10 seconds +- Vector search latency: <100ms + +#### Phase 2 Success Metrics +- Query success rate: >95% +- Query latency: <2 seconds end-to-end +- Backend compatibility: 100% for target databases +- Answer accuracy: >85% based on available data + +#### Phase 3 Success Metrics +- System uptime: >99.9% +- Error recovery rate: >95% +- Cache hit rate: >60% +- Concurrent users: >100 + +#### Phase 4 Success Metrics +- Multi-ontology queries: Fully supported +- Routing optimization: 30% latency reduction +- Confidence scoring accuracy: >90% +- Production deployment: Zero-downtime updates + +## References + +- [OWL 2 Web Ontology Language](https://www.w3.org/TR/owl2-overview/) +- [GraphRAG Architecture](https://github.com/microsoft/graphrag) +- [Sentence Transformers](https://www.sbert.net/) +- [FAISS Vector Search](https://github.com/facebookresearch/faiss) +- [spaCy NLP Library](https://spacy.io/) +- [rdflib Documentation](https://rdflib.readthedocs.io/) +- [Neo4j Bolt Protocol](https://neo4j.com/docs/bolt/current/) diff --git a/docs/tech-specs/vector-store-lifecycle.md b/docs/tech-specs/vector-store-lifecycle.md new file mode 100644 index 00000000..dcbb73e1 --- /dev/null +++ b/docs/tech-specs/vector-store-lifecycle.md @@ -0,0 +1,299 @@ +# Vector Store Lifecycle Management + +## Overview + +This document describes how TrustGraph manages vector store collections across different backend implementations (Qdrant, Pinecone, Milvus). The design addresses the challenge of supporting embeddings with different dimensions without hardcoding dimension values. + +## Problem Statement + +Vector stores require the embedding dimension to be specified when creating collections/indexes. However: +- Different embedding models produce different dimensions (e.g., 384, 768, 1536) +- The dimension is not known until the first embedding is generated +- A single TrustGraph collection may receive embeddings from multiple models +- Hardcoding a dimension (e.g., 384) causes failures with other embedding sizes + +## Design Principles + +1. **Lazy Creation**: Collections are created on-demand during first write, not during collection management operations +2. **Dimension-Based Naming**: Collection names include the embedding dimension as a suffix +3. **Graceful Degradation**: Queries against non-existent collections return empty results, not errors +4. **Multi-Dimension Support**: A single logical collection can have multiple physical collections (one per dimension) + +## Architecture + +### Collection Naming Convention + +Vector store collections use dimension suffixes to support multiple embedding sizes: + +**Document Embeddings:** +- Qdrant: `d_{user}_{collection}_{dimension}` +- Pinecone: `d-{user}-{collection}-{dimension}` +- Milvus: `doc_{user}_{collection}_{dimension}` + +**Graph Embeddings:** +- Qdrant: `t_{user}_{collection}_{dimension}` +- Pinecone: `t-{user}-{collection}-{dimension}` +- Milvus: `entity_{user}_{collection}_{dimension}` + +Examples: +- `d_alice_papers_384` - Alice's papers collection with 384-dimensional embeddings +- `d_alice_papers_768` - Same logical collection with 768-dimensional embeddings +- `t_bob_knowledge_1536` - Bob's knowledge graph with 1536-dimensional embeddings + +### Lifecycle Phases + +#### 1. Collection Creation Request + +**Request Flow:** +``` +User/System → Librarian → Storage Management Topic → Vector Stores +``` + +**Behavior:** +- The librarian broadcasts `create-collection` requests to all storage backends +- Vector store processors acknowledge the request but **do not create physical collections** +- Response is returned immediately with success +- Actual collection creation is deferred until first write + +**Rationale:** +- Dimension is unknown at creation time +- Avoids creating collections with wrong dimensions +- Simplifies collection management logic + +#### 2. Write Operations (Lazy Creation) + +**Write Flow:** +``` +Data → Storage Processor → Check Collection → Create if Needed → Insert +``` + +**Behavior:** +1. Extract embedding dimension from the vector: `dim = len(vector)` +2. Construct collection name with dimension suffix +3. Check if collection exists with that specific dimension +4. If not exists: + - Create collection with correct dimension + - Log: `"Lazily creating collection {name} with dimension {dim}"` +5. Insert the embedding into the dimension-specific collection + +**Example Scenario:** +``` +1. User creates collection "papers" + → No physical collections created yet + +2. First document with 384-dim embedding arrives + → Creates d_user_papers_384 + → Inserts data + +3. Second document with 768-dim embedding arrives + → Creates d_user_papers_768 + → Inserts data + +Result: Two physical collections for one logical collection +``` + +#### 3. Query Operations + +**Query Flow:** +``` +Query Vector → Determine Dimension → Check Collection → Search or Return Empty +``` + +**Behavior:** +1. Extract dimension from query vector: `dim = len(vector)` +2. Construct collection name with dimension suffix +3. Check if collection exists +4. If exists: + - Perform similarity search + - Return results +5. If not exists: + - Log: `"Collection {name} does not exist, returning empty results"` + - Return empty list (no error raised) + +**Multiple Dimensions in Same Query:** +- If query contains vectors of different dimensions +- Each dimension queries its corresponding collection +- Results are aggregated +- Missing collections are skipped (not treated as errors) + +**Rationale:** +- Querying an empty collection is a valid use case +- Returning empty results is semantically correct +- Avoids errors during system startup or before data ingestion + +#### 4. Collection Deletion + +**Delete Flow:** +``` +Delete Request → List All Collections → Filter by Prefix → Delete All Matches +``` + +**Behavior:** +1. Construct prefix pattern: `d_{user}_{collection}_` (note trailing underscore) +2. List all collections in the vector store +3. Filter collections matching the prefix +4. Delete all matching collections +5. Log each deletion: `"Deleted collection {name}"` +6. Summary log: `"Deleted {count} collection(s) for {user}/{collection}"` + +**Example:** +``` +Collections in store: +- d_alice_papers_384 +- d_alice_papers_768 +- d_alice_reports_384 +- d_bob_papers_384 + +Delete "papers" for alice: +→ Deletes: d_alice_papers_384, d_alice_papers_768 +→ Keeps: d_alice_reports_384, d_bob_papers_384 +``` + +**Rationale:** +- Ensures complete cleanup of all dimension variants +- Pattern matching prevents accidental deletion of unrelated collections +- Atomic operation from user perspective (all dimensions deleted together) + +## Behavioral Characteristics + +### Normal Operations + +**Collection Creation:** +- ✓ Returns success immediately +- ✓ No physical storage allocated +- ✓ Fast operation (no backend I/O) + +**First Write:** +- ✓ Creates collection with correct dimension +- ✓ Slightly slower due to collection creation overhead +- ✓ Subsequent writes to same dimension are fast + +**Queries Before Any Writes:** +- ✓ Returns empty results +- ✓ No errors or exceptions +- ✓ System remains stable + +**Mixed Dimension Writes:** +- ✓ Automatically creates separate collections per dimension +- ✓ Each dimension isolated in its own collection +- ✓ No dimension conflicts or schema errors + +**Collection Deletion:** +- ✓ Removes all dimension variants +- ✓ Complete cleanup +- ✓ No orphaned collections + +### Edge Cases + +**Multiple Embedding Models:** +``` +Scenario: User switches from model A (384-dim) to model B (768-dim) +Behavior: +- Both dimensions coexist in separate collections +- Old data (384-dim) remains queryable with 384-dim vectors +- New data (768-dim) queryable with 768-dim vectors +- Cross-dimension queries return results only for matching dimension +``` + +**Concurrent First Writes:** +``` +Scenario: Multiple processes write to same collection simultaneously +Behavior: +- Each process checks for existence before creating +- Most vector stores handle concurrent creation gracefully +- If race condition occurs, second create is typically idempotent +- Final state: Collection exists and both writes succeed +``` + +**Dimension Migration:** +``` +Scenario: User wants to migrate from 384-dim to 768-dim embeddings +Behavior: +- No automatic migration +- Old collection (384-dim) persists +- New collection (768-dim) created on first new write +- Both dimensions remain accessible +- Manual deletion of old dimension collections possible +``` + +**Empty Collection Queries:** +``` +Scenario: Query a collection that has never received data +Behavior: +- Collection doesn't exist (never created) +- Query returns empty list +- No error state +- System logs: "Collection does not exist, returning empty results" +``` + +## Implementation Notes + +### Storage Backend Specifics + +**Qdrant:** +- Uses `collection_exists()` for existence checks +- Uses `get_collections()` for listing during deletion +- Collection creation requires `VectorParams(size=dim, distance=Distance.COSINE)` + +**Pinecone:** +- Uses `has_index()` for existence checks +- Uses `list_indexes()` for listing during deletion +- Index creation requires waiting for "ready" status +- Serverless spec configured with cloud/region + +**Milvus:** +- Direct classes (`DocVectors`, `EntityVectors`) manage lifecycle +- Internal cache `self.collections[(dim, user, collection)]` for performance +- Collection names sanitized (alphanumeric + underscore only) +- Supports schema with auto-incrementing IDs + +### Performance Considerations + +**First Write Latency:** +- Additional overhead due to collection creation +- Qdrant: ~100-500ms +- Pinecone: ~10-30 seconds (serverless provisioning) +- Milvus: ~500-2000ms (includes indexing) + +**Query Performance:** +- Existence check adds minimal overhead (~1-10ms) +- No performance impact once collection exists +- Each dimension collection is independently optimized + +**Storage Overhead:** +- Minimal metadata per collection +- Main overhead is per-dimension storage +- Trade-off: Storage space vs. dimension flexibility + +## Future Considerations + +**Automatic Dimension Consolidation:** +- Could add background process to identify and merge unused dimension variants +- Would require re-embedding or dimension reduction + +**Dimension Discovery:** +- Could expose API to list all dimensions in use for a collection +- Useful for administration and monitoring + +**Default Dimension Preference:** +- Could track "primary" dimension per collection +- Use for queries when dimension context is unavailable + +**Storage Quotas:** +- May need per-collection dimension limits +- Prevent proliferation of dimension variants + +## Migration Notes + +**From Pre-Dimension-Suffix System:** +- Old collections: `d_{user}_{collection}` (no dimension suffix) +- New collections: `d_{user}_{collection}_{dim}` (with dimension suffix) +- No automatic migration - old collections remain accessible +- Consider manual migration script if needed +- Can run both naming schemes simultaneously + +## References + +- Collection Management: `docs/tech-specs/collection-management.md` +- Storage Schema: `trustgraph-base/trustgraph/schema/services/storage.py` +- Librarian Service: `trustgraph-flow/trustgraph/librarian/service.py` diff --git a/requirements.txt b/requirements.txt index 0d269066..68c21e1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ langchain-text-splitters langchain-community huggingface-hub requests -cassandra-driver +scylla-driver pulsar-client pypdf anthropic diff --git a/tests/unit/test_agent/test_agent_step_arguments.py b/tests/unit/test_agent/test_agent_step_arguments.py new file mode 100644 index 00000000..7243721d --- /dev/null +++ b/tests/unit/test_agent/test_agent_step_arguments.py @@ -0,0 +1,376 @@ +""" +Unit tests for AgentStep arguments type conversion + +Tests the fix for converting agent tool arguments to strings when creating +AgentStep records, ensuring compatibility with Pulsar schema that requires +Map(String()) for the arguments field. +""" + +import pytest +from unittest.mock import Mock, AsyncMock +from trustgraph.schema import AgentStep +from trustgraph.agent.react.types import Action + + +class TestAgentStepArgumentsConversion: + """Test cases for AgentStep arguments string conversion""" + + def test_agent_step_with_integer_arguments(self): + """Test that integer arguments are converted to strings""" + # Arrange + action = Action( + thought="Set volume to 10", + name="set_volume", + arguments={"volume_level": 10, "device": "speaker"}, + observation="Volume set successfully" + ) + + # Act - simulate the conversion that happens in service.py + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["volume_level"] == "10" + assert isinstance(agent_step.arguments["volume_level"], str) + assert agent_step.arguments["device"] == "speaker" + assert isinstance(agent_step.arguments["device"], str) + + def test_agent_step_with_float_arguments(self): + """Test that float arguments are converted to strings""" + # Arrange + action = Action( + thought="Set temperature", + name="set_temperature", + arguments={"temperature": 23.5, "unit": "celsius"}, + observation="Temperature set" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["temperature"] == "23.5" + assert isinstance(agent_step.arguments["temperature"], str) + assert agent_step.arguments["unit"] == "celsius" + + def test_agent_step_with_boolean_arguments(self): + """Test that boolean arguments are converted to strings""" + # Arrange + action = Action( + thought="Enable feature", + name="toggle_feature", + arguments={"enabled": True, "feature_name": "dark_mode"}, + observation="Feature toggled" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["enabled"] == "True" + assert isinstance(agent_step.arguments["enabled"], str) + assert agent_step.arguments["feature_name"] == "dark_mode" + + def test_agent_step_with_none_arguments(self): + """Test that None arguments are converted to strings""" + # Arrange + action = Action( + thought="Check status", + name="get_status", + arguments={"filter": None, "category": "all"}, + observation="Status retrieved" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["filter"] == "None" + assert isinstance(agent_step.arguments["filter"], str) + assert agent_step.arguments["category"] == "all" + + def test_agent_step_with_mixed_type_arguments(self): + """Test that mixed type arguments are all converted to strings""" + # Arrange + action = Action( + thought="Configure device", + name="configure_device", + arguments={ + "name": "Hifi", + "volume_level": 10, + "bass_boost": 1.5, + "enabled": True, + "preset": None + }, + observation="Device configured" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert - all values should be strings + assert all(isinstance(v, str) for v in agent_step.arguments.values()) + assert agent_step.arguments["name"] == "Hifi" + assert agent_step.arguments["volume_level"] == "10" + assert agent_step.arguments["bass_boost"] == "1.5" + assert agent_step.arguments["enabled"] == "True" + assert agent_step.arguments["preset"] == "None" + + def test_agent_step_with_string_arguments(self): + """Test that string arguments remain strings (no double conversion)""" + # Arrange + action = Action( + thought="Search for information", + name="search", + arguments={"query": "test query", "limit": "10"}, + observation="Search completed" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["query"] == "test query" + assert agent_step.arguments["limit"] == "10" + assert all(isinstance(v, str) for v in agent_step.arguments.values()) + + def test_agent_step_with_empty_arguments(self): + """Test that empty arguments dict works correctly""" + # Arrange + action = Action( + thought="Perform action", + name="do_something", + arguments={}, + observation="Action completed" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments == {} + assert isinstance(agent_step.arguments, dict) + + def test_agent_step_with_numeric_string_values(self): + """Test arguments that are already strings containing numbers""" + # Arrange + action = Action( + thought="Process order", + name="process_order", + arguments={"order_id": "12345", "quantity": 10}, + observation="Order processed" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["order_id"] == "12345" + assert agent_step.arguments["quantity"] == "10" + assert all(isinstance(v, str) for v in agent_step.arguments.values()) + + def test_agent_step_conversion_preserves_keys(self): + """Test that argument keys are preserved during conversion""" + # Arrange + action = Action( + thought="Test", + name="test_action", + arguments={ + "param1": 1, + "param_two": 2, + "PARAM_THREE": 3, + "param-four": 4 + }, + observation="Done" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert - verify all keys are preserved + assert set(agent_step.arguments.keys()) == { + "param1", "param_two", "PARAM_THREE", "param-four" + } + # Verify values are converted + assert agent_step.arguments["param1"] == "1" + assert agent_step.arguments["param_two"] == "2" + assert agent_step.arguments["PARAM_THREE"] == "3" + assert agent_step.arguments["param-four"] == "4" + + def test_real_world_home_assistant_example(self): + """Test with real-world Home Assistant volume control example""" + # Arrange - this is the exact scenario from the bug report + action = Action( + thought='The user wants to set the volume of the Hifi. The `set_device_volume` tool can be used for this purpose. The device name is "Hifi" and the desired volume level is 10.', + name='set_device_volume', + arguments={'name': 'Hifi', 'volume_level': 10}, + observation='{"speech": {}, "response_type": "action_done", "data": {"targets": [], "success": [{"name": "Hifi", "type": "entity", "id": "media_player.hifi"}], "failed": []}}' + ) + + # Act - this should not raise TypeError + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["name"] == "Hifi" + assert agent_step.arguments["volume_level"] == "10" + assert isinstance(agent_step.arguments["volume_level"], str) + + def test_multiple_actions_in_history(self): + """Test converting multiple actions in history (as done in service.py)""" + # Arrange + history = [ + Action( + thought="First action", + name="action1", + arguments={"count": 5}, + observation="Done 1" + ), + Action( + thought="Second action", + name="action2", + arguments={"enabled": True, "name": "test"}, + observation="Done 2" + ), + Action( + thought="Third action", + name="action3", + arguments={"value": 3.14}, + observation="Done 3" + ) + ] + + # Act - simulate the list comprehension in service.py + agent_steps = [ + AgentStep( + thought=h.thought, + action=h.name, + arguments={k: str(v) for k, v in h.arguments.items()}, + observation=h.observation + ) + for h in history + ] + + # Assert + assert len(agent_steps) == 3 + + # First action + assert agent_steps[0].arguments["count"] == "5" + assert isinstance(agent_steps[0].arguments["count"], str) + + # Second action + assert agent_steps[1].arguments["enabled"] == "True" + assert agent_steps[1].arguments["name"] == "test" + assert all(isinstance(v, str) for v in agent_steps[1].arguments.values()) + + # Third action + assert agent_steps[2].arguments["value"] == "3.14" + assert isinstance(agent_steps[2].arguments["value"], str) + + def test_arguments_with_special_characters(self): + """Test arguments containing special characters are properly converted""" + # Arrange + action = Action( + thought="Process data", + name="process", + arguments={ + "text": "Hello, World!", + "path": "/home/user/file.txt", + "pattern": "test-*-pattern", + "count": 42 + }, + observation="Processed" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["text"] == "Hello, World!" + assert agent_step.arguments["path"] == "/home/user/file.txt" + assert agent_step.arguments["pattern"] == "test-*-pattern" + assert agent_step.arguments["count"] == "42" + assert all(isinstance(v, str) for v in agent_step.arguments.values()) + + def test_zero_and_negative_numbers(self): + """Test that zero and negative numbers are converted correctly""" + # Arrange + action = Action( + thought="Test edge cases", + name="edge_test", + arguments={ + "zero": 0, + "negative": -5, + "negative_float": -3.14, + "positive": 10 + }, + observation="Done" + ) + + # Act + agent_step = AgentStep( + thought=action.thought, + action=action.name, + arguments={k: str(v) for k, v in action.arguments.items()}, + observation=action.observation + ) + + # Assert + assert agent_step.arguments["zero"] == "0" + assert agent_step.arguments["negative"] == "-5" + assert agent_step.arguments["negative_float"] == "-3.14" + assert agent_step.arguments["positive"] == "10" + assert all(isinstance(v, str) for v in agent_step.arguments.values()) diff --git a/tests/unit/test_agent/test_mcp_tool_auth.py b/tests/unit/test_agent/test_mcp_tool_auth.py new file mode 100644 index 00000000..82877226 --- /dev/null +++ b/tests/unit/test_agent/test_mcp_tool_auth.py @@ -0,0 +1,233 @@ +""" +Unit tests for MCP tool bearer token authentication + +Tests the authentication feature added to MCP tool service that allows +configuring optional bearer tokens for MCP server connections. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +import json + + +class TestMcpToolAuthentication: + """Test cases for MCP tool bearer token authentication""" + + def test_mcp_tool_with_auth_token_header_building(self): + """Test that auth token is correctly formatted in headers""" + # Arrange + mcp_config = { + "url": "https://secure.example.com/mcp", + "remote-name": "secure-tool", + "auth-token": "test-token-12345" + } + + # Act - simulate header building logic from service.py + headers = {} + if "auth-token" in mcp_config and mcp_config["auth-token"]: + token = mcp_config["auth-token"] + headers["Authorization"] = f"Bearer {token}" + + # Assert + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer test-token-12345" + assert headers["Authorization"].startswith("Bearer ") + + def test_mcp_tool_without_auth_token_header_building(self): + """Test that no auth header is added when token is not present (backward compatibility)""" + # Arrange + mcp_config = { + "url": "http://public.example.com/mcp", + "remote-name": "public-tool" + # No auth-token field + } + + # Act - simulate header building logic from service.py + headers = {} + if "auth-token" in mcp_config and mcp_config["auth-token"]: + token = mcp_config["auth-token"] + headers["Authorization"] = f"Bearer {token}" + + # Assert + assert headers == {} + assert "Authorization" not in headers + + def test_mcp_config_with_auth_token(self): + """Test MCP configuration parsing with auth-token""" + # Arrange + config = { + "mcp": { + "secure-tool": json.dumps({ + "url": "https://secure.example.com/mcp", + "remote-name": "secure-tool", + "auth-token": "test-token-xyz" + }), + "public-tool": json.dumps({ + "url": "http://public.example.com/mcp", + "remote-name": "public-tool" + }) + } + } + + # Act - simulate on_mcp_config + mcp_services = { + k: json.loads(v) + for k, v in config["mcp"].items() + } + + # Assert + assert "secure-tool" in mcp_services + assert mcp_services["secure-tool"]["auth-token"] == "test-token-xyz" + assert mcp_services["secure-tool"]["url"] == "https://secure.example.com/mcp" + + assert "public-tool" in mcp_services + assert "auth-token" not in mcp_services["public-tool"] + assert mcp_services["public-tool"]["url"] == "http://public.example.com/mcp" + + def test_auth_token_with_empty_string(self): + """Test that empty auth-token string is treated as no auth""" + # Arrange + config_data = { + "url": "https://example.com/mcp", + "remote-name": "test-tool", + "auth-token": "" + } + + # Act - simulate header building logic + headers = {} + if "auth-token" in config_data and config_data["auth-token"]: + headers["Authorization"] = f"Bearer {config_data['auth-token']}" + + # Assert + assert headers == {}, "Empty auth-token should not add Authorization header" + + def test_auth_token_with_special_characters(self): + """Test auth token with special characters (JWT-like)""" + # Arrange + jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + + config_data = { + "url": "https://example.com/mcp", + "auth-token": jwt_token + } + + # Act - simulate header building + headers = {} + if "auth-token" in config_data and config_data["auth-token"]: + token = config_data["auth-token"] + headers["Authorization"] = f"Bearer {token}" + + # Assert + assert headers["Authorization"] == f"Bearer {jwt_token}" + assert "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" in headers["Authorization"] + + def test_multiple_tools_with_different_auth_configs(self): + """Test handling multiple MCP tools with mixed auth configurations""" + # Arrange + mcp_services = { + "tool-a": { + "url": "https://a.example.com/mcp", + "auth-token": "token-a" + }, + "tool-b": { + "url": "https://b.example.com/mcp", + "auth-token": "token-b" + }, + "tool-c": { + "url": "http://c.example.com/mcp" + # No auth-token + } + } + + # Act - simulate header building for each tool + headers_a = {} + if "auth-token" in mcp_services["tool-a"] and mcp_services["tool-a"]["auth-token"]: + headers_a["Authorization"] = f"Bearer {mcp_services['tool-a']['auth-token']}" + + headers_b = {} + if "auth-token" in mcp_services["tool-b"] and mcp_services["tool-b"]["auth-token"]: + headers_b["Authorization"] = f"Bearer {mcp_services['tool-b']['auth-token']}" + + headers_c = {} + if "auth-token" in mcp_services["tool-c"] and mcp_services["tool-c"]["auth-token"]: + headers_c["Authorization"] = f"Bearer {mcp_services['tool-c']['auth-token']}" + + # Assert + assert headers_a == {"Authorization": "Bearer token-a"} + assert headers_b == {"Authorization": "Bearer token-b"} + assert headers_c == {} + + def test_auth_token_not_logged(self): + """Test that auth tokens are not exposed in logs""" + # This is more of a guideline test - in real implementation, + # we should ensure tokens are never logged + + # Arrange + auth_token = "super-secret-token-123" + config = { + "url": "https://secure.example.com/mcp", + "auth-token": auth_token + } + + # Act - simulate log-safe representation + def get_log_safe_config(cfg): + """Return config with sensitive data masked""" + safe_config = cfg.copy() + if "auth-token" in safe_config and safe_config["auth-token"]: + safe_config["auth-token"] = "****" + return safe_config + + log_safe = get_log_safe_config(config) + + # Assert + assert log_safe["auth-token"] == "****" + assert auth_token not in str(log_safe) + assert "url" in log_safe + assert log_safe["url"] == "https://secure.example.com/mcp" + + def test_auth_token_with_remote_name_configuration(self): + """Test MCP tool configuration with both auth-token and remote-name""" + # Arrange + mcp_config = { + "url": "https://server.example.com/mcp", + "remote-name": "actual_tool_name", + "auth-token": "my-token-456" + } + + # Act - simulate header building and remote name extraction + headers = {} + if "auth-token" in mcp_config and mcp_config["auth-token"]: + token = mcp_config["auth-token"] + headers["Authorization"] = f"Bearer {token}" + + remote_name = mcp_config.get("remote-name", "default-name") + + # Assert + assert headers["Authorization"] == "Bearer my-token-456" + assert remote_name == "actual_tool_name" + assert "url" in mcp_config + assert mcp_config["url"] == "https://server.example.com/mcp" + + def test_bearer_token_format(self): + """Test that Bearer token format is correct""" + # Arrange + tokens = [ + "simple-token", + "token_with_underscore", + "token-with-dash", + "TokenWithMixedCase123", + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature" + ] + + # Act & Assert + for token in tokens: + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + # Verify format is "Bearer " with single space + assert headers["Authorization"].startswith("Bearer ") + assert headers["Authorization"] == f"Bearer {token}" + # Verify no extra spaces + assert headers["Authorization"].count("Bearer") == 1 + assert headers["Authorization"].split("Bearer ")[1] == token diff --git a/tests/unit/test_direct/test_milvus_user_collection_integration.py b/tests/unit/test_direct/test_milvus_user_collection_integration.py index cc45524c..90b80e8f 100644 --- a/tests/unit/test_direct/test_milvus_user_collection_integration.py +++ b/tests/unit/test_direct/test_milvus_user_collection_integration.py @@ -30,14 +30,16 @@ class TestMilvusUserCollectionIntegration: for user, collection, vector in test_cases: doc_vectors.insert(vector, "test document", user, collection) - + expected_collection_name = make_safe_collection_name( user, collection, "doc" ) - - # Verify collection was created with correct name + # Add dimension suffix to expected name + expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}" + + # Verify collection was created with correct name (including dimension) assert (len(vector), user, collection) in doc_vectors.collections - assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name + assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name_with_dim @patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient') def test_entity_vectors_collection_creation_with_user_collection(self, mock_milvus_client): @@ -56,14 +58,16 @@ class TestMilvusUserCollectionIntegration: for user, collection, vector in test_cases: entity_vectors.insert(vector, "test entity", user, collection) - + expected_collection_name = make_safe_collection_name( user, collection, "entity" ) - - # Verify collection was created with correct name + # Add dimension suffix to expected name + expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}" + + # Verify collection was created with correct name (including dimension) assert (len(vector), user, collection) in entity_vectors.collections - assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name + assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name_with_dim @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') def test_doc_vectors_search_uses_correct_collection(self, mock_milvus_client): @@ -88,11 +92,12 @@ class TestMilvusUserCollectionIntegration: # Now search result = doc_vectors.search(vector, user, collection, limit=5) - # Verify search was called with correct collection name + # Verify search was called with correct collection name (including dimension) expected_collection_name = make_safe_collection_name(user, collection, "doc") + expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}" mock_client.search.assert_called_once() search_call = mock_client.search.call_args - assert search_call[1]["collection_name"] == expected_collection_name + assert search_call[1]["collection_name"] == expected_collection_name_with_dim @patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient') def test_entity_vectors_search_uses_correct_collection(self, mock_milvus_client): @@ -117,11 +122,12 @@ class TestMilvusUserCollectionIntegration: # Now search result = entity_vectors.search(vector, user, collection, limit=5) - # Verify search was called with correct collection name + # Verify search was called with correct collection name (including dimension) expected_collection_name = make_safe_collection_name(user, collection, "entity") + expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}" mock_client.search.assert_called_once() search_call = mock_client.search.call_args - assert search_call[1]["collection_name"] == expected_collection_name + assert search_call[1]["collection_name"] == expected_collection_name_with_dim @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') def test_doc_vectors_collection_isolation(self, mock_milvus_client): @@ -141,10 +147,11 @@ class TestMilvusUserCollectionIntegration: assert len(doc_vectors.collections) == 3 collection_names = set(doc_vectors.collections.values()) + # All vectors are 3-dimensional, so all names should have _3 suffix expected_names = { - "doc_user1_collection1", - "doc_user2_collection2", - "doc_user1_collection2" + "doc_user1_collection1_3", + "doc_user2_collection2_3", + "doc_user1_collection2_3" } assert collection_names == expected_names @@ -166,10 +173,11 @@ class TestMilvusUserCollectionIntegration: assert len(entity_vectors.collections) == 3 collection_names = set(entity_vectors.collections.values()) + # All vectors are 3-dimensional, so all names should have _3 suffix expected_names = { - "entity_user1_collection1", - "entity_user2_collection2", - "entity_user1_collection2" + "entity_user1_collection1_3", + "entity_user2_collection2_3", + "entity_user1_collection2_3" } assert collection_names == expected_names @@ -191,16 +199,16 @@ class TestMilvusUserCollectionIntegration: # Verify three separate collections were created for different dimensions assert len(doc_vectors.collections) == 3 - + collection_names = set(doc_vectors.collections.values()) + # Different dimensions now create different collections with dimension suffixes expected_names = { - "doc_test_user_test_collection", # Same name for all dimensions - "doc_test_user_test_collection", # now stored per dimension in key - "doc_test_user_test_collection" # but collection name is the same + "doc_test_user_test_collection_2", # 2D vector + "doc_test_user_test_collection_3", # 3D vector + "doc_test_user_test_collection_4" # 4D vector } - # Note: Now all dimensions use the same collection name, they are differentiated by the key - assert len(collection_names) == 1 # Only one unique collection name - assert "doc_test_user_test_collection" in collection_names + # Each dimension gets its own collection + assert len(collection_names) == 3 # Three unique collection names assert collection_names == expected_names @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') @@ -222,8 +230,9 @@ class TestMilvusUserCollectionIntegration: # Verify only one collection was created assert len(doc_vectors.collections) == 1 - - expected_collection_name = "doc_test_user_test_collection" + + # Collection name now includes dimension suffix + expected_collection_name = "doc_test_user_test_collection_3" assert doc_vectors.collections[(3, user, collection)] == expected_collection_name @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') @@ -235,19 +244,20 @@ class TestMilvusUserCollectionIntegration: doc_vectors = DocVectors(uri="http://test:19530", prefix="doc") # Test various special character combinations + # All expected names now include dimension suffix _3 test_cases = [ - ("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1"), - ("user_123", "collection_456", "doc_user_123_collection_456"), - ("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces"), - ("user@@@test", "collection---test", "doc_user_test_collection_test"), + ("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1_3"), + ("user_123", "collection_456", "doc_user_123_collection_456_3"), + ("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces_3"), + ("user@@@test", "collection---test", "doc_user_test_collection_test_3"), ] - + vector = [0.1, 0.2, 0.3] - + for user, collection, expected_name in test_cases: doc_vectors_instance = DocVectors(uri="http://test:19530", prefix="doc") doc_vectors_instance.insert(vector, "test doc", user, collection) - + assert doc_vectors_instance.collections[(3, user, collection)] == expected_name def test_collection_name_backward_compatibility(self): diff --git a/tests/unit/test_embeddings/test_embeddings_service_contract.py b/tests/unit/test_embeddings/test_embeddings_service_contract.py new file mode 100644 index 00000000..e53faf81 --- /dev/null +++ b/tests/unit/test_embeddings/test_embeddings_service_contract.py @@ -0,0 +1,157 @@ +""" +Contract tests for EmbeddingsService base class + +Tests the contract between the EmbeddingsService base class and its +implementations, ensuring proper integration of the model parameter handling. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from unittest import IsolatedAsyncioTestCase +from trustgraph.base import EmbeddingsService +from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error + + +class ConcreteEmbeddingsService(EmbeddingsService): + """Concrete implementation for testing the abstract base class""" + + def __init__(self, **params): + self.on_embeddings_calls = [] + self.default_model = params.get("model", "default-test-model") + # Don't call super().__init__ to avoid taskgroup requirements in tests + # We're only testing the on_embeddings interface + + async def on_embeddings(self, text, model=None): + """Implementation that tracks calls""" + self.on_embeddings_calls.append({ + "text": text, + "model": model + }) + # Return a simple embedding + return [[0.1, 0.2, 0.3]] + + +class TestEmbeddingsServiceModelParameterContract(IsolatedAsyncioTestCase): + """Test the model parameter contract in embeddings implementations""" + + async def test_on_embeddings_accepts_model_parameter(self): + """Test that on_embeddings method accepts optional model parameter""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + # Act + result1 = await service.on_embeddings("test text") + result2 = await service.on_embeddings("test text", model="custom-model") + result3 = await service.on_embeddings("test text", model=None) + + # Assert + assert len(service.on_embeddings_calls) == 3 + assert service.on_embeddings_calls[0]["model"] is None # No model specified + assert service.on_embeddings_calls[1]["model"] == "custom-model" + assert service.on_embeddings_calls[2]["model"] is None + + async def test_implementation_tracks_model_changes(self): + """Test that implementations properly track which model is requested""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + # Act - multiple requests with different models + await service.on_embeddings("text1", model="model-a") + await service.on_embeddings("text2", model="model-b") + await service.on_embeddings("text3") # Use default (None passed) + await service.on_embeddings("text4", model="model-a") + + # Assert + assert len(service.on_embeddings_calls) == 4 + assert service.on_embeddings_calls[0]["model"] == "model-a" + assert service.on_embeddings_calls[1]["model"] == "model-b" + assert service.on_embeddings_calls[2]["model"] is None + assert service.on_embeddings_calls[3]["model"] == "model-a" + + async def test_model_parameter_with_various_text_inputs(self): + """Test model parameter works with different text inputs""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + test_cases = [ + ("Simple text", "model-1"), + ("", "model-2"), + ("Unicode: 世界 🌍", "model-3"), + ("Very " * 100 + "long text", None), + ] + + # Act + for text, model in test_cases: + await service.on_embeddings(text, model=model) + + # Assert + assert len(service.on_embeddings_calls) == len(test_cases) + for i, (text, model) in enumerate(test_cases): + assert service.on_embeddings_calls[i]["text"] == text + assert service.on_embeddings_calls[i]["model"] == model + + async def test_embeddings_return_format(self): + """Test that embeddings are returned in correct format""" + # Arrange + service = ConcreteEmbeddingsService(model="default-model") + + # Act + result = await service.on_embeddings("test text", model="test-model") + + # Assert + assert isinstance(result, list) + assert len(result) > 0 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0]) + + +class TestEmbeddingsResponseSchema: + """Test the EmbeddingsResponse schema contract""" + + def test_success_response(self): + """Test creating success response""" + # Act + response = EmbeddingsResponse( + error=None, + vectors=[[0.1, 0.2, 0.3]] + ) + + # Assert + assert response.error is None + assert response.vectors == [[0.1, 0.2, 0.3]] + + def test_error_response(self): + """Test creating error response""" + # Act + error = Error(type="test-error", message="Test message") + response = EmbeddingsResponse( + error=error, + vectors=None + ) + + # Assert + assert response.error is not None + assert response.error.type == "test-error" + assert response.error.message == "Test message" + assert response.vectors is None + + def test_response_with_multiple_vectors(self): + """Test response with multiple embedding vectors""" + # Act + vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9] + ] + response = EmbeddingsResponse( + error=None, + vectors=vectors + ) + + # Assert + assert len(response.vectors) == 3 + assert response.vectors[0] == [0.1, 0.2, 0.3] + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py new file mode 100644 index 00000000..1c1fb883 --- /dev/null +++ b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py @@ -0,0 +1,216 @@ +""" +Unit tests for FastEmbed dynamic model loading + +Tests the model caching and dynamic loading functionality for FastEmbed +embeddings service. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from unittest import IsolatedAsyncioTestCase +from trustgraph.embeddings.fastembed.processor import Processor + + +class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): + """Test FastEmbed dynamic model loading and caching""" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that default model is loaded during initialization""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + base_params = { + "id": "test-embeddings", + "concurrency": 1, + "model": "test-model", + "taskgroup": AsyncMock() + } + + # Act + processor = Processor(**base_params) + + # Assert + mock_text_embedding_class.assert_called_once_with(model_name="test-model") + assert processor.default_model == "test-model" + assert processor.cached_model_name == "test-model" + assert processor.embeddings is not None + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that using the same model doesn't reload it""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act - use same model multiple times + processor._load_model("test-model") + processor._load_model("test-model") + processor._load_model("test-model") + + # Assert - model should not be reloaded + mock_text_embedding_class.assert_not_called() + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that changing model name triggers reload""" + # Arrange + mock_fastembed_instance = Mock() + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act - switch to different model + processor._load_model("different-model") + + # Assert + mock_text_embedding_class.assert_called_once_with(model_name="different-model") + assert processor.cached_model_name == "different-model" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that on_embeddings uses default model when no model specified""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text") + + # Assert + mock_fastembed_instance.embed.assert_called_once_with(["test text"]) + assert processor.cached_model_name == "test-model" # Still using default + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that on_embeddings uses specified model when provided""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_text_embedding_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text", model="custom-model") + + # Assert + mock_text_embedding_class.assert_called_once_with(model_name="custom-model") + assert processor.cached_model_name == "custom-model" + mock_fastembed_instance.embed.assert_called_once_with(["test text"]) + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test switching between multiple models""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_call_count = mock_text_embedding_class.call_count + + # Act - switch between models + await processor.on_embeddings("text1", model="model-a") + call_count_after_a = mock_text_embedding_class.call_count + + await processor.on_embeddings("text2", model="model-a") # Same, no reload + call_count_after_a_repeat = mock_text_embedding_class.call_count + + await processor.on_embeddings("text3", model="model-b") # Different, reload + call_count_after_b = mock_text_embedding_class.call_count + + await processor.on_embeddings("text4", model="model-a") # Back to A, reload + call_count_after_a_again = mock_text_embedding_class.call_count + + # Assert + assert call_count_after_a == initial_call_count + 1 # First load + assert call_count_after_a_repeat == initial_call_count + 1 # No reload + assert call_count_after_b == initial_call_count + 2 # Reload for model-b + assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test that None model parameter falls back to default""" + # Arrange + mock_fastembed_instance = Mock() + mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])] + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_count = mock_text_embedding_class.call_count + + # Act + result = await processor.on_embeddings("test text", model=None) + + # Assert + # No reload, using cached default + assert mock_text_embedding_class.call_count == initial_count + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class): + """Test initialization without model parameter uses module default""" + # Arrange + mock_fastembed_instance = Mock() + mock_text_embedding_class.return_value = mock_fastembed_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock()) + + # Assert + # Should use default_model from module + expected_default = "sentence-transformers/all-MiniLM-L6-v2" + mock_text_embedding_class.assert_called_once_with(model_name=expected_default) + assert processor.default_model == expected_default + assert processor.cached_model_name == expected_default + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py new file mode 100644 index 00000000..aef6fc92 --- /dev/null +++ b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py @@ -0,0 +1,213 @@ +""" +Unit tests for HuggingFace dynamic model loading + +Tests the model caching and dynamic loading functionality for HuggingFace +embeddings service using LangChain's HuggingFaceEmbeddings. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from unittest import IsolatedAsyncioTestCase + +# Skip all tests in this module if trustgraph.embeddings.hf is not installed +pytest.importorskip("trustgraph.embeddings.hf") + +from trustgraph.embeddings.hf.hf import Processor + + +class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): + """Test HuggingFace dynamic model loading and caching""" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that default model is loaded during initialization""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Assert + mock_hf_class.assert_called_once_with(model_name="test-model") + assert processor.default_model == "test-model" + assert processor.cached_model_name == "test-model" + assert processor.embeddings is not None + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that using the same model doesn't reload it""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act - use same model multiple times + processor._load_model("test-model") + processor._load_model("test-model") + processor._load_model("test-model") + + # Assert - model should not be reloaded + mock_hf_class.assert_not_called() + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that changing model name triggers reload""" + # Arrange + mock_hf_instance = Mock() + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act - switch to different model + processor._load_model("different-model") + + # Assert + mock_hf_class.assert_called_once_with(model_name="different-model") + assert processor.cached_model_name == "different-model" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that on_embeddings uses default model when no model specified""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text") + + # Assert + mock_hf_instance.embed_documents.assert_called_once_with(["test text"]) + assert processor.cached_model_name == "test-model" # Still using default + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that on_embeddings uses specified model when provided""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + mock_hf_class.reset_mock() + + # Act + result = await processor.on_embeddings("test text", model="custom-model") + + # Assert + mock_hf_class.assert_called_once_with(model_name="custom-model") + assert processor.cached_model_name == "custom-model" + mock_hf_instance.embed_documents.assert_called_once_with(["test text"]) + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test switching between multiple models""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_call_count = mock_hf_class.call_count + + # Act - switch between models + await processor.on_embeddings("text1", model="model-a") + call_count_after_a = mock_hf_class.call_count + + await processor.on_embeddings("text2", model="model-a") # Same, no reload + call_count_after_a_repeat = mock_hf_class.call_count + + await processor.on_embeddings("text3", model="model-b") # Different, reload + call_count_after_b = mock_hf_class.call_count + + await processor.on_embeddings("text4", model="model-a") # Back to A, reload + call_count_after_a_again = mock_hf_class.call_count + + # Assert + assert call_count_after_a == initial_call_count + 1 # First load + assert call_count_after_a_repeat == initial_call_count + 1 # No reload + assert call_count_after_b == initial_call_count + 2 # Reload for model-b + assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test that None model parameter falls back to default""" + # Arrange + mock_hf_instance = Mock() + mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + initial_count = mock_hf_class.call_count + + # Act + result = await processor.on_embeddings("test text", model=None) + + # Assert + # No reload, using cached default + assert mock_hf_class.call_count == initial_count + assert processor.cached_model_name == "test-model" + + @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class): + """Test initialization without model parameter uses module default""" + # Arrange + mock_hf_instance = Mock() + mock_hf_class.return_value = mock_hf_instance + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock()) + + # Assert + # Should use default_model from module + expected_default = "all-MiniLM-L6-v2" + mock_hf_class.assert_called_once_with(model_name=expected_default) + assert processor.default_model == expected_default + assert processor.cached_model_name == expected_default + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py new file mode 100644 index 00000000..ca0f44bf --- /dev/null +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -0,0 +1,167 @@ +""" +Unit tests for Ollama dynamic model loading + +Tests the dynamic model selection functionality for Ollama embeddings service. +Since Ollama is server-side, no model caching is needed on the client side. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from unittest import IsolatedAsyncioTestCase +from trustgraph.embeddings.ollama.processor import Processor + + +class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): + """Test Ollama dynamic model selection""" + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that Ollama client is initialized with correct host""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test", concurrency=1, model="test-model", + ollama="http://localhost:11434", taskgroup=AsyncMock()) + + # Assert + mock_client_class.assert_called_once_with(host="http://localhost:11434") + assert processor.default_model == "test-model" + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that on_embeddings uses default model when no model specified""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act + result = await processor.on_embeddings("test text") + + # Assert + mock_ollama_client.embed.assert_called_once_with( + model="test-model", + input="test text" + ) + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that on_embeddings uses specified model when provided""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act + result = await processor.on_embeddings("test text", model="custom-model") + + # Assert + mock_ollama_client.embed.assert_called_once_with( + model="custom-model", + input="test text" + ) + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test switching between multiple models""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act - switch between different models + await processor.on_embeddings("text1", model="model-a") + await processor.on_embeddings("text2", model="model-b") + await processor.on_embeddings("text3", model="model-a") + await processor.on_embeddings("text4") # Use default + + # Assert + calls = mock_ollama_client.embed.call_args_list + assert len(calls) == 4 + assert calls[0][1]['model'] == "model-a" + assert calls[1][1]['model'] == "model-b" + assert calls[2][1]['model'] == "model-a" + assert calls[3][1]['model'] == "test-model" # Default + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test that None model parameter falls back to default""" + # Arrange + mock_ollama_client = Mock() + mock_response = Mock() + mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]] + mock_ollama_client.embed.return_value = mock_response + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) + + # Act + result = await processor.on_embeddings("test text", model=None) + + # Assert + mock_ollama_client.embed.assert_called_once_with( + model="test-model", + input="test text" + ) + + @patch('trustgraph.embeddings.ollama.processor.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') + async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class): + """Test initialization without model parameter uses module default""" + # Arrange + mock_ollama_client = Mock() + mock_client_class.return_value = mock_ollama_client + mock_async_init.return_value = None + mock_embeddings_init.return_value = None + + # Act + processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock()) + + # Assert + # Should use default_model from module + expected_default = "mxbai-embed-large" + assert processor.default_model == expected_default + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_extract/test_ontology/README.md b/tests/unit/test_extract/test_ontology/README.md new file mode 100644 index 00000000..e3f0a164 --- /dev/null +++ b/tests/unit/test_extract/test_ontology/README.md @@ -0,0 +1,148 @@ +# Ontology Extractor Unit Tests + +Comprehensive unit tests for the OntoRAG ontology extraction system. + +## Test Coverage + +### 1. `test_ontology_selector.py` - Auto-Include Properties Feature + +Tests the critical dependency resolution that automatically includes all properties related to selected classes. + +**Key Tests:** +- `test_auto_include_properties_for_recipe_class` - Verifies Recipe class auto-includes `ingredients`, `method`, `produces`, `serves` +- `test_auto_include_properties_for_ingredient_class` - Verifies Ingredient class auto-includes `food` property +- `test_auto_include_properties_for_range_class` - Tests properties are included when class appears in range +- `test_auto_include_adds_domain_and_range_classes` - Ensures related classes are added too +- `test_multiple_classes_get_all_related_properties` - Tests combining multiple class selections +- `test_no_duplicate_properties_added` - Ensures properties aren't duplicated + +### 2. `test_uri_expansion.py` - URI Expansion + +Tests that URIs are properly expanded using ontology definitions instead of constructed fallback URIs. + +**Key Tests:** +- `test_expand_class_uri_from_ontology` - Class names expand to ontology URIs +- `test_expand_object_property_uri_from_ontology` - Object properties use ontology URIs +- `test_expand_datatype_property_uri_from_ontology` - Datatype properties use ontology URIs +- `test_expand_rdf_prefix` - Standard RDF prefixes expand correctly +- `test_expand_rdfs_prefix`, `test_expand_owl_prefix`, `test_expand_xsd_prefix` - Other standard prefixes +- `test_fallback_uri_for_instance` - Entity instances get constructed URIs +- `test_already_full_uri_unchanged` - Full URIs pass through +- `test_dict_access_not_object_attribute` - **Critical test** verifying dict access works (not object attributes) + +### 3. `test_ontology_triples.py` - Ontology Triple Generation + +Tests that ontology elements (classes and properties) are properly converted to RDF triples with labels, comments, domains, and ranges. + +**Key Tests:** +- `test_generates_class_type_triples` - Classes get `rdf:type owl:Class` triples +- `test_generates_class_labels` - Classes get `rdfs:label` triples +- `test_generates_class_comments` - Classes get `rdfs:comment` triples +- `test_generates_object_property_type_triples` - Object properties get proper type triples +- `test_generates_object_property_labels` - Object properties get labels +- `test_generates_object_property_domain` - Object properties get `rdfs:domain` triples +- `test_generates_object_property_range` - Object properties get `rdfs:range` triples +- `test_generates_datatype_property_type_triples` - Datatype properties get proper type triples +- `test_generates_datatype_property_range` - Datatype properties get XSD type ranges +- `test_uses_dict_field_names_not_rdf_names` - **Critical test** verifying dict field names work +- `test_total_triple_count_is_reasonable` - Validates expected number of triples + +### 4. `test_text_processing.py` - Text Processing and Segmentation + +Tests that text is properly split into sentences for ontology matching, including NLTK tokenization and TextSegment creation. + +**Key Tests:** +- `test_segment_single_sentence` - Single sentence produces one segment +- `test_segment_multiple_sentences` - Multiple sentences split correctly +- `test_segment_positions` - Segment start/end positions are correct +- `test_segment_complex_punctuation` - Handles abbreviations (Dr., U.S.A., Mr.) +- `test_segment_question_and_exclamation` - Different sentence terminators +- `test_segment_preserves_original_text` - Segments can reconstruct original +- `test_text_segment_non_overlapping` - Segments don't overlap +- `test_nltk_punkt_availability` - NLTK tokenizer is available +- `test_unicode_text` - Handles unicode characters +- `test_quoted_text` - Handles quoted text correctly + +### 5. `test_prompt_and_extraction.py` - LLM Prompt Construction and Triple Extraction + +Tests that the system correctly constructs prompts with ontology constraints and extracts/validates triples from LLM responses. + +**Key Tests:** +- `test_build_extraction_variables_includes_text` - Prompt includes input text +- `test_build_extraction_variables_includes_classes` - Prompt includes ontology classes +- `test_build_extraction_variables_includes_properties` - Prompt includes properties +- `test_validates_rdf_type_triple_with_valid_class` - Validates rdf:type against ontology +- `test_rejects_rdf_type_triple_with_invalid_class` - Rejects invalid classes +- `test_validates_object_property_triple` - Validates object properties +- `test_rejects_unknown_property` - Rejects properties not in ontology +- `test_parse_simple_triple_dict` - Parses triple from dict format +- `test_filters_invalid_triples` - Filters out invalid triples +- `test_expands_uris_in_parsed_triples` - Expands URIs using ontology +- `test_creates_proper_triple_objects` - Creates Triple objects with Value subjects/predicates/objects + +### 6. `test_embedding_and_similarity.py` - Ontology Embedding and Similarity Matching + +Tests that ontology elements are properly embedded and matched against input text using vector similarity. + +**Key Tests:** +- `test_create_text_from_class_with_id` - Text representation includes class ID +- `test_create_text_from_class_with_labels` - Includes labels in text +- `test_create_text_from_class_with_comment` - Includes comments in text +- `test_create_text_from_property_with_domain_range` - Includes domain/range in property text +- `test_normalizes_id_with_underscores` - Normalizes IDs (underscores to spaces) +- `test_includes_subclass_info_for_classes` - Includes subclass relationships +- `test_vector_store_api_structure` - Vector store has expected API +- `test_selector_handles_text_segments` - Selector processes text segments +- `test_merge_subsets_combines_elements` - Merging combines ontology elements +- `test_ontology_element_metadata_structure` - Metadata structure is correct + +## Running the Tests + +### Run all ontology extractor tests: +```bash +cd /home/mark/work/trustgraph.ai/trustgraph +pytest tests/unit/test_extract/test_ontology/ -v +``` + +### Run specific test file: +```bash +pytest tests/unit/test_extract/test_ontology/test_ontology_selector.py -v +pytest tests/unit/test_extract/test_ontology/test_uri_expansion.py -v +pytest tests/unit/test_extract/test_ontology/test_ontology_triples.py -v +pytest tests/unit/test_extract/test_ontology/test_text_processing.py -v +pytest tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py -v +pytest tests/unit/test_extract/test_ontology/test_embedding_and_similarity.py -v +``` + +### Run specific test: +```bash +pytest tests/unit/test_extract/test_ontology/test_ontology_selector.py::TestOntologySelector::test_auto_include_properties_for_recipe_class -v +``` + +### Run with coverage: +```bash +pytest tests/unit/test_extract/test_ontology/ --cov=trustgraph.extract.kg.ontology --cov-report=html +``` + +## Test Fixtures + +- `sample_ontology` - Complete Food Ontology with Recipe, Ingredient, Food, Method classes +- `ontology_loader_with_sample` - Mock OntologyLoader with the sample ontology +- `ontology_embedder` - Mock embedder for testing +- `mock_embedding_service` - Mock service for generating deterministic embeddings +- `vector_store` - InMemoryVectorStore for testing +- `extractor` - Processor instance for URI expansion tests +- `ontology_subset_with_uris` - OntologySubset with proper URIs defined +- `sample_ontology_subset` - OntologySubset for testing triple generation +- `text_processor` - TextProcessor instance for text segmentation tests +- `sample_ontology_class` - Sample OntologyClass for testing +- `sample_ontology_property` - Sample OntologyProperty for testing + +## Implementation Notes + +These tests verify the fixes made to address: +1. **Disconnected graph problem** - Auto-include properties feature ensures all relevant relationships are available +2. **Wrong URIs problem** - URI expansion using ontology definitions instead of constructed fallbacks +3. **Dict vs object attribute problem** - URI expansion works with dicts (from `cls.__dict__`) not object attributes +4. **Ontology visibility in KG** - Ontology elements themselves appear in the knowledge graph with proper metadata +5. **Text segmentation** - Proper sentence splitting for ontology matching using NLTK diff --git a/tests/unit/test_extract/test_ontology/__init__.py b/tests/unit/test_extract/test_ontology/__init__.py new file mode 100644 index 00000000..22e958af --- /dev/null +++ b/tests/unit/test_extract/test_ontology/__init__.py @@ -0,0 +1 @@ +"""Tests for ontology-based extraction.""" diff --git a/tests/unit/test_extract/test_ontology/test_embedding_and_similarity.py b/tests/unit/test_extract/test_ontology/test_embedding_and_similarity.py new file mode 100644 index 00000000..fe6d3c5f --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_embedding_and_similarity.py @@ -0,0 +1,423 @@ +""" +Unit tests for ontology embedding and similarity matching. + +Tests that ontology elements are properly embedded and matched against +input text using vector similarity. +""" + +import pytest +import numpy as np +from unittest.mock import AsyncMock, MagicMock +from trustgraph.extract.kg.ontology.ontology_embedder import ( + OntologyEmbedder, + OntologyElementMetadata +) +from trustgraph.extract.kg.ontology.ontology_loader import ( + Ontology, + OntologyClass, + OntologyProperty +) +from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore, SearchResult +from trustgraph.extract.kg.ontology.text_processor import TextSegment +from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset + + +@pytest.fixture +def mock_embedding_service(): + """Create a mock embedding service.""" + service = AsyncMock() + # Return deterministic embeddings for testing + async def mock_embed(text): + # Simple hash-based embedding for deterministic tests + hash_val = hash(text) % 1000 + return np.array([hash_val / 1000.0, (1000 - hash_val) / 1000.0]) + service.embed = mock_embed + return service + + +@pytest.fixture +def vector_store(): + """Create an empty vector store.""" + return InMemoryVectorStore() + + +@pytest.fixture +def ontology_embedder(mock_embedding_service, vector_store): + """Create an ontology embedder with mock service.""" + return OntologyEmbedder( + embedding_service=mock_embedding_service, + vector_store=vector_store + ) + + +@pytest.fixture +def sample_ontology_class(): + """Create a sample ontology class.""" + return OntologyClass( + uri="http://purl.org/ontology/fo/Recipe", + type="owl:Class", + labels=[{"value": "Recipe", "lang": "en-gb"}], + comment="A Recipe is a combination of ingredients and a method.", + subclass_of=None + ) + + +@pytest.fixture +def sample_ontology_property(): + """Create a sample ontology property.""" + return OntologyProperty( + uri="http://purl.org/ontology/fo/ingredients", + type="owl:ObjectProperty", + labels=[{"value": "ingredients", "lang": "en-gb"}], + comment="The ingredients property relates a recipe to an ingredient list.", + domain="Recipe", + range="IngredientList" + ) + + +class TestTextRepresentation: + """Test suite for creating text representations of ontology elements.""" + + def test_create_text_from_class_with_id(self, ontology_embedder, sample_ontology_class): + """Test that class ID is included in text representation.""" + text = ontology_embedder._create_text_representation( + "Recipe", + sample_ontology_class, + "class" + ) + + assert "Recipe" in text, "Should include class ID" + + def test_create_text_from_class_with_labels(self, ontology_embedder, sample_ontology_class): + """Test that class labels are included in text representation.""" + text = ontology_embedder._create_text_representation( + "Recipe", + sample_ontology_class, + "class" + ) + + assert "Recipe" in text, "Should include label value" + + def test_create_text_from_class_with_comment(self, ontology_embedder, sample_ontology_class): + """Test that class comments are included in text representation.""" + text = ontology_embedder._create_text_representation( + "Recipe", + sample_ontology_class, + "class" + ) + + assert "combination of ingredients" in text, "Should include comment" + + def test_create_text_from_property_with_domain_range(self, ontology_embedder, sample_ontology_property): + """Test that property domain and range are included in text.""" + text = ontology_embedder._create_text_representation( + "ingredients", + sample_ontology_property, + "objectProperty" + ) + + assert "domain: Recipe" in text, "Should include domain" + assert "range: IngredientList" in text, "Should include range" + + def test_normalizes_id_with_underscores(self, ontology_embedder): + """Test that IDs with underscores are normalized.""" + mock_element = MagicMock() + mock_element.labels = [] + mock_element.comment = None + + text = ontology_embedder._create_text_representation( + "some_property_name", + mock_element, + "objectProperty" + ) + + assert "some property name" in text, "Should replace underscores with spaces" + + def test_normalizes_id_with_hyphens(self, ontology_embedder): + """Test that IDs with hyphens are normalized.""" + mock_element = MagicMock() + mock_element.labels = [] + mock_element.comment = None + + text = ontology_embedder._create_text_representation( + "some-property-name", + mock_element, + "objectProperty" + ) + + assert "some property name" in text, "Should replace hyphens with spaces" + + def test_handles_element_without_labels(self, ontology_embedder): + """Test handling of elements without labels.""" + mock_element = MagicMock() + mock_element.labels = None + mock_element.comment = "Test comment" + + text = ontology_embedder._create_text_representation( + "TestElement", + mock_element, + "class" + ) + + assert "TestElement" in text, "Should still include ID" + assert "Test comment" in text, "Should include comment" + + def test_includes_subclass_info_for_classes(self, ontology_embedder): + """Test that subclass information is included for classes.""" + mock_class = MagicMock() + mock_class.labels = [] + mock_class.comment = None + mock_class.subclass_of = "ParentClass" + + text = ontology_embedder._create_text_representation( + "ChildClass", + mock_class, + "class" + ) + + assert "subclass of ParentClass" in text, "Should include subclass relationship" + + +class TestVectorStoreOperations: + """Test suite for vector store operations.""" + + def test_vector_store_starts_empty(self, vector_store): + """Test that vector store initializes empty.""" + assert vector_store.size() == 0, "New vector store should be empty" + + def test_vector_store_api_structure(self, vector_store): + """Test that vector store has expected API methods.""" + assert hasattr(vector_store, 'add'), "Should have add method" + assert hasattr(vector_store, 'add_batch'), "Should have add_batch method" + assert hasattr(vector_store, 'search'), "Should have search method" + assert hasattr(vector_store, 'size'), "Should have size method" + + def test_search_result_class_structure(self): + """Test that SearchResult has expected structure.""" + # Create a sample SearchResult + result = SearchResult(id="test-1", score=0.95, metadata={"element": "Test"}) + + assert hasattr(result, 'id'), "Should have id attribute" + assert hasattr(result, 'score'), "Should have score attribute" + assert hasattr(result, 'metadata'), "Should have metadata attribute" + assert result.id == "test-1" + assert result.score == 0.95 + assert result.metadata["element"] == "Test" + + +class TestOntologySelectorIntegration: + """Test suite for ontology selector with embeddings.""" + + @pytest.fixture + def sample_ontology(self): + """Create a sample ontology for testing.""" + return Ontology( + id="food", + classes={ + "Recipe": OntologyClass( + uri="http://purl.org/ontology/fo/Recipe", + type="owl:Class", + labels=[{"value": "Recipe", "lang": "en-gb"}], + comment="A Recipe is a combination of ingredients and a method." + ), + "Ingredient": OntologyClass( + uri="http://purl.org/ontology/fo/Ingredient", + type="owl:Class", + labels=[{"value": "Ingredient", "lang": "en-gb"}], + comment="An Ingredient combines a quantity and a food." + ) + }, + object_properties={ + "ingredients": OntologyProperty( + uri="http://purl.org/ontology/fo/ingredients", + type="owl:ObjectProperty", + labels=[{"value": "ingredients", "lang": "en-gb"}], + comment="Relates a recipe to its ingredients.", + domain="Recipe", + range="IngredientList" + ) + }, + datatype_properties={}, + metadata={"name": "Food Ontology"} + ) + + @pytest.fixture + def ontology_loader_mock(self, sample_ontology): + """Create a mock ontology loader.""" + loader = MagicMock() + loader.get_ontology.return_value = sample_ontology + loader.get_all_ontology_ids.return_value = ["food"] + return loader + + async def test_selector_handles_text_segments( + self, ontology_embedder, ontology_loader_mock + ): + """Test that selector can process text segments.""" + # Create selector + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_mock, + top_k=5, + similarity_threshold=0.3 + ) + + # Create text segments + segments = [ + TextSegment(text="Recipe for cornish pasty", type="sentence", position=0), + TextSegment(text="ingredients needed", type="sentence", position=1) + ] + + # Select ontology subset (will be empty since we haven't embedded anything) + subsets = await selector.select_ontology_subset(segments) + + # Should return a list (even if empty) + assert isinstance(subsets, list), "Should return a list of subsets" + + async def test_selector_with_no_embedding_service(self, vector_store, ontology_loader_mock): + """Test that selector handles missing embedding service gracefully.""" + embedder = OntologyEmbedder(embedding_service=None, vector_store=vector_store) + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=ontology_loader_mock, + top_k=5, + similarity_threshold=0.7 + ) + + segments = [ + TextSegment(text="Test text", type="sentence", position=0) + ] + + # Should return empty results without crashing + subsets = await selector.select_ontology_subset(segments) + assert isinstance(subsets, list), "Should return a list even without embeddings" + + def test_merge_subsets_combines_elements(self, ontology_loader_mock, ontology_embedder): + """Test that merging subsets combines all elements.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_mock, + top_k=5, + similarity_threshold=0.7 + ) + + # Create two subsets from same ontology + subset1 = OntologySubset( + ontology_id="food", + classes={"Recipe": {"uri": "http://example.com/Recipe"}}, + object_properties={}, + datatype_properties={}, + metadata={}, + relevance_score=0.8 + ) + + subset2 = OntologySubset( + ontology_id="food", + classes={"Ingredient": {"uri": "http://example.com/Ingredient"}}, + object_properties={"ingredients": {"uri": "http://example.com/ingredients"}}, + datatype_properties={}, + metadata={}, + relevance_score=0.9 + ) + + merged = selector.merge_subsets([subset1, subset2]) + + assert len(merged.classes) == 2, "Should combine classes" + # Keys may be prefixed with ontology id + assert any("Recipe" in key for key in merged.classes.keys()) + assert any("Ingredient" in key for key in merged.classes.keys()) + assert len(merged.object_properties) == 1, "Should include properties" + + +class TestEmbeddingEdgeCases: + """Test suite for edge cases in embedding.""" + + async def test_embed_element_with_no_labels(self, ontology_embedder): + """Test embedding element without labels.""" + mock_element = MagicMock() + mock_element.labels = None + mock_element.comment = "Test element" + + text = ontology_embedder._create_text_representation( + "TestElement", + mock_element, + "class" + ) + + # Should not crash and should include ID and comment + assert "TestElement" in text + assert "Test element" in text + + async def test_embed_element_with_empty_comment(self, ontology_embedder): + """Test embedding element with empty comment.""" + mock_element = MagicMock() + mock_element.labels = [{"value": "Label"}] + mock_element.comment = None + + text = ontology_embedder._create_text_representation( + "TestElement", + mock_element, + "class" + ) + + # Should not crash + assert "Label" in text + + def test_ontology_element_metadata_structure(self): + """Test OntologyElementMetadata structure.""" + metadata = OntologyElementMetadata( + type="class", + ontology="food", + element="Recipe", + definition={"uri": "http://example.com/Recipe"}, + text="Recipe A combination of ingredients" + ) + + assert metadata.type == "class" + assert metadata.ontology == "food" + assert metadata.element == "Recipe" + assert "uri" in metadata.definition + + def test_vector_store_search_on_empty_store(self): + """Test searching empty vector store.""" + # Need a non-empty store for faiss to work + # This test verifies the store can be created but searching requires dimension + store = InMemoryVectorStore() + assert store.size() == 0, "Empty store should have size 0" + + +class TestOntologySubsetStructure: + """Test suite for OntologySubset structure.""" + + def test_ontology_subset_creation(self): + """Test creating an OntologySubset.""" + subset = OntologySubset( + ontology_id="test", + classes={"Recipe": {}}, + object_properties={"produces": {}}, + datatype_properties={"serves": {}}, + metadata={"name": "Test"}, + relevance_score=0.85 + ) + + assert subset.ontology_id == "test" + assert len(subset.classes) == 1 + assert len(subset.object_properties) == 1 + assert len(subset.datatype_properties) == 1 + assert subset.relevance_score == 0.85 + + def test_ontology_subset_default_score(self): + """Test that OntologySubset has default score.""" + subset = OntologySubset( + ontology_id="test", + classes={}, + object_properties={}, + datatype_properties={}, + metadata={} + ) + + assert subset.relevance_score == 0.0, "Should have default score of 0.0" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_entity_contexts.py b/tests/unit/test_extract/test_ontology/test_entity_contexts.py new file mode 100644 index 00000000..c867b05a --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_entity_contexts.py @@ -0,0 +1,353 @@ +""" +Unit tests for entity context building. + +Tests that entity contexts are properly created from extracted triples, +collecting labels and definitions for entity embedding and retrieval. +""" + +import pytest +from trustgraph.extract.kg.ontology.extract import Processor +from trustgraph.schema.core.primitives import Triple, Value +from trustgraph.schema.knowledge.graph import EntityContext + + +@pytest.fixture +def processor(): + """Create a Processor instance for testing.""" + processor = object.__new__(Processor) + return processor + + +class TestEntityContextBuilding: + """Test suite for entity context building from triples.""" + + def test_builds_context_from_label(self, processor): + """Test that entity context is built from rdfs:label.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Cornish Pasty", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1, "Should create one entity context" + assert isinstance(contexts[0], EntityContext) + assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty" + assert "Label: Cornish Pasty" in contexts[0].context + + def test_builds_context_from_definition(self, processor): + """Test that entity context includes definitions.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/pasty", is_uri=True), + p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), + o=Value(value="A baked pastry filled with savory ingredients", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert "A baked pastry filled with savory ingredients" in contexts[0].context + + def test_combines_label_and_definition(self, processor): + """Test that label and definition are combined in context.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/recipe1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Pasty Recipe", is_uri=False) + ), + Triple( + s=Value(value="https://example.com/entity/recipe1", is_uri=True), + p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), + o=Value(value="Traditional Cornish pastry recipe", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + context_text = contexts[0].context + assert "Label: Pasty Recipe" in context_text + assert "Traditional Cornish pastry recipe" in context_text + assert ". " in context_text, "Should join parts with period and space" + + def test_uses_first_label_only(self, processor): + """Test that only the first label is used in context.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="First Label", is_uri=False) + ), + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Second Label", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert "Label: First Label" in contexts[0].context + assert "Second Label" not in contexts[0].context + + def test_includes_all_definitions(self, processor): + """Test that all definitions are included in context.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), + o=Value(value="First definition", is_uri=False) + ), + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), + o=Value(value="Second definition", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + context_text = contexts[0].context + assert "First definition" in context_text + assert "Second definition" in context_text + + def test_supports_schema_org_description(self, processor): + """Test that schema.org description is treated as definition.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="https://schema.org/description", is_uri=True), + o=Value(value="A delicious food item", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert "A delicious food item" in contexts[0].context + + def test_handles_multiple_entities(self, processor): + """Test that contexts are created for multiple entities.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/entity1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Entity One", is_uri=False) + ), + Triple( + s=Value(value="https://example.com/entity/entity2", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Entity Two", is_uri=False) + ), + Triple( + s=Value(value="https://example.com/entity/entity3", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Entity Three", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 3, "Should create context for each entity" + entity_uris = [ctx.entity.value for ctx in contexts] + assert "https://example.com/entity/entity1" in entity_uris + assert "https://example.com/entity/entity2" in entity_uris + assert "https://example.com/entity/entity3" in entity_uris + + def test_ignores_uri_literals(self, processor): + """Test that URI objects are ignored (only literal labels/definitions).""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal + ) + ] + + contexts = processor.build_entity_contexts(triples) + + # Should not create context since label is URI + assert len(contexts) == 0, "Should not create context for URI labels" + + def test_ignores_non_label_non_definition_triples(self, processor): + """Test that other predicates are ignored.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://example.com/Food", is_uri=True) + ), + Triple( + s=Value(value="https://example.com/entity/food1", is_uri=True), + p=Value(value="http://example.com/produces", is_uri=True), + o=Value(value="https://example.com/entity/food2", is_uri=True) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + # Should not create context since no labels or definitions + assert len(contexts) == 0, "Should not create context without labels/definitions" + + def test_handles_empty_triple_list(self, processor): + """Test handling of empty triple list.""" + triples = [] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 0, "Empty triple list should return empty contexts" + + def test_entity_context_has_value_object(self, processor): + """Test that EntityContext.entity is a Value object.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/test", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Test Entity", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert isinstance(contexts[0].entity, Value), "Entity should be Value object" + assert contexts[0].entity.is_uri, "Entity should be marked as URI" + + def test_entity_context_text_is_string(self, processor): + """Test that EntityContext.context is a string.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/test", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Test Entity", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert isinstance(contexts[0].context, str), "Context should be string" + + def test_only_creates_contexts_with_meaningful_info(self, processor): + """Test that contexts are only created when there's meaningful information.""" + triples = [ + # Entity with label - should create context + Triple( + s=Value(value="https://example.com/entity/entity1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Entity One", is_uri=False) + ), + # Entity with only rdf:type - should NOT create context + Triple( + s=Value(value="https://example.com/entity/entity2", is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://example.com/Food", is_uri=True) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1, "Should only create context for entity with label/definition" + assert contexts[0].entity.value == "https://example.com/entity/entity1" + + +class TestEntityContextEdgeCases: + """Test suite for edge cases in entity context building.""" + + def test_handles_unicode_in_labels(self, processor): + """Test handling of unicode characters in labels.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/café", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Café Spécial", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert "Café Spécial" in contexts[0].context + + def test_handles_long_definitions(self, processor): + """Test handling of very long definitions.""" + long_def = "This is a very long definition " * 50 + triples = [ + Triple( + s=Value(value="https://example.com/entity/test", is_uri=True), + p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), + o=Value(value=long_def, is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert long_def in contexts[0].context + + def test_handles_special_characters_in_context(self, processor): + """Test handling of special characters in context text.""" + triples = [ + Triple( + s=Value(value="https://example.com/entity/test", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Test & Entity \"quotes\"", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + assert "Test & Entity \"quotes\"" in contexts[0].context + + def test_mixed_relevant_and_irrelevant_triples(self, processor): + """Test extracting contexts from mixed triple types.""" + triples = [ + # Label - relevant + Triple( + s=Value(value="https://example.com/entity/recipe1", is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Cornish Pasty Recipe", is_uri=False) + ), + # Type - irrelevant + Triple( + s=Value(value="https://example.com/entity/recipe1", is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://example.com/Recipe", is_uri=True) + ), + # Property - irrelevant + Triple( + s=Value(value="https://example.com/entity/recipe1", is_uri=True), + p=Value(value="http://example.com/produces", is_uri=True), + o=Value(value="https://example.com/entity/pasty", is_uri=True) + ), + # Definition - relevant + Triple( + s=Value(value="https://example.com/entity/recipe1", is_uri=True), + p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True), + o=Value(value="Traditional British pastry recipe", is_uri=False) + ) + ] + + contexts = processor.build_entity_contexts(triples) + + assert len(contexts) == 1 + context_text = contexts[0].context + # Should include label and definition + assert "Label: Cornish Pasty Recipe" in context_text + assert "Traditional British pastry recipe" in context_text + # Should not include type or property info + assert "Recipe" not in context_text or "Cornish Pasty Recipe" in context_text # Only in label + assert "produces" not in context_text + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_ontology_loading.py b/tests/unit/test_extract/test_ontology/test_ontology_loading.py new file mode 100644 index 00000000..27e34e1f --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_ontology_loading.py @@ -0,0 +1,518 @@ +""" +Unit tests for ontology loading and configuration. + +Tests that ontologies are properly loaded from configuration, +parsed, validated, and managed by the OntologyLoader. +""" + +import pytest +from trustgraph.extract.kg.ontology.ontology_loader import ( + OntologyLoader, + Ontology, + OntologyClass, + OntologyProperty +) + + +@pytest.fixture +def ontology_loader(): + """Create an OntologyLoader instance.""" + return OntologyLoader() + + +@pytest.fixture +def sample_ontology_config(): + """Create a sample ontology configuration.""" + return { + "food": { + "metadata": { + "name": "Food Ontology", + "namespace": "http://purl.org/ontology/fo/" + }, + "classes": { + "Recipe": { + "uri": "http://purl.org/ontology/fo/Recipe", + "type": "owl:Class", + "rdfs:label": [{"value": "Recipe", "lang": "en-gb"}], + "rdfs:comment": "A Recipe is a combination of ingredients and a method." + }, + "Ingredient": { + "uri": "http://purl.org/ontology/fo/Ingredient", + "type": "owl:Class", + "rdfs:label": [{"value": "Ingredient", "lang": "en-gb"}], + "rdfs:comment": "An Ingredient combines a quantity and a food." + }, + "Food": { + "uri": "http://purl.org/ontology/fo/Food", + "type": "owl:Class", + "rdfs:label": [{"value": "Food", "lang": "en-gb"}], + "rdfs:comment": "A Food is something that can be eaten.", + "rdfs:subClassOf": "EdibleThing" + } + }, + "objectProperties": { + "ingredients": { + "uri": "http://purl.org/ontology/fo/ingredients", + "type": "owl:ObjectProperty", + "rdfs:label": [{"value": "ingredients", "lang": "en-gb"}], + "rdfs:domain": "Recipe", + "rdfs:range": "IngredientList" + }, + "produces": { + "uri": "http://purl.org/ontology/fo/produces", + "type": "owl:ObjectProperty", + "rdfs:label": [{"value": "produces", "lang": "en-gb"}], + "rdfs:domain": "Recipe", + "rdfs:range": "Food" + } + }, + "datatypeProperties": { + "serves": { + "uri": "http://purl.org/ontology/fo/serves", + "type": "owl:DatatypeProperty", + "rdfs:label": [{"value": "serves", "lang": "en-gb"}], + "rdfs:domain": "Recipe", + "rdfs:range": "xsd:string" + } + } + } + } + + +class TestOntologyLoaderInitialization: + """Test suite for OntologyLoader initialization.""" + + def test_loader_starts_empty(self, ontology_loader): + """Test that loader initializes with no ontologies.""" + assert len(ontology_loader.get_all_ontologies()) == 0 + + def test_loader_get_nonexistent_ontology(self, ontology_loader): + """Test getting non-existent ontology returns None.""" + result = ontology_loader.get_ontology("nonexistent") + assert result is None + + +class TestOntologyLoading: + """Test suite for loading ontologies from configuration.""" + + def test_loads_single_ontology(self, ontology_loader, sample_ontology_config): + """Test loading a single ontology.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontologies = ontology_loader.get_all_ontologies() + assert len(ontologies) == 1 + assert "food" in ontologies + + def test_loaded_ontology_has_correct_id(self, ontology_loader, sample_ontology_config): + """Test that loaded ontology has correct ID.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + assert ontology is not None + assert ontology.id == "food" + + def test_loaded_ontology_has_metadata(self, ontology_loader, sample_ontology_config): + """Test that loaded ontology includes metadata.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + assert ontology.metadata["name"] == "Food Ontology" + assert ontology.metadata["namespace"] == "http://purl.org/ontology/fo/" + + def test_loaded_ontology_has_classes(self, ontology_loader, sample_ontology_config): + """Test that loaded ontology includes classes.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + assert len(ontology.classes) == 3 + assert "Recipe" in ontology.classes + assert "Ingredient" in ontology.classes + assert "Food" in ontology.classes + + def test_loaded_classes_have_correct_properties(self, ontology_loader, sample_ontology_config): + """Test that loaded classes have correct properties.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + recipe = ontology.get_class("Recipe") + + assert isinstance(recipe, OntologyClass) + assert recipe.uri == "http://purl.org/ontology/fo/Recipe" + assert recipe.type == "owl:Class" + assert len(recipe.labels) == 1 + assert recipe.labels[0]["value"] == "Recipe" + assert "combination of ingredients" in recipe.comment + + def test_loaded_ontology_has_object_properties(self, ontology_loader, sample_ontology_config): + """Test that loaded ontology includes object properties.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + assert len(ontology.object_properties) == 2 + assert "ingredients" in ontology.object_properties + assert "produces" in ontology.object_properties + + def test_loaded_properties_have_domain_and_range(self, ontology_loader, sample_ontology_config): + """Test that loaded properties have domain and range.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + produces = ontology.get_property("produces") + + assert isinstance(produces, OntologyProperty) + assert produces.domain == "Recipe" + assert produces.range == "Food" + + def test_loaded_ontology_has_datatype_properties(self, ontology_loader, sample_ontology_config): + """Test that loaded ontology includes datatype properties.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + assert len(ontology.datatype_properties) == 1 + assert "serves" in ontology.datatype_properties + + def test_loads_multiple_ontologies(self, ontology_loader): + """Test loading multiple ontologies.""" + config = { + "food": { + "metadata": {"name": "Food Ontology"}, + "classes": {"Recipe": {"uri": "http://example.com/Recipe"}}, + "objectProperties": {}, + "datatypeProperties": {} + }, + "music": { + "metadata": {"name": "Music Ontology"}, + "classes": {"Song": {"uri": "http://example.com/Song"}}, + "objectProperties": {}, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + + ontologies = ontology_loader.get_all_ontologies() + assert len(ontologies) == 2 + assert "food" in ontologies + assert "music" in ontologies + + def test_update_replaces_existing_ontologies(self, ontology_loader, sample_ontology_config): + """Test that update replaces existing ontologies.""" + # Load initial ontologies + ontology_loader.update_ontologies(sample_ontology_config) + assert len(ontology_loader.get_all_ontologies()) == 1 + + # Update with different config + new_config = { + "music": { + "metadata": {"name": "Music Ontology"}, + "classes": {}, + "objectProperties": {}, + "datatypeProperties": {} + } + } + ontology_loader.update_ontologies(new_config) + + # Old ontologies should be replaced + ontologies = ontology_loader.get_all_ontologies() + assert len(ontologies) == 1 + assert "music" in ontologies + assert "food" not in ontologies + + +class TestOntologyRetrieval: + """Test suite for retrieving ontologies.""" + + def test_get_ontology_by_id(self, ontology_loader, sample_ontology_config): + """Test retrieving ontology by ID.""" + ontology_loader.update_ontologies(sample_ontology_config) + + ontology = ontology_loader.get_ontology("food") + assert ontology is not None + assert isinstance(ontology, Ontology) + + def test_get_all_ontologies(self, ontology_loader): + """Test retrieving all ontologies.""" + config = { + "food": { + "metadata": {}, + "classes": {}, + "objectProperties": {}, + "datatypeProperties": {} + }, + "music": { + "metadata": {}, + "classes": {}, + "objectProperties": {}, + "datatypeProperties": {} + } + } + ontology_loader.update_ontologies(config) + + ontologies = ontology_loader.get_all_ontologies() + assert isinstance(ontologies, dict) + assert len(ontologies) == 2 + + def test_get_all_ontology_ids(self, ontology_loader): + """Test retrieving all ontology IDs.""" + config = { + "food": { + "metadata": {}, + "classes": {}, + "objectProperties": {}, + "datatypeProperties": {} + }, + "music": { + "metadata": {}, + "classes": {}, + "objectProperties": {}, + "datatypeProperties": {} + } + } + ontology_loader.update_ontologies(config) + + ontologies = ontology_loader.get_all_ontologies() + ids = list(ontologies.keys()) + assert len(ids) == 2 + assert "food" in ids + assert "music" in ids + + +class TestOntologyClassMethods: + """Test suite for Ontology helper methods.""" + + def test_get_class(self, ontology_loader, sample_ontology_config): + """Test getting a class from ontology.""" + ontology_loader.update_ontologies(sample_ontology_config) + ontology = ontology_loader.get_ontology("food") + + recipe = ontology.get_class("Recipe") + assert recipe is not None + assert recipe.uri == "http://purl.org/ontology/fo/Recipe" + + def test_get_nonexistent_class(self, ontology_loader, sample_ontology_config): + """Test getting non-existent class returns None.""" + ontology_loader.update_ontologies(sample_ontology_config) + ontology = ontology_loader.get_ontology("food") + + result = ontology.get_class("NonExistent") + assert result is None + + def test_get_property(self, ontology_loader, sample_ontology_config): + """Test getting a property from ontology.""" + ontology_loader.update_ontologies(sample_ontology_config) + ontology = ontology_loader.get_ontology("food") + + produces = ontology.get_property("produces") + assert produces is not None + assert produces.domain == "Recipe" + + def test_get_property_checks_both_types(self, ontology_loader, sample_ontology_config): + """Test that get_property checks both object and datatype properties.""" + ontology_loader.update_ontologies(sample_ontology_config) + ontology = ontology_loader.get_ontology("food") + + # Object property + produces = ontology.get_property("produces") + assert produces is not None + + # Datatype property + serves = ontology.get_property("serves") + assert serves is not None + + def test_get_parent_classes(self, ontology_loader, sample_ontology_config): + """Test getting parent classes following subClassOf.""" + ontology_loader.update_ontologies(sample_ontology_config) + ontology = ontology_loader.get_ontology("food") + + parents = ontology.get_parent_classes("Food") + assert "EdibleThing" in parents + + def test_get_parent_classes_empty_for_root(self, ontology_loader, sample_ontology_config): + """Test that root classes have no parents.""" + ontology_loader.update_ontologies(sample_ontology_config) + ontology = ontology_loader.get_ontology("food") + + parents = ontology.get_parent_classes("Recipe") + assert len(parents) == 0 + + +class TestOntologyValidation: + """Test suite for ontology validation.""" + + def test_validates_property_domain_exists(self, ontology_loader): + """Test validation of property domain.""" + config = { + "test": { + "metadata": {}, + "classes": { + "Recipe": {"uri": "http://example.com/Recipe"} + }, + "objectProperties": { + "produces": { + "uri": "http://example.com/produces", + "type": "owl:ObjectProperty", + "rdfs:domain": "NonExistentClass", # Invalid + "rdfs:range": "Food" + } + }, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("test") + + issues = ontology.validate_structure() + assert len(issues) > 0 + assert any("unknown domain" in issue.lower() for issue in issues) + + def test_validates_object_property_range_exists(self, ontology_loader): + """Test validation of object property range.""" + config = { + "test": { + "metadata": {}, + "classes": { + "Recipe": {"uri": "http://example.com/Recipe"} + }, + "objectProperties": { + "produces": { + "uri": "http://example.com/produces", + "type": "owl:ObjectProperty", + "rdfs:domain": "Recipe", + "rdfs:range": "NonExistentClass" # Invalid + } + }, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("test") + + issues = ontology.validate_structure() + assert len(issues) > 0 + assert any("unknown range" in issue.lower() for issue in issues) + + def test_detects_circular_inheritance(self, ontology_loader): + """Test detection of circular inheritance.""" + config = { + "test": { + "metadata": {}, + "classes": { + "A": { + "uri": "http://example.com/A", + "rdfs:subClassOf": "B" + }, + "B": { + "uri": "http://example.com/B", + "rdfs:subClassOf": "C" + }, + "C": { + "uri": "http://example.com/C", + "rdfs:subClassOf": "A" # Circular! + } + }, + "objectProperties": {}, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("test") + + issues = ontology.validate_structure() + assert len(issues) > 0 + assert any("circular" in issue.lower() for issue in issues) + + def test_valid_ontology_has_no_issues(self, ontology_loader, sample_ontology_config): + """Test that valid ontology passes validation.""" + # Modify config to have valid references + config = sample_ontology_config.copy() + config["food"]["classes"]["EdibleThing"] = { + "uri": "http://purl.org/ontology/fo/EdibleThing" + } + config["food"]["classes"]["IngredientList"] = { + "uri": "http://purl.org/ontology/fo/IngredientList" + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("food") + + issues = ontology.validate_structure() + # Should have minimal or no issues for valid ontology + assert isinstance(issues, list) + + +class TestEdgeCases: + """Test suite for edge cases in ontology loading.""" + + def test_handles_empty_config(self, ontology_loader): + """Test handling of empty configuration.""" + ontology_loader.update_ontologies({}) + + ontologies = ontology_loader.get_all_ontologies() + assert len(ontologies) == 0 + + def test_handles_ontology_without_classes(self, ontology_loader): + """Test handling of ontology with no classes.""" + config = { + "minimal": { + "metadata": {"name": "Minimal"}, + "classes": {}, + "objectProperties": {}, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("minimal") + + assert ontology is not None + assert len(ontology.classes) == 0 + + def test_handles_ontology_without_properties(self, ontology_loader): + """Test handling of ontology with no properties.""" + config = { + "test": { + "metadata": {}, + "classes": { + "Recipe": {"uri": "http://example.com/Recipe"} + }, + "objectProperties": {}, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("test") + + assert ontology is not None + assert len(ontology.object_properties) == 0 + assert len(ontology.datatype_properties) == 0 + + def test_handles_missing_optional_fields(self, ontology_loader): + """Test handling of missing optional fields.""" + config = { + "test": { + "metadata": {}, + "classes": { + "Simple": { + "uri": "http://example.com/Simple" + # No labels, comments, subclass, etc. + } + }, + "objectProperties": {}, + "datatypeProperties": {} + } + } + + ontology_loader.update_ontologies(config) + ontology = ontology_loader.get_ontology("test") + + simple = ontology.get_class("Simple") + assert simple is not None + assert simple.uri == "http://example.com/Simple" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_ontology_selector.py b/tests/unit/test_extract/test_ontology/test_ontology_selector.py new file mode 100644 index 00000000..37526d74 --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_ontology_selector.py @@ -0,0 +1,336 @@ +""" +Unit tests for OntologySelector component. + +Tests the critical auto-include properties feature that automatically pulls in +all properties related to selected classes. +""" + +import pytest +from unittest.mock import Mock, AsyncMock +from trustgraph.extract.kg.ontology.ontology_selector import ( + OntologySelector, + OntologySubset +) +from trustgraph.extract.kg.ontology.ontology_loader import ( + Ontology, + OntologyClass, + OntologyProperty +) +from trustgraph.extract.kg.ontology.text_processor import TextSegment + + +@pytest.fixture +def sample_ontology(): + """Create a sample food ontology for testing.""" + # Create classes + recipe_class = OntologyClass( + uri="http://purl.org/ontology/fo/Recipe", + type="owl:Class", + labels=[{"value": "Recipe", "lang": "en-gb"}], + comment="A Recipe is a combination of ingredients and a method." + ) + + ingredient_class = OntologyClass( + uri="http://purl.org/ontology/fo/Ingredient", + type="owl:Class", + labels=[{"value": "Ingredient", "lang": "en-gb"}], + comment="An Ingredient is a combination of a quantity and a food." + ) + + food_class = OntologyClass( + uri="http://purl.org/ontology/fo/Food", + type="owl:Class", + labels=[{"value": "Food", "lang": "en-gb"}], + comment="A Food is something that can be eaten." + ) + + method_class = OntologyClass( + uri="http://purl.org/ontology/fo/Method", + type="owl:Class", + labels=[{"value": "Method", "lang": "en-gb"}], + comment="A Method is the way in which ingredients are combined." + ) + + # Create object properties + ingredients_prop = OntologyProperty( + uri="http://purl.org/ontology/fo/ingredients", + type="owl:ObjectProperty", + labels=[{"value": "ingredients", "lang": "en-gb"}], + comment="The ingredients property relates a recipe to an ingredient list.", + domain="Recipe", + range="IngredientList" + ) + + food_prop = OntologyProperty( + uri="http://purl.org/ontology/fo/food", + type="owl:ObjectProperty", + labels=[{"value": "food", "lang": "en-gb"}], + comment="The food property relates an ingredient to the food that is required.", + domain="Ingredient", + range="Food" + ) + + method_prop = OntologyProperty( + uri="http://purl.org/ontology/fo/method", + type="owl:ObjectProperty", + labels=[{"value": "method", "lang": "en-gb"}], + comment="The method property relates a recipe to the method used.", + domain="Recipe", + range="Method" + ) + + produces_prop = OntologyProperty( + uri="http://purl.org/ontology/fo/produces", + type="owl:ObjectProperty", + labels=[{"value": "produces", "lang": "en-gb"}], + comment="The produces property relates a recipe to the food it produces.", + domain="Recipe", + range="Food" + ) + + # Create datatype properties + serves_prop = OntologyProperty( + uri="http://purl.org/ontology/fo/serves", + type="owl:DatatypeProperty", + labels=[{"value": "serves", "lang": "en-gb"}], + comment="The serves property indicates what the recipe is intended to serve.", + domain="Recipe", + range="xsd:string" + ) + + # Build ontology + ontology = Ontology( + id="food", + metadata={ + "name": "Food Ontology", + "namespace": "http://purl.org/ontology/fo/" + }, + classes={ + "Recipe": recipe_class, + "Ingredient": ingredient_class, + "Food": food_class, + "Method": method_class + }, + object_properties={ + "ingredients": ingredients_prop, + "food": food_prop, + "method": method_prop, + "produces": produces_prop + }, + datatype_properties={ + "serves": serves_prop + } + ) + + return ontology + + +@pytest.fixture +def ontology_loader_with_sample(sample_ontology): + """Create an OntologyLoader with the sample ontology.""" + loader = Mock() + loader.get_ontology = Mock(return_value=sample_ontology) + loader.ontologies = {"food": sample_ontology} + return loader + + +@pytest.fixture +def ontology_embedder(): + """Create a mock OntologyEmbedder.""" + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2, 0.3]) # Mock embedding + + # Mock vector store with search results + vector_store = Mock() + embedder.get_vector_store = Mock(return_value=vector_store) + + return embedder + + +class TestOntologySelector: + """Test suite for OntologySelector.""" + + def test_auto_include_properties_for_recipe_class( + self, ontology_loader_with_sample, ontology_embedder, sample_ontology + ): + """Test that selecting Recipe class automatically includes all related properties.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_with_sample, + top_k=10, + similarity_threshold=0.3 + ) + + # Create a subset with only Recipe class initially selected + subset = OntologySubset( + ontology_id="food", + classes={"Recipe": sample_ontology.classes["Recipe"].__dict__}, + object_properties={}, + datatype_properties={}, + metadata=sample_ontology.metadata, + relevance_score=0.8 + ) + + # Resolve dependencies (this is where auto-include happens) + selector._resolve_dependencies(subset) + + # Assert that properties with Recipe in domain are included + assert "ingredients" in subset.object_properties, \ + "ingredients property should be auto-included (Recipe in domain)" + assert "method" in subset.object_properties, \ + "method property should be auto-included (Recipe in domain)" + assert "produces" in subset.object_properties, \ + "produces property should be auto-included (Recipe in domain)" + assert "serves" in subset.datatype_properties, \ + "serves property should be auto-included (Recipe in domain)" + + # Assert that unrelated property is NOT included + assert "food" not in subset.object_properties, \ + "food property should NOT be included (Recipe not in domain/range)" + + def test_auto_include_properties_for_ingredient_class( + self, ontology_loader_with_sample, ontology_embedder, sample_ontology + ): + """Test that selecting Ingredient class includes properties with Ingredient in domain.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_with_sample + ) + + subset = OntologySubset( + ontology_id="food", + classes={"Ingredient": sample_ontology.classes["Ingredient"].__dict__}, + object_properties={}, + datatype_properties={}, + metadata=sample_ontology.metadata + ) + + selector._resolve_dependencies(subset) + + # Ingredient has 'food' property in domain + assert "food" in subset.object_properties, \ + "food property should be auto-included (Ingredient in domain)" + + # Recipe-related properties should NOT be included + assert "ingredients" not in subset.object_properties + assert "method" not in subset.object_properties + + def test_auto_include_properties_for_range_class( + self, ontology_loader_with_sample, ontology_embedder, sample_ontology + ): + """Test that selecting a class includes properties with that class in range.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_with_sample + ) + + subset = OntologySubset( + ontology_id="food", + classes={"Food": sample_ontology.classes["Food"].__dict__}, + object_properties={}, + datatype_properties={}, + metadata=sample_ontology.metadata + ) + + selector._resolve_dependencies(subset) + + # Food appears in range of 'food' and 'produces' properties + assert "food" in subset.object_properties, \ + "food property should be auto-included (Food in range)" + assert "produces" in subset.object_properties, \ + "produces property should be auto-included (Food in range)" + + def test_auto_include_adds_domain_and_range_classes( + self, ontology_loader_with_sample, ontology_embedder, sample_ontology + ): + """Test that auto-included properties also add their domain/range classes.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_with_sample + ) + + # Start with only Recipe class + subset = OntologySubset( + ontology_id="food", + classes={"Recipe": sample_ontology.classes["Recipe"].__dict__}, + object_properties={}, + datatype_properties={}, + metadata=sample_ontology.metadata + ) + + selector._resolve_dependencies(subset) + + # Should auto-include 'produces' property (Recipe → Food) + assert "produces" in subset.object_properties + + # Should also add Food class (range of produces) + assert "Food" in subset.classes, \ + "Food class should be added (range of auto-included produces property)" + + # Should also add Method class (range of method property) + assert "Method" in subset.classes, \ + "Method class should be added (range of auto-included method property)" + + def test_multiple_classes_get_all_related_properties( + self, ontology_loader_with_sample, ontology_embedder, sample_ontology + ): + """Test that selecting multiple classes includes all their related properties.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_with_sample + ) + + # Select both Recipe and Ingredient classes + subset = OntologySubset( + ontology_id="food", + classes={ + "Recipe": sample_ontology.classes["Recipe"].__dict__, + "Ingredient": sample_ontology.classes["Ingredient"].__dict__ + }, + object_properties={}, + datatype_properties={}, + metadata=sample_ontology.metadata + ) + + selector._resolve_dependencies(subset) + + # Should include Recipe-related properties + assert "ingredients" in subset.object_properties + assert "method" in subset.object_properties + assert "produces" in subset.object_properties + assert "serves" in subset.datatype_properties + + # Should also include Ingredient-related properties + assert "food" in subset.object_properties + + def test_no_duplicate_properties_added( + self, ontology_loader_with_sample, ontology_embedder, sample_ontology + ): + """Test that properties aren't added multiple times.""" + selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=ontology_loader_with_sample + ) + + # Start with Recipe and Food (both related to 'produces') + subset = OntologySubset( + ontology_id="food", + classes={ + "Recipe": sample_ontology.classes["Recipe"].__dict__, + "Food": sample_ontology.classes["Food"].__dict__ + }, + object_properties={}, + datatype_properties={}, + metadata=sample_ontology.metadata + ) + + selector._resolve_dependencies(subset) + + # 'produces' should be included once (not duplicated) + assert "produces" in subset.object_properties + # Count would be 1 - dict keys are unique, so this is guaranteed + # but worth documenting the expected behavior + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_ontology_triples.py b/tests/unit/test_extract/test_ontology/test_ontology_triples.py new file mode 100644 index 00000000..70ade79d --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_ontology_triples.py @@ -0,0 +1,300 @@ +""" +Unit tests for ontology triple generation. + +Tests that ontology elements (classes and properties) are properly converted +to RDF triples with labels, comments, domains, and ranges so they appear in +the knowledge graph. +""" + +import pytest +from trustgraph.extract.kg.ontology.extract import Processor +from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset +from trustgraph.schema.core.primitives import Triple, Value + + +@pytest.fixture +def extractor(): + """Create a Processor instance for testing.""" + extractor = object.__new__(Processor) + return extractor + + +@pytest.fixture +def sample_ontology_subset(): + """Create a sample ontology subset with classes and properties.""" + return OntologySubset( + ontology_id="food", + classes={ + "Recipe": { + "uri": "http://purl.org/ontology/fo/Recipe", + "type": "owl:Class", + "labels": [{"value": "Recipe", "lang": "en-gb"}], + "comment": "A Recipe is a combination of ingredients and a method.", + "subclass_of": None + }, + "Ingredient": { + "uri": "http://purl.org/ontology/fo/Ingredient", + "type": "owl:Class", + "labels": [{"value": "Ingredient", "lang": "en-gb"}], + "comment": "An Ingredient combines a quantity and a food.", + "subclass_of": None + }, + "Food": { + "uri": "http://purl.org/ontology/fo/Food", + "type": "owl:Class", + "labels": [{"value": "Food", "lang": "en-gb"}], + "comment": "A Food is something that can be eaten.", + "subclass_of": None + } + }, + object_properties={ + "ingredients": { + "uri": "http://purl.org/ontology/fo/ingredients", + "type": "owl:ObjectProperty", + "labels": [{"value": "ingredients", "lang": "en-gb"}], + "comment": "The ingredients property relates a recipe to an ingredient list.", + "domain": "Recipe", + "range": "IngredientList" + }, + "produces": { + "uri": "http://purl.org/ontology/fo/produces", + "type": "owl:ObjectProperty", + "labels": [{"value": "produces", "lang": "en-gb"}], + "comment": "The produces property relates a recipe to the food it produces.", + "domain": "Recipe", + "range": "Food" + } + }, + datatype_properties={ + "serves": { + "uri": "http://purl.org/ontology/fo/serves", + "type": "owl:DatatypeProperty", + "labels": [{"value": "serves", "lang": "en-gb"}], + "comment": "The serves property indicates serving size.", + "domain": "Recipe", + "rdfs:range": "xsd:string" + } + }, + metadata={ + "name": "Food Ontology", + "namespace": "http://purl.org/ontology/fo/" + } + ) + + +class TestOntologyTripleGeneration: + """Test suite for ontology triple generation.""" + + def test_generates_class_type_triples(self, extractor, sample_ontology_subset): + """Test that classes get rdf:type owl:Class triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find type triples for Recipe class + recipe_type_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/Recipe" + and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + ] + + assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class" + assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \ + "Class type should be owl:Class" + + def test_generates_class_labels(self, extractor, sample_ontology_subset): + """Test that classes get rdfs:label triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find label triples for Recipe class + recipe_label_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/Recipe" + and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label" + ] + + assert len(recipe_label_triples) == 1, "Should generate label triple for class" + assert recipe_label_triples[0].o.value == "Recipe", \ + "Label should match class label from ontology" + assert not recipe_label_triples[0].o.is_uri, \ + "Label should be a literal, not URI" + + def test_generates_class_comments(self, extractor, sample_ontology_subset): + """Test that classes get rdfs:comment triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find comment triples for Recipe class + recipe_comment_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/Recipe" + and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment" + ] + + assert len(recipe_comment_triples) == 1, "Should generate comment triple for class" + assert "combination of ingredients and a method" in recipe_comment_triples[0].o.value, \ + "Comment should match class description from ontology" + + def test_generates_object_property_type_triples(self, extractor, sample_ontology_subset): + """Test that object properties get rdf:type owl:ObjectProperty triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find type triples for ingredients property + ingredients_type_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/ingredients" + and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + ] + + assert len(ingredients_type_triples) == 1, \ + "Should generate exactly one type triple per object property" + assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \ + "Object property type should be owl:ObjectProperty" + + def test_generates_object_property_labels(self, extractor, sample_ontology_subset): + """Test that object properties get rdfs:label triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find label triples for ingredients property + ingredients_label_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/ingredients" + and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label" + ] + + assert len(ingredients_label_triples) == 1, \ + "Should generate label triple for object property" + assert ingredients_label_triples[0].o.value == "ingredients", \ + "Label should match property label from ontology" + + def test_generates_object_property_domain(self, extractor, sample_ontology_subset): + """Test that object properties get rdfs:domain triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find domain triples for ingredients property + ingredients_domain_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/ingredients" + and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain" + ] + + assert len(ingredients_domain_triples) == 1, \ + "Should generate domain triple for object property" + assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \ + "Domain should be Recipe class URI" + assert ingredients_domain_triples[0].o.is_uri, \ + "Domain should be a URI reference" + + def test_generates_object_property_range(self, extractor, sample_ontology_subset): + """Test that object properties get rdfs:range triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find range triples for produces property + produces_range_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/produces" + and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range" + ] + + assert len(produces_range_triples) == 1, \ + "Should generate range triple for object property" + assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \ + "Range should be Food class URI" + + def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset): + """Test that datatype properties get rdf:type owl:DatatypeProperty triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find type triples for serves property + serves_type_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/serves" + and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + ] + + assert len(serves_type_triples) == 1, \ + "Should generate exactly one type triple per datatype property" + assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \ + "Datatype property type should be owl:DatatypeProperty" + + def test_generates_datatype_property_range(self, extractor, sample_ontology_subset): + """Test that datatype properties get rdfs:range triples with XSD types.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Find range triples for serves property + serves_range_triples = [ + t for t in triples + if t.s.value == "http://purl.org/ontology/fo/serves" + and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range" + ] + + assert len(serves_range_triples) == 1, \ + "Should generate range triple for datatype property" + assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \ + "Range should be XSD type URI (xsd:string expanded)" + + def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset): + """Test that triples are generated for all classes in the subset.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Count unique class subjects + class_subjects = set( + t.s.value for t in triples + if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + and t.o.value == "http://www.w3.org/2002/07/owl#Class" + ) + + assert len(class_subjects) == 3, \ + "Should generate triples for all 3 classes (Recipe, Ingredient, Food)" + + def test_generates_triples_for_all_properties(self, extractor, sample_ontology_subset): + """Test that triples are generated for all properties in the subset.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Count unique property subjects (object + datatype properties) + property_subjects = set( + t.s.value for t in triples + if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value) + ) + + assert len(property_subjects) == 3, \ + "Should generate triples for all 3 properties (ingredients, produces, serves)" + + def test_uses_dict_field_names_not_rdf_names(self, extractor, sample_ontology_subset): + """Test that triple generation works with dict field names (labels, comment, domain, range). + + This is critical - the ontology subset has dicts with Python field names, + not RDF property names. + """ + # Verify the subset uses dict field names + recipe_def = sample_ontology_subset.classes["Recipe"] + assert isinstance(recipe_def, dict), "Class definitions should be dicts" + assert "labels" in recipe_def, "Should use 'labels' not 'rdfs:label'" + assert "comment" in recipe_def, "Should use 'comment' not 'rdfs:comment'" + + # Now verify triple generation works + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Should still generate proper RDF triples despite dict field names + label_triples = [ + t for t in triples + if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label" + ] + assert len(label_triples) > 0, \ + "Should generate rdfs:label triples from dict 'labels' field" + + def test_total_triple_count_is_reasonable(self, extractor, sample_ontology_subset): + """Test that we generate a reasonable number of triples.""" + triples = extractor.build_ontology_triples(sample_ontology_subset) + + # Each class gets: type, label, comment (3 triples) + # Each object property gets: type, label, comment, domain, range (5 triples) + # Each datatype property gets: type, label, comment, domain, range (5 triples) + # Expected: 3 classes * 3 + 2 object props * 5 + 1 datatype prop * 5 = 9 + 10 + 5 = 24 + + assert len(triples) >= 20, \ + "Should generate substantial number of triples for ontology elements" + assert len(triples) < 50, \ + "Should not generate excessive duplicate triples" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py new file mode 100644 index 00000000..e6d5bf36 --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py @@ -0,0 +1,414 @@ +""" +Unit tests for LLM prompt construction and triple extraction. + +Tests that the system correctly constructs prompts with ontology constraints +and extracts/validates triples from LLM responses. +""" + +import pytest +from trustgraph.extract.kg.ontology.extract import Processor +from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset +from trustgraph.schema.core.primitives import Triple, Value + + +@pytest.fixture +def extractor(): + """Create a Processor instance for testing.""" + extractor = object.__new__(Processor) + extractor.URI_PREFIXES = { + "rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#", + "rdfs:": "http://www.w3.org/2000/01/rdf-schema#", + "owl:": "http://www.w3.org/2002/07/owl#", + "xsd:": "http://www.w3.org/2001/XMLSchema#", + } + return extractor + + +@pytest.fixture +def sample_ontology_subset(): + """Create a sample ontology subset for extraction testing.""" + return OntologySubset( + ontology_id="food", + classes={ + "Recipe": { + "uri": "http://purl.org/ontology/fo/Recipe", + "type": "owl:Class", + "labels": [{"value": "Recipe", "lang": "en-gb"}], + "comment": "A Recipe is a combination of ingredients and a method." + }, + "Ingredient": { + "uri": "http://purl.org/ontology/fo/Ingredient", + "type": "owl:Class", + "labels": [{"value": "Ingredient", "lang": "en-gb"}], + "comment": "An Ingredient combines a quantity and a food." + }, + "Food": { + "uri": "http://purl.org/ontology/fo/Food", + "type": "owl:Class", + "labels": [{"value": "Food", "lang": "en-gb"}], + "comment": "A Food is something that can be eaten." + } + }, + object_properties={ + "ingredients": { + "uri": "http://purl.org/ontology/fo/ingredients", + "type": "owl:ObjectProperty", + "labels": [{"value": "ingredients", "lang": "en-gb"}], + "comment": "The ingredients property relates a recipe to an ingredient list.", + "domain": "Recipe", + "range": "IngredientList" + }, + "food": { + "uri": "http://purl.org/ontology/fo/food", + "type": "owl:ObjectProperty", + "labels": [{"value": "food", "lang": "en-gb"}], + "comment": "The food property relates an ingredient to food.", + "domain": "Ingredient", + "range": "Food" + }, + "produces": { + "uri": "http://purl.org/ontology/fo/produces", + "type": "owl:ObjectProperty", + "labels": [{"value": "produces", "lang": "en-gb"}], + "comment": "The produces property relates a recipe to the food it produces.", + "domain": "Recipe", + "range": "Food" + } + }, + datatype_properties={ + "serves": { + "uri": "http://purl.org/ontology/fo/serves", + "type": "owl:DatatypeProperty", + "labels": [{"value": "serves", "lang": "en-gb"}], + "comment": "The serves property indicates serving size.", + "domain": "Recipe", + "rdfs:range": "xsd:string" + } + }, + metadata={ + "name": "Food Ontology", + "namespace": "http://purl.org/ontology/fo/" + } + ) + + +class TestPromptConstruction: + """Test suite for LLM prompt construction.""" + + def test_build_extraction_variables_includes_text(self, extractor, sample_ontology_subset): + """Test that extraction variables include the input text.""" + chunk = "Cornish pasty is a traditional British pastry filled with meat and vegetables." + + variables = extractor.build_extraction_variables(chunk, sample_ontology_subset) + + assert "text" in variables, "Should include text key" + assert variables["text"] == chunk, "Text should match input chunk" + + def test_build_extraction_variables_includes_classes(self, extractor, sample_ontology_subset): + """Test that extraction variables include ontology classes.""" + chunk = "Test text" + + variables = extractor.build_extraction_variables(chunk, sample_ontology_subset) + + assert "classes" in variables, "Should include classes key" + assert len(variables["classes"]) == 3, "Should include all classes from subset" + assert "Recipe" in variables["classes"] + assert "Ingredient" in variables["classes"] + assert "Food" in variables["classes"] + + def test_build_extraction_variables_includes_properties(self, extractor, sample_ontology_subset): + """Test that extraction variables include ontology properties.""" + chunk = "Test text" + + variables = extractor.build_extraction_variables(chunk, sample_ontology_subset) + + assert "object_properties" in variables, "Should include object_properties key" + assert "datatype_properties" in variables, "Should include datatype_properties key" + assert len(variables["object_properties"]) == 3 + assert len(variables["datatype_properties"]) == 1 + + def test_build_extraction_variables_structure(self, extractor, sample_ontology_subset): + """Test the overall structure of extraction variables.""" + chunk = "Test text" + + variables = extractor.build_extraction_variables(chunk, sample_ontology_subset) + + # Should have exactly 4 keys + assert set(variables.keys()) == {"text", "classes", "object_properties", "datatype_properties"} + + def test_build_extraction_variables_with_empty_subset(self, extractor): + """Test building variables with minimal ontology subset.""" + minimal_subset = OntologySubset( + ontology_id="minimal", + classes={}, + object_properties={}, + datatype_properties={}, + metadata={} + ) + chunk = "Test text" + + variables = extractor.build_extraction_variables(chunk, minimal_subset) + + assert variables["text"] == chunk + assert len(variables["classes"]) == 0 + assert len(variables["object_properties"]) == 0 + assert len(variables["datatype_properties"]) == 0 + + +class TestTripleValidation: + """Test suite for triple validation against ontology.""" + + def test_validates_rdf_type_triple_with_valid_class(self, extractor, sample_ontology_subset): + """Test that rdf:type triples are validated against ontology classes.""" + subject = "cornish-pasty" + predicate = "rdf:type" + object_val = "Recipe" + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + + assert is_valid, "rdf:type with valid class should be valid" + + def test_rejects_rdf_type_triple_with_invalid_class(self, extractor, sample_ontology_subset): + """Test that rdf:type triples with non-existent classes are rejected.""" + subject = "cornish-pasty" + predicate = "rdf:type" + object_val = "NonExistentClass" + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + + assert not is_valid, "rdf:type with invalid class should be rejected" + + def test_validates_rdfs_label_triple(self, extractor, sample_ontology_subset): + """Test that rdfs:label triples are always valid.""" + subject = "cornish-pasty" + predicate = "rdfs:label" + object_val = "Cornish Pasty" + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + + assert is_valid, "rdfs:label should always be valid" + + def test_validates_object_property_triple(self, extractor, sample_ontology_subset): + """Test that object property triples are validated.""" + subject = "cornish-pasty-recipe" + predicate = "produces" + object_val = "cornish-pasty" + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + + assert is_valid, "Valid object property should be accepted" + + def test_validates_datatype_property_triple(self, extractor, sample_ontology_subset): + """Test that datatype property triples are validated.""" + subject = "cornish-pasty-recipe" + predicate = "serves" + object_val = "4-6 people" + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + + assert is_valid, "Valid datatype property should be accepted" + + def test_rejects_unknown_property(self, extractor, sample_ontology_subset): + """Test that unknown properties are rejected.""" + subject = "cornish-pasty" + predicate = "unknownProperty" + object_val = "some value" + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + + assert not is_valid, "Unknown property should be rejected" + + def test_validates_multiple_valid_properties(self, extractor, sample_ontology_subset): + """Test validation of different property types.""" + test_cases = [ + ("recipe1", "produces", "food1", True), + ("ingredient1", "food", "food1", True), + ("recipe1", "serves", "4", True), + ("recipe1", "invalidProp", "value", False), + ] + + for subject, predicate, object_val, expected in test_cases: + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) + assert is_valid == expected, f"Validation of {predicate} should be {expected}" + + +class TestTripleParsing: + """Test suite for parsing triples from LLM responses.""" + + def test_parse_simple_triple_dict(self, extractor, sample_ontology_subset): + """Test parsing a simple triple from dict format.""" + triples_response = [ + { + "subject": "cornish-pasty", + "predicate": "rdf:type", + "object": "Recipe" + } + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 1, "Should parse one valid triple" + assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty" + assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" + assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe" + + def test_parse_multiple_triples(self, extractor, sample_ontology_subset): + """Test parsing multiple triples.""" + triples_response = [ + {"subject": "cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}, + {"subject": "cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"}, + {"subject": "cornish-pasty", "predicate": "serves", "object": "1-2 people"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 3, "Should parse all valid triples" + + def test_filters_invalid_triples(self, extractor, sample_ontology_subset): + """Test that invalid triples are filtered out.""" + triples_response = [ + {"subject": "cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}, # Valid + {"subject": "cornish-pasty", "predicate": "invalidProp", "object": "value"}, # Invalid + {"subject": "cornish-pasty", "predicate": "produces", "object": "food1"} # Valid + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 2, "Should filter out invalid triple" + + def test_handles_missing_fields(self, extractor, sample_ontology_subset): + """Test that triples with missing fields are skipped.""" + triples_response = [ + {"subject": "cornish-pasty", "predicate": "rdf:type"}, # Missing object + {"subject": "cornish-pasty", "object": "Recipe"}, # Missing predicate + {"predicate": "rdf:type", "object": "Recipe"}, # Missing subject + {"subject": "cornish-pasty", "predicate": "rdf:type", "object": "Recipe"} # Valid + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 1, "Should skip triples with missing fields" + + def test_handles_empty_response(self, extractor, sample_ontology_subset): + """Test handling of empty LLM response.""" + triples_response = [] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 0, "Empty response should return no triples" + + def test_expands_uris_in_parsed_triples(self, extractor, sample_ontology_subset): + """Test that URIs are properly expanded in parsed triples.""" + triples_response = [ + {"subject": "recipe1", "predicate": "produces", "object": "Food"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 1 + # Subject should be expanded to entity URI + assert validated[0].s.value.startswith("https://trustgraph.ai/food/") + # Predicate should be expanded to ontology URI + assert validated[0].p.value == "http://purl.org/ontology/fo/produces" + # Object should be expanded to class URI + assert validated[0].o.value == "http://purl.org/ontology/fo/Food" + + def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset): + """Test that Triple objects are properly created.""" + triples_response = [ + {"subject": "cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 1 + triple = validated[0] + assert isinstance(triple, Triple), "Should create Triple objects" + assert isinstance(triple.s, Value), "Subject should be Value object" + assert isinstance(triple.p, Value), "Predicate should be Value object" + assert isinstance(triple.o, Value), "Object should be Value object" + assert triple.s.is_uri, "Subject should be marked as URI" + assert triple.p.is_uri, "Predicate should be marked as URI" + assert not triple.o.is_uri, "Object literal should not be marked as URI" + + +class TestURIExpansionInExtraction: + """Test suite for URI expansion during triple extraction.""" + + def test_expands_class_names_in_objects(self, extractor, sample_ontology_subset): + """Test that class names in object position are expanded.""" + triples_response = [ + {"subject": "entity1", "predicate": "rdf:type", "object": "Recipe"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe" + assert validated[0].o.is_uri, "Class reference should be URI" + + def test_expands_property_names(self, extractor, sample_ontology_subset): + """Test that property names are expanded to full URIs.""" + triples_response = [ + {"subject": "recipe1", "predicate": "produces", "object": "food1"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert validated[0].p.value == "http://purl.org/ontology/fo/produces" + + def test_expands_entity_instances(self, extractor, sample_ontology_subset): + """Test that entity instances get constructed URIs.""" + triples_response = [ + {"subject": "my-special-recipe", "predicate": "rdf:type", "object": "Recipe"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert validated[0].s.value.startswith("https://trustgraph.ai/food/") + assert "my-special-recipe" in validated[0].s.value + + +class TestEdgeCases: + """Test suite for edge cases in extraction.""" + + def test_handles_non_dict_response_items(self, extractor, sample_ontology_subset): + """Test that non-dict items in response are skipped.""" + triples_response = [ + {"subject": "entity1", "predicate": "rdf:type", "object": "Recipe"}, # Valid + "invalid string item", # Invalid + None, # Invalid + {"subject": "entity2", "predicate": "rdf:type", "object": "Food"} # Valid + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + # Should skip non-dict items gracefully + assert len(validated) >= 0, "Should handle non-dict items without crashing" + + def test_handles_empty_string_values(self, extractor, sample_ontology_subset): + """Test that empty string values are skipped.""" + triples_response = [ + {"subject": "", "predicate": "rdf:type", "object": "Recipe"}, + {"subject": "entity1", "predicate": "", "object": "Recipe"}, + {"subject": "entity1", "predicate": "rdf:type", "object": ""}, + {"subject": "entity1", "predicate": "rdf:type", "object": "Recipe"} # Valid + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 1, "Should skip triples with empty strings" + + def test_handles_unicode_in_literals(self, extractor, sample_ontology_subset): + """Test handling of unicode characters in literal values.""" + triples_response = [ + {"subject": "café-recipe", "predicate": "rdfs:label", "object": "Café Spécial"} + ] + + validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset) + + assert len(validated) == 1 + assert "Café Spécial" in validated[0].o.value + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_text_processing.py b/tests/unit/test_extract/test_ontology/test_text_processing.py new file mode 100644 index 00000000..67686297 --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_text_processing.py @@ -0,0 +1,290 @@ +""" +Unit tests for text processing and segmentation. + +Tests that text is properly split into sentences for ontology matching, +including NLTK tokenization and TextSegment creation. +""" + +import pytest +from trustgraph.extract.kg.ontology.text_processor import TextProcessor, TextSegment + + +@pytest.fixture +def text_processor(): + """Create a TextProcessor instance for testing.""" + return TextProcessor() + + +class TestTextSegmentation: + """Test suite for text segmentation functionality.""" + + def test_segment_single_sentence(self, text_processor): + """Test segmentation of a single sentence.""" + text = "This is a simple sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 1, "Single sentence should produce one sentence segment" + assert text in sentences[0].text, "Segment text should contain input" + + def test_segment_multiple_sentences(self, text_processor): + """Test segmentation of multiple sentences.""" + text = "First sentence. Second sentence. Third sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 3, "Should create three sentence segments for three sentences" + assert "First sentence" in sentences[0].text + assert "Second sentence" in sentences[1].text + assert "Third sentence" in sentences[2].text + + def test_segment_positions(self, text_processor): + """Test that segment positions are tracked.""" + text = "First sentence. Second sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 2 + assert sentences[0].position == 0 + assert sentences[1].position > 0 + + def test_segment_empty_text(self, text_processor): + """Test handling of empty text.""" + text = "" + + segments = text_processor.process_chunk(text, extract_phrases=False) + + assert len(segments) == 0, "Empty text should produce no segments" + + def test_segment_whitespace_only(self, text_processor): + """Test handling of whitespace-only text.""" + text = " \n\t " + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # May produce empty segments or no segments depending on implementation + assert len(segments) <= 1, "Whitespace-only text should produce minimal segments" + + def test_segment_with_newlines(self, text_processor): + """Test segmentation of text with newlines.""" + text = "First sentence.\nSecond sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 2 + assert "First sentence" in sentences[0].text + assert "Second sentence" in sentences[1].text + + def test_segment_complex_punctuation(self, text_processor): + """Test segmentation with complex punctuation.""" + text = "Dr. Smith went to the U.S.A. yesterday. He met Mr. Jones." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + # NLTK should handle abbreviations correctly + assert len(sentences) == 2, "Should recognize abbreviations and not split on them" + assert "Dr. Smith" in sentences[0].text + assert "Mr. Jones" in sentences[1].text + + def test_segment_question_and_exclamation(self, text_processor): + """Test segmentation with different sentence terminators.""" + text = "Is this working? Yes, it is! Great news." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 3 + assert "Is this working?" in sentences[0].text + assert "Yes, it is!" in sentences[1].text + assert "Great news" in sentences[2].text + + def test_segment_long_paragraph(self, text_processor): + """Test segmentation of a longer paragraph.""" + text = ( + "The recipe requires several ingredients. " + "First, gather flour and sugar. " + "Then, add eggs and milk. " + "Finally, mix everything together." + ) + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Filter to only sentences + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 4, "Should split paragraph into individual sentences" + assert all(isinstance(seg, TextSegment) for seg in sentences) + + def test_extract_phrases_option(self, text_processor): + """Test that phrase extraction can be enabled.""" + text = "The recipe requires several ingredients." + + # With phrases + segments_with_phrases = text_processor.process_chunk(text, extract_phrases=True) + # Without phrases + segments_without_phrases = text_processor.process_chunk(text, extract_phrases=False) + + # With phrases should have more segments (sentences + phrases) + assert len(segments_with_phrases) >= len(segments_without_phrases) + + +class TestTextSegmentCreation: + """Test suite for TextSegment object creation.""" + + def test_text_segment_attributes(self, text_processor): + """Test that TextSegment objects have correct attributes.""" + text = "This is a test sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + assert len(segments) >= 1 + segment = segments[0] + + assert hasattr(segment, 'text'), "Segment should have text attribute" + assert hasattr(segment, 'type'), "Segment should have type attribute" + assert hasattr(segment, 'position'), "Segment should have position attribute" + assert segment.type in ['sentence', 'phrase', 'noun_phrase', 'verb_phrase'] + + def test_text_segment_types(self, text_processor): + """Test that different segment types are created correctly.""" + text = "The recipe requires several ingredients." + + # Without phrases + segments = text_processor.process_chunk(text, extract_phrases=False) + types = set(s.type for s in segments) + assert 'sentence' in types, "Should create sentence segments" + + # With phrases + segments = text_processor.process_chunk(text, extract_phrases=True) + types = set(s.type for s in segments) + assert 'sentence' in types, "Should create sentence segments" + # May also have phrase types + + def test_text_segment_sentence_tracking(self, text_processor): + """Test that segments track their parent sentence.""" + text = "This is a test sentence." + + segments = text_processor.process_chunk(text, extract_phrases=True) + + # Phrases should reference their parent sentence + phrases = [s for s in segments if s.type != 'sentence'] + if phrases: + for phrase in phrases: + # parent_sentence may be set for phrases + assert hasattr(phrase, 'parent_sentence') + + +class TestNLTKCompatibility: + """Test suite for NLTK version compatibility.""" + + def test_nltk_punkt_availability(self, text_processor): + """Test that NLTK punkt tokenizer is available.""" + # This test verifies the text_processor can use NLTK + # If punkt/punkt_tab is not available, this will fail during setup + import nltk + + # Try to use sentence tokenizer + text = "Test sentence. Another sentence." + + try: + from nltk.tokenize import sent_tokenize + result = sent_tokenize(text) + assert len(result) > 0, "NLTK sentence tokenizer should work" + except LookupError: + pytest.fail("NLTK punkt tokenizer not available") + + def test_text_processor_uses_nltk(self, text_processor): + """Test that TextProcessor successfully uses NLTK for segmentation.""" + # This verifies the integration works + text = "First sentence. Second sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Should successfully segment using NLTK + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) >= 1, "Should successfully segment text using NLTK" + + +class TestEdgeCases: + """Test suite for edge cases in text processing.""" + + def test_sentence_with_only_punctuation(self, text_processor): + """Test handling of unusual punctuation patterns.""" + text = "...!?!" + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # Should handle gracefully (NLTK may split this oddly, that's ok) + assert len(segments) <= 3, "Should handle punctuation-only text gracefully" + + def test_very_long_sentence(self, text_processor): + """Test handling of very long sentences.""" + # Create a long sentence with many clauses + text = ( + "This is a very long sentence with many clauses, " + "including subordinate clauses, coordinate clauses, " + "and various other grammatical structures that make it " + "quite lengthy but still technically a single sentence." + ) + + segments = text_processor.process_chunk(text, extract_phrases=False) + + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 1, "Long sentence should still be one sentence segment" + assert len(sentences[0].text) > 100 + + def test_unicode_text(self, text_processor): + """Test handling of unicode characters.""" + text = "Café serves crêpes. The recipe is français." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 2 + assert "Café" in sentences[0].text + assert "français" in sentences[1].text + + def test_numbers_and_dates(self, text_processor): + """Test handling of numbers and dates in text.""" + text = "The recipe was created on Jan. 1, 2024. It serves 4-6 people." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 2 + assert "2024" in sentences[0].text + assert "4-6" in sentences[1].text + + def test_ellipsis_handling(self, text_processor): + """Test handling of ellipsis in text.""" + text = "First sentence... Second sentence." + + segments = text_processor.process_chunk(text, extract_phrases=False) + + # NLTK may handle ellipsis differently + assert len(segments) >= 1, "Should produce at least one segment" + # The exact behavior depends on NLTK version + + def test_quoted_text(self, text_processor): + """Test handling of quoted text.""" + text = 'He said "Hello world." Then he left.' + + segments = text_processor.process_chunk(text, extract_phrases=False) + + sentences = [s for s in segments if s.type == 'sentence'] + assert len(sentences) == 2 + assert '"Hello world."' in sentences[0].text or "Hello world" in sentences[0].text + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_extract/test_ontology/test_uri_expansion.py b/tests/unit/test_extract/test_ontology/test_uri_expansion.py new file mode 100644 index 00000000..dec33ff1 --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_uri_expansion.py @@ -0,0 +1,258 @@ +""" +Unit tests for URI expansion functionality. + +Tests that URIs are properly expanded using ontology definitions instead of +constructed fallback URIs. +""" + +import pytest +from trustgraph.extract.kg.ontology.extract import Processor +from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset + + +class MockParams: + """Mock parameters for Processor.""" + def get(self, key, default=None): + return default + + +@pytest.fixture +def extractor(): + """Create a Processor instance for testing.""" + params = MockParams() + # We only need the expand_uri method, so minimal initialization + extractor = object.__new__(Processor) + extractor.URI_PREFIXES = { + "rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#", + "rdfs:": "http://www.w3.org/2000/01/rdf-schema#", + "owl:": "http://www.w3.org/2002/07/owl#", + "xsd:": "http://www.w3.org/2001/XMLSchema#", + } + return extractor + + +@pytest.fixture +def ontology_subset_with_uris(): + """Create an ontology subset with proper URIs defined.""" + return OntologySubset( + ontology_id="food", + classes={ + "Recipe": { + "uri": "http://purl.org/ontology/fo/Recipe", + "type": "owl:Class", + "labels": [{"value": "Recipe", "lang": "en-gb"}], + "comment": "A Recipe is a combination of ingredients and a method." + }, + "Ingredient": { + "uri": "http://purl.org/ontology/fo/Ingredient", + "type": "owl:Class", + "labels": [{"value": "Ingredient", "lang": "en-gb"}], + "comment": "An Ingredient combines a quantity and a food." + }, + "Food": { + "uri": "http://purl.org/ontology/fo/Food", + "type": "owl:Class", + "labels": [{"value": "Food", "lang": "en-gb"}], + "comment": "A Food is something that can be eaten." + } + }, + object_properties={ + "ingredients": { + "uri": "http://purl.org/ontology/fo/ingredients", + "type": "owl:ObjectProperty", + "labels": [{"value": "ingredients", "lang": "en-gb"}], + "domain": "Recipe", + "range": "IngredientList" + }, + "food": { + "uri": "http://purl.org/ontology/fo/food", + "type": "owl:ObjectProperty", + "labels": [{"value": "food", "lang": "en-gb"}], + "domain": "Ingredient", + "range": "Food" + }, + "produces": { + "uri": "http://purl.org/ontology/fo/produces", + "type": "owl:ObjectProperty", + "labels": [{"value": "produces", "lang": "en-gb"}], + "domain": "Recipe", + "range": "Food" + } + }, + datatype_properties={ + "serves": { + "uri": "http://purl.org/ontology/fo/serves", + "type": "owl:DatatypeProperty", + "labels": [{"value": "serves", "lang": "en-gb"}], + "domain": "Recipe", + "range": "xsd:string" + } + }, + metadata={ + "name": "Food Ontology", + "namespace": "http://purl.org/ontology/fo/" + } + ) + + +class TestURIExpansion: + """Test suite for URI expansion functionality.""" + + def test_expand_class_uri_from_ontology(self, extractor, ontology_subset_with_uris): + """Test that class names are expanded to their ontology URIs.""" + result = extractor.expand_uri("Recipe", ontology_subset_with_uris, "food") + + assert result == "http://purl.org/ontology/fo/Recipe", \ + "Recipe should expand to its ontology URI" + + def test_expand_object_property_uri_from_ontology(self, extractor, ontology_subset_with_uris): + """Test that object properties are expanded to their ontology URIs.""" + result = extractor.expand_uri("ingredients", ontology_subset_with_uris, "food") + + assert result == "http://purl.org/ontology/fo/ingredients", \ + "ingredients property should expand to its ontology URI" + + def test_expand_datatype_property_uri_from_ontology(self, extractor, ontology_subset_with_uris): + """Test that datatype properties are expanded to their ontology URIs.""" + result = extractor.expand_uri("serves", ontology_subset_with_uris, "food") + + assert result == "http://purl.org/ontology/fo/serves", \ + "serves property should expand to its ontology URI" + + def test_expand_multiple_classes(self, extractor, ontology_subset_with_uris): + """Test expansion of multiple different classes.""" + recipe_uri = extractor.expand_uri("Recipe", ontology_subset_with_uris, "food") + ingredient_uri = extractor.expand_uri("Ingredient", ontology_subset_with_uris, "food") + food_uri = extractor.expand_uri("Food", ontology_subset_with_uris, "food") + + assert recipe_uri == "http://purl.org/ontology/fo/Recipe" + assert ingredient_uri == "http://purl.org/ontology/fo/Ingredient" + assert food_uri == "http://purl.org/ontology/fo/Food" + + def test_expand_rdf_prefix(self, extractor, ontology_subset_with_uris): + """Test that standard RDF prefixes are expanded.""" + result = extractor.expand_uri("rdf:type", ontology_subset_with_uris, "food") + + assert result == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", \ + "rdf:type should expand to full RDF namespace URI" + + def test_expand_rdfs_prefix(self, extractor, ontology_subset_with_uris): + """Test that RDFS prefixes are expanded.""" + result = extractor.expand_uri("rdfs:label", ontology_subset_with_uris, "food") + + assert result == "http://www.w3.org/2000/01/rdf-schema#label", \ + "rdfs:label should expand to full RDFS namespace URI" + + def test_expand_owl_prefix(self, extractor, ontology_subset_with_uris): + """Test that OWL prefixes are expanded.""" + result = extractor.expand_uri("owl:Class", ontology_subset_with_uris, "food") + + assert result == "http://www.w3.org/2002/07/owl#Class", \ + "owl:Class should expand to full OWL namespace URI" + + def test_expand_xsd_prefix(self, extractor, ontology_subset_with_uris): + """Test that XSD prefixes are expanded.""" + result = extractor.expand_uri("xsd:string", ontology_subset_with_uris, "food") + + assert result == "http://www.w3.org/2001/XMLSchema#string", \ + "xsd:string should expand to full XSD namespace URI" + + def test_fallback_uri_for_instance(self, extractor, ontology_subset_with_uris): + """Test that entity instances get constructed URIs when not in ontology.""" + result = extractor.expand_uri("recipe:cornish-pasty", ontology_subset_with_uris, "food") + + # Should construct a URI for the instance + assert result.startswith("https://trustgraph.ai/food/"), \ + "Entity instance should get constructed URI under trustgraph.ai domain" + assert "cornish-pasty" in result.lower(), \ + "Instance URI should include normalized entity name" + + def test_already_full_uri_unchanged(self, extractor, ontology_subset_with_uris): + """Test that full URIs are returned unchanged.""" + full_uri = "http://example.com/custom/entity" + result = extractor.expand_uri(full_uri, ontology_subset_with_uris, "food") + + assert result == full_uri, \ + "Full URIs should be returned unchanged" + + def test_https_uri_unchanged(self, extractor, ontology_subset_with_uris): + """Test that HTTPS URIs are returned unchanged.""" + full_uri = "https://example.com/custom/entity" + result = extractor.expand_uri(full_uri, ontology_subset_with_uris, "food") + + assert result == full_uri, \ + "HTTPS URIs should be returned unchanged" + + def test_class_without_uri_gets_fallback(self, extractor): + """Test that classes without URI definitions get constructed fallback URIs.""" + # Create subset with class that has no URI + subset_no_uri = OntologySubset( + ontology_id="test", + classes={ + "SomeClass": { + "type": "owl:Class", + "labels": [{"value": "Some Class"}], + # No 'uri' field + } + }, + object_properties={}, + datatype_properties={}, + metadata={} + ) + + result = extractor.expand_uri("SomeClass", subset_no_uri, "test") + + assert result == "https://trustgraph.ai/ontology/test#SomeClass", \ + "Class without URI should get fallback constructed URI" + + def test_property_without_uri_gets_fallback(self, extractor): + """Test that properties without URI definitions get constructed fallback URIs.""" + subset_no_uri = OntologySubset( + ontology_id="test", + classes={}, + object_properties={ + "someProperty": { + "type": "owl:ObjectProperty", + # No 'uri' field + } + }, + datatype_properties={}, + metadata={} + ) + + result = extractor.expand_uri("someProperty", subset_no_uri, "test") + + assert result == "https://trustgraph.ai/ontology/test#someProperty", \ + "Property without URI should get fallback constructed URI" + + def test_entity_normalization_in_constructed_uri(self, extractor, ontology_subset_with_uris): + """Test that entity names are normalized when constructing URIs.""" + # Entity with spaces and mixed case + result = extractor.expand_uri("Cornish Pasty Recipe", ontology_subset_with_uris, "food") + + # Should be normalized: lowercase, spaces to hyphens + assert result == "https://trustgraph.ai/food/cornish-pasty-recipe", \ + "Entity names should be normalized (lowercase, spaces to hyphens)" + + def test_dict_access_not_object_attribute(self, extractor, ontology_subset_with_uris): + """Test that URI expansion works with dict access (not object attributes). + + This is the key fix - ontology_selector stores cls.__dict__ which means + we get dicts, not objects, so we must use dict key access. + """ + # The ontology_subset_with_uris uses dicts (with 'uri' key) + # This test verifies we can access it correctly + class_def = ontology_subset_with_uris.classes["Recipe"] + + # Verify it's a dict + assert isinstance(class_def, dict), "Class definitions should be dicts" + assert "uri" in class_def, "Dict should have 'uri' key" + + # Now test expansion works + result = extractor.expand_uri("Recipe", ontology_subset_with_uris, "food") + assert result == class_def["uri"], \ + "URI expansion must work with dict access (not object attributes)" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index ce2a7431..069546fb 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -119,8 +119,8 @@ class TestPineconeDocEmbeddingsQueryProcessor: chunks = await processor.query_document_embeddings(message) - # Verify index was accessed correctly - expected_index_name = "d-test_user-test_collection" + # Verify index was accessed correctly (with dimension suffix) + expected_index_name = "d-test_user-test_collection-3" # 3 dimensions processor.pinecone.Index.assert_called_once_with(expected_index_name) # Verify query parameters @@ -264,10 +264,12 @@ class TestPineconeDocEmbeddingsQueryProcessor: chunks = await processor.query_document_embeddings(message) - # Verify same index used for both vectors - expected_index_name = "d-test_user-test_collection" + # Verify different indexes used for different dimensions assert processor.pinecone.Index.call_count == 2 - processor.pinecone.Index.assert_called_with(expected_index_name) + index_calls = processor.pinecone.Index.call_args_list + index_names = [call[0][0] for call in index_calls] + assert "d-test_user-test_collection-2" in index_names # 2D vector + assert "d-test_user-test_collection-4" in index_names # 4D vector # Verify both queries were made assert mock_index.query.call_count == 2 diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index f4a1d977..ad337f73 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -103,8 +103,8 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called with correct parameters - expected_collection = 'd_test_user_test_collection' + # Verify query was called with correct parameters (with dimension suffix) + expected_collection = 'd_test_user_test_collection_3' # 3 dimensions mock_qdrant_instance.query_points.assert_called_once_with( collection_name=expected_collection, query=[0.1, 0.2, 0.3], @@ -164,9 +164,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert # Verify query was called twice assert mock_qdrant_instance.query_points.call_count == 2 - - # Verify both collections were queried - expected_collection = 'd_multi_user_multi_collection' + + # Verify both collections were queried (both 2-dimensional vectors) + expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection assert calls[1][1]['collection_name'] == expected_collection @@ -301,13 +301,13 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Verify query was called twice with different collections assert mock_qdrant_instance.query_points.call_count == 2 calls = mock_qdrant_instance.query_points.call_args_list - + # First call should use 2D collection - assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection' + assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['query'] == [0.1, 0.2] - + # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection' + assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions assert calls[1][1]['query'] == [0.3, 0.4, 0.5] # Verify results diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index dbe9b9fc..930334c7 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -147,8 +147,8 @@ class TestPineconeGraphEmbeddingsQueryProcessor: entities = await processor.query_graph_embeddings(message) - # Verify index was accessed correctly - expected_index_name = "t-test_user-test_collection" + # Verify index was accessed correctly (with dimension suffix) + expected_index_name = "t-test_user-test_collection-3" # 3 dimensions processor.pinecone.Index.assert_called_once_with(expected_index_name) # Verify query parameters @@ -290,10 +290,12 @@ class TestPineconeGraphEmbeddingsQueryProcessor: entities = await processor.query_graph_embeddings(message) - # Verify same index used for both vectors - expected_index_name = "t-test_user-test_collection" + # Verify different indexes used for different dimensions assert processor.pinecone.Index.call_count == 2 - processor.pinecone.Index.assert_called_with(expected_index_name) + index_calls = processor.pinecone.Index.call_args_list + index_names = [call[0][0] for call in index_calls] + assert "t-test_user-test_collection-2" in index_names # 2D vector + assert "t-test_user-test_collection-4" in index_names # 4D vector # Verify both queries were made assert mock_index.query.call_count == 2 diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 0dd0e94e..ab22c9df 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -175,8 +175,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called with correct parameters - expected_collection = 't_test_user_test_collection' + # Verify query was called with correct parameters (with dimension suffix) + expected_collection = 't_test_user_test_collection_3' # 3 dimensions mock_qdrant_instance.query_points.assert_called_once_with( collection_name=expected_collection, query=[0.1, 0.2, 0.3], @@ -234,9 +234,9 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert # Verify query was called twice assert mock_qdrant_instance.query_points.call_count == 2 - - # Verify both collections were queried - expected_collection = 't_multi_user_multi_collection' + + # Verify both collections were queried (both 2-dimensional vectors) + expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection assert calls[1][1]['collection_name'] == expected_collection @@ -372,13 +372,13 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Verify query was called twice with different collections assert mock_qdrant_instance.query_points.call_count == 2 calls = mock_qdrant_instance.query_points.call_args_list - + # First call should use 2D collection - assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection' + assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['query'] == [0.1, 0.2] - + # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection' + assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions assert calls[1][1]['query'] == [0.3, 0.4, 0.5] # Verify results diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py index 848916f5..41f786d0 100644 --- a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -134,8 +134,8 @@ class TestPineconeDocEmbeddingsStorageProcessor: with patch('uuid.uuid4', side_effect=['id1', 'id2']): await processor.store_document_embeddings(message) - # Verify index name and operations - expected_index_name = "d-test_user-test_collection" + # Verify index name and operations (with dimension suffix) + expected_index_name = "d-test_user-test_collection-3" # 3 dimensions processor.pinecone.Index.assert_called_with(expected_index_name) # Verify upsert was called for each vector @@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_index_validation(self, processor): - """Test that writing to non-existent index raises ValueError""" + """Test that writing to non-existent index creates it lazily""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' @@ -191,12 +191,24 @@ class TestPineconeDocEmbeddingsStorageProcessor: ) message.chunks = [chunk] - # Mock index doesn't exist + # Mock index doesn't exist initially processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index - with pytest.raises(ValueError, match="Collection .* does not exist"): + with patch('uuid.uuid4', return_value='test-id'): await processor.store_document_embeddings(message) + # Verify index was created with correct dimension + expected_index_name = "d-test_user-test_collection-3" # 3 dimensions + processor.pinecone.create_index.assert_called_once() + create_call = processor.pinecone.create_index.call_args + assert create_call[1]['name'] == expected_index_name + assert create_call[1]['dimension'] == 3 + + # Verify upsert was still called + mock_index.upsert.assert_called_once() + @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunk(self, processor): """Test storing document embeddings with empty chunk (should be skipped)""" @@ -345,7 +357,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_document_embeddings_validation_before_creation(self, processor): - """Test that validation error occurs before creation attempts""" + """Test that lazy creation happens when index doesn't exist""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' @@ -359,13 +371,18 @@ class TestPineconeDocEmbeddingsStorageProcessor: # Mock index doesn't exist processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index - with pytest.raises(ValueError, match="Collection .* does not exist"): + with patch('uuid.uuid4', return_value='test-id'): await processor.store_document_embeddings(message) + # Verify index was created + processor.pinecone.create_index.assert_called_once() + @pytest.mark.asyncio async def test_store_document_embeddings_validates_before_timeout(self, processor): - """Test that validation error occurs before timeout checks""" + """Test that lazy creation works correctly""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' @@ -379,10 +396,16 @@ class TestPineconeDocEmbeddingsStorageProcessor: # Mock index doesn't exist processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index - with pytest.raises(ValueError, match="Collection .* does not exist"): + with patch('uuid.uuid4', return_value='test-id'): await processor.store_document_embeddings(message) + # Verify index was created and used + processor.pinecone.create_index.assert_called_once() + mock_index.upsert.assert_called_once() + @pytest.mark.asyncio async def test_store_document_embeddings_unicode_content(self, processor): """Test storing document embeddings with Unicode content""" diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index ef00c3f9..f99d9883 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -103,8 +103,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message) # Assert - # Verify collection existence was checked - expected_collection = 'd_test_user_test_collection' + # Verify collection existence was checked (with dimension suffix) + expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3] mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) # Verify upsert was called @@ -112,7 +112,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Verify upsert parameters upsert_call_args = mock_qdrant_instance.upsert.call_args - assert upsert_call_args[1]['collection_name'] == expected_collection + assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3' assert len(upsert_call_args[1]['points']) == 1 point = upsert_call_args[1]['points'][0] @@ -272,18 +272,21 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert # Should not call upsert for empty chunks mock_qdrant_instance.upsert.assert_not_called() - # But collection_exists should be called for validation - mock_qdrant_instance.collection_exists.assert_called_once() + # collection_exists should NOT be called since we return early for empty chunks + mock_qdrant_instance.collection_exists.assert_not_called() @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client): - """Test that writing to non-existent collection raises ValueError""" + async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client): + """Test that writing to non-existent collection creates it lazily""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist mock_qdrant_client.return_value = mock_qdrant_instance + mock_uuid.uuid4.return_value = MagicMock() + mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') config = { 'store_uri': 'http://localhost:6333', @@ -305,19 +308,36 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] - # Act & Assert - with pytest.raises(ValueError, match="Collection .* does not exist"): - await processor.store_document_embeddings(mock_message) + # Act + await processor.store_document_embeddings(mock_message) + + # Assert - collection should be lazily created + expected_collection = 'd_new_user_new_collection_5' # 5 dimensions + mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + mock_qdrant_instance.create_collection.assert_called_once() + + # Verify create_collection was called with correct parameters + create_call = mock_qdrant_instance.create_collection.call_args + assert create_call[1]['collection_name'] == expected_collection + assert create_call[1]['vectors_config'].size == 5 + + # Verify upsert was still called + mock_qdrant_instance.upsert.assert_called_once() @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client): - """Test that validation error occurs before connection errors""" + async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client): + """Test that collection creation errors are propagated""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist + # Simulate creation failure + mock_qdrant_instance.create_collection.side_effect = Exception("Connection error") mock_qdrant_client.return_value = mock_qdrant_instance + mock_uuid.uuid4.return_value = MagicMock() + mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') config = { 'store_uri': 'http://localhost:6333', @@ -339,8 +359,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_message.chunks = [mock_chunk] - # Act & Assert - with pytest.raises(ValueError, match="Collection .* does not exist"): + # Act & Assert - should propagate the creation error + with pytest.raises(Exception, match="Connection error"): await processor.store_document_embeddings(mock_message) @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @@ -398,7 +418,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message2) # Assert - expected_collection = 'd_cache_user_cache_collection' + expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions # Verify collection existence is checked on each write mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) @@ -407,15 +427,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_instance.upsert.assert_called_once() @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client): + async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client): """Test that different vector dimensions create different collections""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance - + mock_uuid.uuid4.return_value = MagicMock() + mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid') + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -424,35 +447,39 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + # Create mock message with different dimension vectors mock_message = MagicMock() mock_message.metadata.user = 'dim_user' mock_message.metadata.collection = 'dim_collection' - + mock_chunk = MagicMock() mock_chunk.chunk.decode.return_value = 'dimension test chunk' mock_chunk.vectors = [ [0.1, 0.2], # 2 dimensions [0.3, 0.4, 0.5] # 3 dimensions ] - + mock_message.chunks = [mock_chunk] - + # Act await processor.store_document_embeddings(mock_message) # Assert - # Should check existence of the same collection (dimensions no longer create separate collections) - expected_collection = 'd_dim_user_dim_collection' - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + # Should check existence of DIFFERENT collections for each dimension + assert mock_qdrant_instance.collection_exists.call_count == 2 - # Should upsert to the same collection for both vectors + # Verify the two different collection names were checked + collection_exists_calls = [call[0][0] for call in mock_qdrant_instance.collection_exists.call_args_list] + assert 'd_dim_user_dim_collection_2' in collection_exists_calls # 2-dim vector + assert 'd_dim_user_dim_collection_3' in collection_exists_calls # 3-dim vector + + # Should upsert to different collections for each vector assert mock_qdrant_instance.upsert.call_count == 2 upsert_calls = mock_qdrant_instance.upsert.call_args_list - assert upsert_calls[0][1]['collection_name'] == expected_collection - assert upsert_calls[1][1]['collection_name'] == expected_collection + assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' + assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py index 854a03b2..74260c1b 100644 --- a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -134,8 +134,8 @@ class TestPineconeGraphEmbeddingsStorageProcessor: with patch('uuid.uuid4', side_effect=['id1', 'id2']): await processor.store_graph_embeddings(message) - # Verify index name and operations - expected_index_name = "t-test_user-test_collection" + # Verify index name and operations (with dimension suffix) + expected_index_name = "t-test_user-test_collection-3" # 3 dimensions processor.pinecone.Index.assert_called_with(expected_index_name) # Verify upsert was called for each vector @@ -179,7 +179,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_index_validation(self, processor): - """Test that writing to non-existent index raises ValueError""" + """Test that writing to non-existent index creates it lazily""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' @@ -193,10 +193,22 @@ class TestPineconeGraphEmbeddingsStorageProcessor: # Mock index doesn't exist processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index - with pytest.raises(ValueError, match="Collection .* does not exist"): + with patch('uuid.uuid4', return_value='test-id'): await processor.store_graph_embeddings(message) + # Verify index was created with correct dimension + expected_index_name = "t-test_user-test_collection-3" # 3 dimensions + processor.pinecone.create_index.assert_called_once() + create_call = processor.pinecone.create_index.call_args + assert create_call[1]['name'] == expected_index_name + assert create_call[1]['dimension'] == 3 + + # Verify upsert was still called + mock_index.upsert.assert_called_once() + @pytest.mark.asyncio async def test_store_graph_embeddings_empty_entity_value(self, processor): """Test storing graph embeddings with empty entity value (should be skipped)""" @@ -267,11 +279,16 @@ class TestPineconeGraphEmbeddingsStorageProcessor: with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): await processor.store_graph_embeddings(message) - # Verify same index was used for all dimensions - expected_index_name = 't-test_user-test_collection' - processor.pinecone.Index.assert_called_with(expected_index_name) + # Verify different indexes were used for different dimensions + index_calls = processor.pinecone.Index.call_args_list + assert len(index_calls) == 3 + # Extract index names from calls + index_names = [call[0][0] for call in index_calls] + assert 't-test_user-test_collection-2' in index_names # 2D vector + assert 't-test_user-test_collection-4' in index_names # 4D vector + assert 't-test_user-test_collection-3' in index_names # 3D vector - # Verify all vectors were upserted to the same index + # Verify all vectors were upserted (to their respective indexes) assert mock_index.upsert.call_count == 3 @pytest.mark.asyncio @@ -316,7 +333,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_validation_before_creation(self, processor): - """Test that validation error occurs before any creation attempts""" + """Test that lazy creation happens when index doesn't exist""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' @@ -330,13 +347,18 @@ class TestPineconeGraphEmbeddingsStorageProcessor: # Mock index doesn't exist processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index - with pytest.raises(ValueError, match="Collection .* does not exist"): + with patch('uuid.uuid4', return_value='test-id'): await processor.store_graph_embeddings(message) + # Verify index was created + processor.pinecone.create_index.assert_called_once() + @pytest.mark.asyncio async def test_store_graph_embeddings_validates_before_timeout(self, processor): - """Test that validation error occurs before timeout checks""" + """Test that lazy creation works correctly""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' @@ -350,10 +372,16 @@ class TestPineconeGraphEmbeddingsStorageProcessor: # Mock index doesn't exist processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index - with pytest.raises(ValueError, match="Collection .* does not exist"): + with patch('uuid.uuid4', return_value='test-id'): await processor.store_graph_embeddings(message) + # Verify index was created and used + processor.pinecone.create_index.assert_called_once() + mock_index.upsert.assert_called_once() + def test_add_args_method(self): """Test that add_args properly configures argument parser""" from argparse import ArgumentParser diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index 4e7b492d..c4b603c9 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -44,29 +44,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): assert hasattr(processor, 'qdrant') assert processor.qdrant == mock_qdrant_instance - @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client): - """Test get_collection validates that collection exists""" - # Arrange - mock_base_init.return_value = None - mock_qdrant_instance = MagicMock() - mock_qdrant_instance.collection_exists.return_value = False - mock_qdrant_client.return_value = mock_qdrant_instance - - config = { - 'store_uri': 'http://localhost:6333', - 'api_key': 'test-api-key', - 'taskgroup': AsyncMock(), - 'id': 'test-qdrant-processor' - } - - processor = Processor(**config) - - # Act & Assert - with pytest.raises(ValueError, match="Collection .* does not exist"): - processor.get_collection(user='test_user', collection='test_collection') - @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') @@ -103,114 +80,22 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_graph_embeddings(mock_message) # Assert - # Verify collection existence was checked - expected_collection = 't_test_user_test_collection' + # Verify collection existence was checked (with dimension suffix) + expected_collection = 't_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3] mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) - + # Verify upsert was called mock_qdrant_instance.upsert.assert_called_once() - + # Verify upsert parameters upsert_call_args = mock_qdrant_instance.upsert.call_args - assert upsert_call_args[1]['collection_name'] == expected_collection + assert upsert_call_args[1]['collection_name'] == 't_test_user_test_collection_3' assert len(upsert_call_args[1]['points']) == 1 point = upsert_call_args[1]['points'][0] assert point.vector == [0.1, 0.2, 0.3] assert point.payload['entity'] == 'test_entity' - @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client): - """Test get_collection uses existing collection without creating new one""" - # Arrange - mock_base_init.return_value = None - mock_qdrant_instance = MagicMock() - mock_qdrant_instance.collection_exists.return_value = True # Collection exists - mock_qdrant_client.return_value = mock_qdrant_instance - - config = { - 'store_uri': 'http://localhost:6333', - 'api_key': 'test-api-key', - 'taskgroup': AsyncMock(), - 'id': 'test-qdrant-processor' - } - - processor = Processor(**config) - - # Act - collection_name = processor.get_collection(user='existing_user', collection='existing_collection') - - # Assert - expected_name = 't_existing_user_existing_collection' - assert collection_name == expected_name - - # Verify collection existence check was performed - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name) - # Verify create_collection was NOT called - mock_qdrant_instance.create_collection.assert_not_called() - - @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client): - """Test get_collection validates collection existence on each call""" - # Arrange - mock_base_init.return_value = None - mock_qdrant_instance = MagicMock() - mock_qdrant_instance.collection_exists.return_value = True - mock_qdrant_client.return_value = mock_qdrant_instance - - config = { - 'store_uri': 'http://localhost:6333', - 'api_key': 'test-api-key', - 'taskgroup': AsyncMock(), - 'id': 'test-qdrant-processor' - } - - processor = Processor(**config) - - # First call - collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection') - - # Reset mock to track second call - mock_qdrant_instance.reset_mock() - mock_qdrant_instance.collection_exists.return_value = True - - # Act - Second call with same parameters - collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection') - - # Assert - expected_name = 't_cache_user_cache_collection' - assert collection_name1 == expected_name - assert collection_name2 == expected_name - - # Verify collection existence check happens on each call - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name) - mock_qdrant_instance.create_collection.assert_not_called() - - @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client): - """Test get_collection raises ValueError when collection doesn't exist""" - # Arrange - mock_base_init.return_value = None - mock_qdrant_instance = MagicMock() - mock_qdrant_instance.collection_exists.return_value = False - mock_qdrant_client.return_value = mock_qdrant_instance - - config = { - 'store_uri': 'http://localhost:6333', - 'api_key': 'test-api-key', - 'taskgroup': AsyncMock(), - 'id': 'test-qdrant-processor' - } - - processor = Processor(**config) - - # Act & Assert - with pytest.raises(ValueError, match="Collection .* does not exist"): - processor.get_collection(user='error_user', collection='error_collection') - @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') diff --git a/tests/unit/test_text_completion/test_openai_processor.py b/tests/unit/test_text_completion/test_openai_processor.py index a9a43b37..352af062 100644 --- a/tests/unit/test_text_completion/test_openai_processor.py +++ b/tests/unit/test_text_completion/test_openai_processor.py @@ -102,11 +102,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): }] }], temperature=0.0, - max_tokens=4096, - top_p=1, - frequency_penalty=0, - presence_penalty=0, - response_format={"type": "text"} + max_tokens=4096 ) @patch('trustgraph.model.text_completion.openai.llm.OpenAI') @@ -385,10 +381,6 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): assert call_args[1]['model'] == 'gpt-3.5-turbo' assert call_args[1]['temperature'] == 0.5 assert call_args[1]['max_tokens'] == 1024 - assert call_args[1]['top_p'] == 1 - assert call_args[1]['frequency_penalty'] == 0 - assert call_args[1]['presence_penalty'] == 0 - assert call_args[1]['response_format'] == {"type": "text"} @patch('trustgraph.model.text_completion.openai.llm.OpenAI') diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py index 556d32ff..a1442d41 100644 --- a/trustgraph-base/trustgraph/base/embeddings_service.py +++ b/trustgraph-base/trustgraph/base/embeddings_service.py @@ -9,7 +9,7 @@ from prometheus_client import Histogram from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error from .. exceptions import TooManyRequests -from .. base import FlowProcessor, ConsumerSpec, ProducerSpec +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec # Module logger logger = logging.getLogger(__name__) @@ -45,6 +45,12 @@ class EmbeddingsService(FlowProcessor): ) ) + self.register_specification( + ParameterSpec( + name = "model", + ) + ) + async def on_request(self, msg, consumer, flow): try: @@ -57,7 +63,9 @@ class EmbeddingsService(FlowProcessor): logger.debug(f"Handling embeddings request {id}...") - vectors = await self.on_embeddings(request.text) + # Pass model from request if specified (non-empty), otherwise use default + model = flow("model") + vectors = await self.on_embeddings(request.text, model=model) await flow("response").send( EmbeddingsResponse( diff --git a/trustgraph-base/trustgraph/rdf.py b/trustgraph-base/trustgraph/rdf.py index ef1da183..32799b8d 100644 --- a/trustgraph-base/trustgraph/rdf.py +++ b/trustgraph-base/trustgraph/rdf.py @@ -1,4 +1,5 @@ +RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" RDF_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" DEFINITION = "http://www.w3.org/2004/02/skos/core#definition" SUBJECT_OF = "https://schema.org/subjectOf" diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 8f23081c..b90edac6 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index ae69abdd..3b9f197b 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", "requests", "pulsar-client", "aiohttp", diff --git a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py index b48c6d86..05e3823c 100644 --- a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py @@ -7,6 +7,7 @@ specification. This script stores MCP tool configurations with: - id: Unique identifier for the tool - remote-name: Name used by the MCP server (defaults to id) - url: MCP server endpoint URL +- auth-token: Optional bearer token for authentication Configurations are stored in the 'mcp' configuration group and can be referenced by agent tools using the 'mcp-tool' type. @@ -25,17 +26,24 @@ def set_mcp_tool( id : str, remote_name : str, tool_url : str, + auth_token : str = None, ): api = Api(url).config() + # Build the MCP tool configuration + config = { + "remote-name": remote_name, + "url": tool_url, + } + + if auth_token: + config["auth-token"] = auth_token + # Store the MCP tool configuration in the 'mcp' group values = api.put([ ConfigValue( - type="mcp", key=id, value=json.dumps({ - "remote-name": remote_name, - "url": tool_url, - }) + type="mcp", key=id, value=json.dumps(config) ) ]) @@ -45,12 +53,15 @@ def main(): prog='tg-set-mcp-tool', description=__doc__, epilog=textwrap.dedent(''' - MCP tools are configured with just a name and URL. The URL should point + MCP tools are configured with a name and URL. The URL should point to the MCP server endpoint that provides the tool functionality. - + Optionally, an auth-token can be provided for secured endpoints. + Examples: %(prog)s --id weather --tool-url "http://localhost:3000/weather" %(prog)s --id calculator --tool-url "http://mcp-tools.example.com/calc" + %(prog)s --id secure-tool --tool-url "https://api.example.com/mcp" \\ + --auth-token "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." ''').strip(), formatter_class=argparse.RawDescriptionHelpFormatter ) @@ -79,6 +90,12 @@ def main(): help='MCP tool URL endpoint', ) + parser.add_argument( + '--auth-token', + required=False, + help='Bearer token for authentication (optional)', + ) + args = parser.parse_args() try: @@ -98,7 +115,8 @@ def main(): url=args.api_url, id=args.id, remote_name=remote_name, - tool_url=args.tool_url + tool_url=args.tool_url, + auth_token=args.auth_token ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py index c22b69ed..da0154ed 100644 --- a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py @@ -27,6 +27,12 @@ def show_config(url): table.append(("remote-name", data["remote-name"])) table.append(("url", data["url"])) + # Display auth status (masked for security) + if "auth-token" in data and data["auth-token"]: + table.append(("auth", "Yes (configured)")) + else: + table.append(("auth", "No")) + print() print(tabulate.tabulate( diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index c1d105c5..39e03aff 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", - "trustgraph-flow>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", + "trustgraph-flow>=1.5,<1.6", "torch", "urllib3", "transformers", diff --git a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py index f1abbfae..8c4d571b 100755 --- a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py +++ b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py @@ -26,10 +26,31 @@ class Processor(EmbeddingsService): **params | { "model": model } ) - logger.info(f"Loading HuggingFace embeddings model: {model}") - self.embeddings = HuggingFaceEmbeddings(model_name=model) + self.default_model = model - async def on_embeddings(self, text): + # Cache for currently loaded model + self.cached_model_name = None + self.embeddings = None + + # Load the default model + self._load_model(model) + + def _load_model(self, model_name): + """Load a model, caching it for reuse""" + if self.cached_model_name != model_name: + logger.info(f"Loading HuggingFace embeddings model: {model_name}") + self.embeddings = HuggingFaceEmbeddings(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"HuggingFace model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + async def on_embeddings(self, text, model=None): + + use_model = model or self.default_model + + # Reload model if it has changed + self._load_model(use_model) embeds = self.embeddings.embed_documents([text]) logger.debug("Embeddings generation complete") diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index c1ecd346..452ebddf 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,12 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", "aiohttp", "anthropic", - "cassandra-driver", + "scylla-driver", "cohere", "cryptography", + "faiss-cpu", "falkordb", "fastembed", "google-genai", @@ -29,6 +30,7 @@ dependencies = [ "minio", "mistralai", "neo4j", + "nltk", "ollama", "openai", "pinecone[grpc]", @@ -82,6 +84,7 @@ kg-extract-definitions = "trustgraph.extract.kg.definitions:run" kg-extract-objects = "trustgraph.extract.kg.objects:run" kg-extract-relationships = "trustgraph.extract.kg.relationships:run" kg-extract-topics = "trustgraph.extract.kg.topics:run" +kg-extract-ontology = "trustgraph.extract.kg.ontology:run" kg-manager = "trustgraph.cores:run" kg-store = "trustgraph.storage.knowledge:run" librarian = "trustgraph.librarian:run" diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index 96ff73f7..3858d06b 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -56,10 +56,16 @@ class Service(ToolService): else: remote_name = name + # Build headers with optional bearer token + headers = {} + if "auth-token" in self.mcp_services[name]: + token = self.mcp_services[name]["auth-token"] + headers["Authorization"] = f"Bearer {token}" + logger.info(f"Invoking {remote_name} at {url}") - # Connect to a streamable HTTP server - async with streamablehttp_client(url) as ( + # Connect to a streamable HTTP server with headers + async with streamablehttp_client(url, headers=headers) as ( read_stream, write_stream, _, diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 06bf7610..30b2df7a 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -317,7 +317,7 @@ class Processor(AgentService): AgentStep( thought=h.thought, action=h.name, - arguments=h.arguments, + arguments={k: str(v) for k, v in h.arguments.items()}, observation=h.observation ) for h in history diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index dc0f2c2b..116abe02 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -334,13 +334,14 @@ class KnowledgeGraph: count += 1 - # Execute batch every 100 triples to avoid oversized batches - if count % 100 == 0: + # Execute batch every 25 triples to avoid oversized batches + # (Each triple adds ~4 statements, so 25 triples = ~100 statements) + if count % 25 == 0: self.session.execute(batch) batch = BatchStatement() # Execute remaining deletions - if count % 100 != 0: + if count % 25 != 0: self.session.execute(batch) # Step 3: Delete collection metadata diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index a96d06df..4047a9e3 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -50,24 +50,26 @@ class DocVectors: logger.debug(f"Reload at {self.next_reload}") def collection_exists(self, user, collection): - """Check if collection exists (dimension-independent check)""" - collection_name = make_safe_collection_name(user, collection, self.prefix) - return self.client.has_collection(collection_name) + """ + Check if any collection exists for this user/collection combination. + Since collections are dimension-specific, this checks if ANY dimension variant exists. + """ + base_name = make_safe_collection_name(user, collection, self.prefix) + prefix = f"{base_name}_" + all_collections = self.client.list_collections() + return any(coll.startswith(prefix) for coll in all_collections) def create_collection(self, user, collection, dimension=384): - """Create collection with default dimension""" - collection_name = make_safe_collection_name(user, collection, self.prefix) - - if self.client.has_collection(collection_name): - logger.info(f"Collection {collection_name} already exists") - return - - self.init_collection(dimension, user, collection) - logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") + """ + No-op for explicit collection creation. + Collections are created lazily on first insert with actual dimension. + """ + logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") def init_collection(self, dimension, user, collection): - collection_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(user, collection, self.prefix) + collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( name="id", @@ -115,6 +117,7 @@ class DocVectors: ) self.collections[(dimension, user, collection)] = collection_name + logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") def insert(self, embeds, doc, user, collection): @@ -139,8 +142,15 @@ class DocVectors: dim = len(embeds) + # Check if collection exists - return empty if not if (dim, user, collection) not in self.collections: - self.init_collection(dim, user, collection) + base_name = make_safe_collection_name(user, collection, self.prefix) + collection_name = f"{base_name}_{dim}" + if not self.client.has_collection(collection_name): + logger.info(f"Collection {collection_name} does not exist, returning empty results") + return [] + # Collection exists but not in cache, add it + self.collections[(dim, user, collection)] = collection_name coll = self.collections[(dim, user, collection)] @@ -172,19 +182,27 @@ class DocVectors: return res def delete_collection(self, user, collection): - """Delete a collection for the given user and collection""" - collection_name = make_safe_collection_name(user, collection, self.prefix) + """ + Delete all dimension variants of the collection for the given user/collection. + Since collections are created with dimension suffixes, we need to find and delete all. + """ + base_name = make_safe_collection_name(user, collection, self.prefix) + prefix = f"{base_name}_" - # Check if collection exists - if self.client.has_collection(collection_name): - # Drop the collection - self.client.drop_collection(collection_name) - logger.info(f"Deleted Milvus collection: {collection_name}") + # Get all collections and filter for matches + all_collections = self.client.list_collections() + matching_collections = [coll for coll in all_collections if coll.startswith(prefix)] - # Remove from our local cache - keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] - for key in keys_to_remove: - del self.collections[key] + if not matching_collections: + logger.info(f"No collections found matching prefix {prefix}") else: - logger.info(f"Collection {collection_name} does not exist, nothing to delete") + for collection_name in matching_collections: + self.client.drop_collection(collection_name) + logger.info(f"Deleted Milvus collection: {collection_name}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + + # Remove from our local cache + keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + for key in keys_to_remove: + del self.collections[key] diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index b3ed2a9f..4a106f27 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -50,24 +50,26 @@ class EntityVectors: logger.debug(f"Reload at {self.next_reload}") def collection_exists(self, user, collection): - """Check if collection exists (dimension-independent check)""" - collection_name = make_safe_collection_name(user, collection, self.prefix) - return self.client.has_collection(collection_name) + """ + Check if any collection exists for this user/collection combination. + Since collections are dimension-specific, this checks if ANY dimension variant exists. + """ + base_name = make_safe_collection_name(user, collection, self.prefix) + prefix = f"{base_name}_" + all_collections = self.client.list_collections() + return any(coll.startswith(prefix) for coll in all_collections) def create_collection(self, user, collection, dimension=384): - """Create collection with default dimension""" - collection_name = make_safe_collection_name(user, collection, self.prefix) - - if self.client.has_collection(collection_name): - logger.info(f"Collection {collection_name} already exists") - return - - self.init_collection(dimension, user, collection) - logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") + """ + No-op for explicit collection creation. + Collections are created lazily on first insert with actual dimension. + """ + logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") def init_collection(self, dimension, user, collection): - collection_name = make_safe_collection_name(user, collection, self.prefix) + base_name = make_safe_collection_name(user, collection, self.prefix) + collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( name="id", @@ -115,6 +117,7 @@ class EntityVectors: ) self.collections[(dimension, user, collection)] = collection_name + logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") def insert(self, embeds, entity, user, collection): @@ -139,8 +142,15 @@ class EntityVectors: dim = len(embeds) + # Check if collection exists - return empty if not if (dim, user, collection) not in self.collections: - self.init_collection(dim, user, collection) + base_name = make_safe_collection_name(user, collection, self.prefix) + collection_name = f"{base_name}_{dim}" + if not self.client.has_collection(collection_name): + logger.info(f"Collection {collection_name} does not exist, returning empty results") + return [] + # Collection exists but not in cache, add it + self.collections[(dim, user, collection)] = collection_name coll = self.collections[(dim, user, collection)] @@ -172,19 +182,27 @@ class EntityVectors: return res def delete_collection(self, user, collection): - """Delete a collection for the given user and collection""" - collection_name = make_safe_collection_name(user, collection, self.prefix) + """ + Delete all dimension variants of the collection for the given user/collection. + Since collections are created with dimension suffixes, we need to find and delete all. + """ + base_name = make_safe_collection_name(user, collection, self.prefix) + prefix = f"{base_name}_" - # Check if collection exists - if self.client.has_collection(collection_name): - # Drop the collection - self.client.drop_collection(collection_name) - logger.info(f"Deleted Milvus collection: {collection_name}") + # Get all collections and filter for matches + all_collections = self.client.list_collections() + matching_collections = [coll for coll in all_collections if coll.startswith(prefix)] - # Remove from our local cache - keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] - for key in keys_to_remove: - del self.collections[key] + if not matching_collections: + logger.info(f"No collections found matching prefix {prefix}") else: - logger.info(f"Collection {collection_name} does not exist, nothing to delete") + for collection_name in matching_collections: + self.client.drop_collection(collection_name) + logger.info(f"Deleted Milvus collection: {collection_name}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") + + # Remove from our local cache + keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + for key in keys_to_remove: + del self.collections[key] diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index 0357e4a3..d1ce93ca 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -27,10 +27,31 @@ class Processor(EmbeddingsService): **params | { "model": model } ) - logger.info("Loading FastEmbed model...") - self.embeddings = TextEmbedding(model_name = model) + self.default_model = model - async def on_embeddings(self, text): + # Cache for currently loaded model + self.cached_model_name = None + self.embeddings = None + + # Load the default model + self._load_model(model) + + def _load_model(self, model_name): + """Load a model, caching it for reuse""" + if self.cached_model_name != model_name: + logger.info(f"Loading FastEmbed model: {model_name}") + self.embeddings = TextEmbedding(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"FastEmbed model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + async def on_embeddings(self, text, model=None): + + use_model = model or self.default_model + + # Reload model if it has changed + self._load_model(use_model) vecs = self.embeddings.embed([text]) diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index 3c0776f9..c951252e 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -28,12 +28,14 @@ class Processor(EmbeddingsService): ) self.client = Client(host=ollama) - self.model = model + self.default_model = model - async def on_embeddings(self, text): + async def on_embeddings(self, text, model=None): + + use_model = model or self.default_model embeds = self.client.embed( - model = self.model, + model = use_model, input = text ) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/__init__.py b/trustgraph-flow/trustgraph/extract/kg/ontology/__init__.py new file mode 100644 index 00000000..102255a1 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/__init__.py @@ -0,0 +1 @@ +from . extract import * \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py new file mode 100644 index 00000000..12832eaf --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -0,0 +1,848 @@ +""" +OntoRAG: Ontology-based knowledge extraction service. +Extracts ontology-conformant triples from text chunks. +""" + +import json +import logging +import asyncio +from typing import List, Dict, Any, Optional + +from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import EntityContext, EntityContexts +from .... schema import PromptRequest, PromptResponse +from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base import PromptClientSpec, EmbeddingsClientSpec + +from .ontology_loader import OntologyLoader +from .ontology_embedder import OntologyEmbedder +from .vector_store import InMemoryVectorStore +from .text_processor import TextProcessor +from .ontology_selector import OntologySelector, OntologySubset + +logger = logging.getLogger(__name__) + +default_ident = "kg-extract-ontology" +default_concurrency = 1 + +# URI prefix mappings for common namespaces +URI_PREFIXES = { + "rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#", + "rdfs:": "http://www.w3.org/2000/01/rdf-schema#", + "owl:": "http://www.w3.org/2002/07/owl#", + "skos:": "http://www.w3.org/2004/02/skos/core#", + "schema:": "https://schema.org/", + "xsd:": "http://www.w3.org/2001/XMLSchema#", +} + + +class Processor(FlowProcessor): + """Main OntoRAG extraction processor.""" + + def __init__(self, **params): + id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) + + super(Processor, self).__init__( + **params | { + "id": id, + "concurrency": concurrency, + } + ) + + # Register specifications + self.register_specification( + ConsumerSpec( + name="input", + schema=Chunk, + handler=self.on_message, + concurrency=concurrency, + ) + ) + + self.register_specification( + PromptClientSpec( + request_name="prompt-request", + response_name="prompt-response", + ) + ) + + self.register_specification( + EmbeddingsClientSpec( + request_name="embeddings-request", + response_name="embeddings-response" + ) + ) + + self.register_specification( + ProducerSpec( + name="triples", + schema=Triples + ) + ) + + self.register_specification( + ProducerSpec( + name="entity-contexts", + schema=EntityContexts + ) + ) + + # Register config handler for ontology updates + self.register_config_handler(self.on_ontology_config) + + # Shared components (not flow-specific) + self.ontology_loader = OntologyLoader() + self.text_processor = TextProcessor() + + # Per-flow components (each flow gets its own embedder/vector store/selector) + self.flow_components = {} # flow_id -> {embedder, vector_store, selector} + + # Configuration + self.top_k = params.get("top_k", 10) + self.similarity_threshold = params.get("similarity_threshold", 0.3) + + # Track loaded ontology version + self.current_ontology_version = None + self.loaded_ontology_ids = set() + + async def initialize_flow_components(self, flow): + """Initialize per-flow OntoRAG components. + + Each flow gets its own vector store and embedder to support + different embedding models across flows. The vector store dimension + is auto-detected from the embeddings service. + + Args: + flow: Flow object for this processing context + + Returns: + flow_id: Identifier for this flow's components + """ + # Use flow object as identifier + flow_id = id(flow) + + if flow_id in self.flow_components: + return flow_id # Already initialized for this flow + + try: + logger.info(f"Initializing components for flow {flow_id}") + + # Use embeddings client directly (no wrapper needed) + embeddings_client = flow("embeddings-request") + + # Detect embedding dimension by embedding a test string + logger.info("Detecting embedding dimension from embeddings service...") + test_embedding_response = await embeddings_client.embed("test") + test_embedding = test_embedding_response[0] # Extract from [[vector]] + dimension = len(test_embedding) + logger.info(f"Detected embedding dimension: {dimension}") + + # Initialize vector store with detected dimension + vector_store = InMemoryVectorStore( + dimension=dimension, + index_type='flat' + ) + + ontology_embedder = OntologyEmbedder( + embedding_service=embeddings_client, + vector_store=vector_store + ) + + # Embed all loaded ontologies for this flow + if self.ontology_loader.get_all_ontologies(): + logger.info(f"Embedding ontologies for flow {flow_id}") + for ont_id, ontology in self.ontology_loader.get_all_ontologies().items(): + await ontology_embedder.embed_ontology(ontology) + logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}") + + # Initialize ontology selector + ontology_selector = OntologySelector( + ontology_embedder=ontology_embedder, + ontology_loader=self.ontology_loader, + top_k=self.top_k, + similarity_threshold=self.similarity_threshold + ) + + # Store flow-specific components + self.flow_components[flow_id] = { + 'embedder': ontology_embedder, + 'vector_store': vector_store, + 'selector': ontology_selector, + 'dimension': dimension + } + + logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})") + return flow_id + + except Exception as e: + logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True) + raise + + async def on_ontology_config(self, config, version): + """ + Handle ontology configuration updates from ConfigPush queue. + + Parses and stores ontologies. Embedding happens per-flow on first message. + + Called automatically when: + - Processor starts (gets full config history via start_of_messages=True) + - Config service pushes updates (immediate event-driven notification) + + Args: + config: Full configuration map - config[type][key] = value + version: Config version number (monotonically increasing) + """ + try: + logger.info(f"Received ontology config update, version={version}") + + # Skip if we've already processed this version + if version == self.current_ontology_version: + logger.debug(f"Already at version {version}, skipping") + return + + # Extract ontology configurations + if "ontology" not in config: + logger.warning("No 'ontology' section in config") + return + + ontology_configs = config["ontology"] + + # Parse ontology definitions + ontologies = {} + for ont_id, ont_json in ontology_configs.items(): + try: + ontologies[ont_id] = json.loads(ont_json) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse ontology '{ont_id}': {e}") + continue + + logger.info(f"Loaded {len(ontologies)} ontology definitions") + + # Determine what changed (for incremental updates) + new_ids = set(ontologies.keys()) + added_ids = new_ids - self.loaded_ontology_ids + removed_ids = self.loaded_ontology_ids - new_ids + updated_ids = new_ids & self.loaded_ontology_ids # May have changed content + + if added_ids: + logger.info(f"New ontologies: {added_ids}") + if removed_ids: + logger.info(f"Removed ontologies: {removed_ids}") + if updated_ids: + logger.info(f"Updated ontologies: {updated_ids}") + + # Update ontology loader's internal state + self.ontology_loader.update_ontologies(ontologies) + + # Clear all flow components to force re-embedding with new ontologies + if added_ids or removed_ids or updated_ids: + logger.info("Clearing flow components to trigger re-embedding") + self.flow_components.clear() + + # Update tracking + self.current_ontology_version = version + self.loaded_ontology_ids = new_ids + + logger.info(f"Ontology config update complete, version={version}") + + except Exception as e: + logger.error(f"Failed to process ontology config: {e}", exc_info=True) + + async def on_message(self, msg, consumer, flow): + """Process incoming chunk message.""" + v = msg.value() + logger.info(f"Extracting ontology-based triples from {v.metadata.id}...") + + # Initialize flow-specific components if needed + flow_id = await self.initialize_flow_components(flow) + components = self.flow_components[flow_id] + + chunk = v.chunk.decode("utf-8") + logger.debug(f"Processing chunk: {chunk[:200]}...") + + try: + # Process text into segments + segments = self.text_processor.process_chunk(chunk, extract_phrases=True) + logger.debug(f"Split chunk into {len(segments)} segments") + + # Select relevant ontology subset (using flow-specific selector) + ontology_subsets = await components['selector'].select_ontology_subset(segments) + + if not ontology_subsets: + logger.warning("No relevant ontology elements found for chunk") + # Emit empty outputs + await self.emit_triples( + flow("triples"), + v.metadata, + [] + ) + await self.emit_entity_contexts( + flow("entity-contexts"), + v.metadata, + [] + ) + return + + # Merge subsets if multiple ontologies matched + if len(ontology_subsets) > 1: + ontology_subset = components['selector'].merge_subsets(ontology_subsets) + else: + ontology_subset = ontology_subsets[0] + + logger.debug(f"Selected ontology subset with {len(ontology_subset.classes)} classes, " + f"{len(ontology_subset.object_properties)} object properties, " + f"{len(ontology_subset.datatype_properties)} datatype properties") + + # Build extraction prompt variables + prompt_variables = self.build_extraction_variables(chunk, ontology_subset) + + # Call prompt service for extraction + try: + # Use prompt() method with extract-with-ontologies prompt ID + triples_response = await flow("prompt-request").prompt( + id="extract-with-ontologies", + variables=prompt_variables + ) + logger.debug(f"Extraction response: {triples_response}") + + if not isinstance(triples_response, list): + logger.error("Expected list of triples from prompt service") + triples_response = [] + + except Exception as e: + logger.error(f"Prompt service error: {e}", exc_info=True) + triples_response = [] + + # Parse and validate triples + triples = self.parse_and_validate_triples(triples_response, ontology_subset) + + # Add metadata triples + for t in v.metadata.metadata: + triples.append(t) + + # Generate ontology definition triples + ontology_triples = self.build_ontology_triples(ontology_subset) + + # Combine extracted triples with ontology triples + all_triples = triples + ontology_triples + + # Build entity contexts from all triples (including ontology elements) + entity_contexts = self.build_entity_contexts(all_triples) + + # Emit all triples (extracted + ontology definitions) + await self.emit_triples( + flow("triples"), + v.metadata, + all_triples + ) + + # Emit entity contexts + await self.emit_entity_contexts( + flow("entity-contexts"), + v.metadata, + entity_contexts + ) + + logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples " + f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts") + + except Exception as e: + logger.error(f"OntoRAG extraction exception: {e}", exc_info=True) + # Emit empty outputs on error + await self.emit_triples( + flow("triples"), + v.metadata, + [] + ) + await self.emit_entity_contexts( + flow("entity-contexts"), + v.metadata, + [] + ) + + def build_extraction_variables(self, chunk: str, ontology_subset: OntologySubset) -> Dict[str, Any]: + """Build variables for ontology-based extraction prompt template. + + Args: + chunk: Text chunk to extract from + ontology_subset: Relevant ontology elements + + Returns: + Dict with template variables: text, classes, object_properties, datatype_properties + """ + return { + "text": chunk, + "classes": ontology_subset.classes, + "object_properties": ontology_subset.object_properties, + "datatype_properties": ontology_subset.datatype_properties + } + + def parse_and_validate_triples(self, triples_response: List[Any], + ontology_subset: OntologySubset) -> List[Triple]: + """Parse and validate extracted triples against ontology.""" + validated_triples = [] + ontology_id = ontology_subset.ontology_id + + for triple_data in triples_response: + try: + if isinstance(triple_data, dict): + subject = triple_data.get('subject', '') + predicate = triple_data.get('predicate', '') + object_val = triple_data.get('object', '') + + if not subject or not predicate or not object_val: + continue + + # Validate against ontology + if self.is_valid_triple(subject, predicate, object_val, ontology_subset): + # Expand URIs before creating Value objects + subject_uri = self.expand_uri(subject, ontology_subset, ontology_id) + predicate_uri = self.expand_uri(predicate, ontology_subset, ontology_id) + + # Object might be URI or literal - check before expanding + if self.is_uri(object_val) or self.should_expand_as_uri(object_val, ontology_subset): + object_uri = self.expand_uri(object_val, ontology_subset, ontology_id) + is_object_uri = True + else: + object_uri = object_val + is_object_uri = False + + # Create Triple object with expanded URIs + s_value = Value(value=subject_uri, is_uri=True) + p_value = Value(value=predicate_uri, is_uri=True) + o_value = Value(value=object_uri, is_uri=is_object_uri) + + validated_triples.append(Triple( + s=s_value, + p=p_value, + o=o_value + )) + else: + logger.debug(f"Invalid triple: ({subject}, {predicate}, {object_val})") + + except Exception as e: + logger.error(f"Error parsing triple: {e}") + + return validated_triples + + def should_expand_as_uri(self, value: str, ontology_subset: OntologySubset) -> bool: + """Check if a value should be treated as URI (not literal). + + Returns True if value is a class name, property name, or entity reference. + """ + # Check if it's a class or property from ontology + if value in ontology_subset.classes: + return True + if value in ontology_subset.object_properties: + return True + if value in ontology_subset.datatype_properties: + return True + # Check if it starts with a known prefix + for prefix in URI_PREFIXES.keys(): + if value.startswith(prefix): + return True + # Check if it looks like an entity reference (e.g., "recipe:cornish-pasty") + if ":" in value and not value.startswith("http"): + return True + return False + + def is_valid_triple(self, subject: str, predicate: str, object_val: str, + ontology_subset: OntologySubset) -> bool: + """Validate triple against ontology constraints.""" + # Special case for rdf:type + if predicate == "rdf:type" or predicate == str(RDF_TYPE): + # Check if object is a valid class + return object_val in ontology_subset.classes + + # Special case for rdfs:label + if predicate == "rdfs:label" or predicate == str(RDF_LABEL): + return True # Labels are always valid + + # Check if predicate is a valid property + is_obj_prop = predicate in ontology_subset.object_properties + is_dt_prop = predicate in ontology_subset.datatype_properties + + if not is_obj_prop and not is_dt_prop: + return False # Unknown property + + # TODO: Add more sophisticated validation (domain/range checking) + return True + + def expand_uri(self, value: str, ontology_subset: OntologySubset, ontology_id: str = "unknown") -> str: + """Expand prefix notation or short names to full URIs. + + Args: + value: Value to expand (e.g., "rdf:type", "Recipe", "has_ingredient") + ontology_subset: Ontology subset for class/property lookup + ontology_id: ID of the ontology for constructing instance URIs + + Returns: + Full URI string + """ + # Already a full URI + if value.startswith("http://") or value.startswith("https://"): + return value + + # Check standard prefixes (rdf:, rdfs:, etc.) + for prefix, namespace in URI_PREFIXES.items(): + if value.startswith(prefix): + return namespace + value[len(prefix):] + + # Check if it's an ontology class + if value in ontology_subset.classes: + class_def = ontology_subset.classes[value] + # class_def is a dict (from cls.__dict__ in ontology_selector) + if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']: + return class_def['uri'] + # Fallback: construct URI + return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" + + # Check if it's an ontology property + if value in ontology_subset.object_properties: + prop_def = ontology_subset.object_properties[value] + # prop_def is a dict (from prop.__dict__ in ontology_selector) + if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']: + return prop_def['uri'] + return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" + + if value in ontology_subset.datatype_properties: + prop_def = ontology_subset.datatype_properties[value] + # prop_def is a dict (from prop.__dict__ in ontology_selector) + if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']: + return prop_def['uri'] + return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" + + # Otherwise, treat as entity instance - construct unique URI + # Normalize the value for URI (lowercase, replace spaces with hyphens) + normalized = value.replace(" ", "-").lower() + return f"https://trustgraph.ai/{ontology_id}/{normalized}" + + def is_uri(self, value: str) -> bool: + """Check if value is already a full URI.""" + return value.startswith("http://") or value.startswith("https://") + + async def emit_triples(self, pub, metadata: Metadata, triples: List[Triple]): + """Emit triples to output.""" + t = Triples( + metadata=Metadata( + id=metadata.id, + metadata=[], + user=metadata.user, + collection=metadata.collection, + ), + triples=triples, + ) + await pub.send(t) + + async def emit_entity_contexts(self, pub, metadata: Metadata, entities: List[EntityContext]): + """Emit entity contexts to output.""" + ec = EntityContexts( + metadata=Metadata( + id=metadata.id, + metadata=[], + user=metadata.user, + collection=metadata.collection, + ), + entities=entities, + ) + await pub.send(ec) + + def build_ontology_triples(self, ontology_subset: OntologySubset) -> List[Triple]: + """Build triples describing the ontology elements themselves. + + Generates triples for classes and properties so they exist in the knowledge graph. + + Args: + ontology_subset: The ontology subset used for extraction + + Returns: + List of Triple objects describing ontology elements + """ + ontology_triples = [] + + # Generate triples for classes + for class_id, class_def in ontology_subset.classes.items(): + # Get URI for class + if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']: + class_uri = class_def['uri'] + else: + # Fallback to constructed URI + class_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{class_id}" + + # rdf:type owl:Class + ontology_triples.append(Triple( + s=Value(value=class_uri, is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True) + )) + + # rdfs:label (stored as 'labels' in OntologyClass.__dict__) + if isinstance(class_def, dict) and 'labels' in class_def: + labels = class_def['labels'] + if isinstance(labels, list) and labels: + label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0]) + ontology_triples.append(Triple( + s=Value(value=class_uri, is_uri=True), + p=Value(value=RDF_LABEL, is_uri=True), + o=Value(value=label_val, is_uri=False) + )) + + # rdfs:comment (stored as 'comment' in OntologyClass.__dict__) + if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']: + comment = class_def['comment'] + ontology_triples.append(Triple( + s=Value(value=class_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), + o=Value(value=comment, is_uri=False) + )) + + # rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__) + if isinstance(class_def, dict) and 'subclass_of' in class_def and class_def['subclass_of']: + parent = class_def['subclass_of'] + # Get parent URI + if parent in ontology_subset.classes: + parent_class_def = ontology_subset.classes[parent] + if isinstance(parent_class_def, dict) and 'uri' in parent_class_def and parent_class_def['uri']: + parent_uri = parent_class_def['uri'] + else: + parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}" + else: + parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}" + + ontology_triples.append(Triple( + s=Value(value=class_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True), + o=Value(value=parent_uri, is_uri=True) + )) + + # Generate triples for object properties + for prop_id, prop_def in ontology_subset.object_properties.items(): + # Get URI for property + if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']: + prop_uri = prop_def['uri'] + else: + prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}" + + # rdf:type owl:ObjectProperty + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True) + )) + + # rdfs:label (stored as 'labels' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'labels' in prop_def: + labels = prop_def['labels'] + if isinstance(labels, list) and labels: + label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0]) + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value=RDF_LABEL, is_uri=True), + o=Value(value=label_val, is_uri=False) + )) + + # rdfs:comment (stored as 'comment' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']: + comment = prop_def['comment'] + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), + o=Value(value=comment, is_uri=False) + )) + + # rdfs:domain (stored as 'domain' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']: + domain = prop_def['domain'] + # Get domain class URI + if domain in ontology_subset.classes: + domain_class_def = ontology_subset.classes[domain] + if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']: + domain_uri = domain_class_def['uri'] + else: + domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}" + else: + domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}" + + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True), + o=Value(value=domain_uri, is_uri=True) + )) + + # rdfs:range (stored as 'range' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'range' in prop_def and prop_def['range']: + range_val = prop_def['range'] + # Get range class URI + if range_val in ontology_subset.classes: + range_class_def = ontology_subset.classes[range_val] + if isinstance(range_class_def, dict) and 'uri' in range_class_def and range_class_def['uri']: + range_uri = range_class_def['uri'] + else: + range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}" + else: + range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}" + + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True), + o=Value(value=range_uri, is_uri=True) + )) + + # Generate triples for datatype properties + for prop_id, prop_def in ontology_subset.datatype_properties.items(): + # Get URI for property + if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']: + prop_uri = prop_def['uri'] + else: + prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}" + + # rdf:type owl:DatatypeProperty + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True) + )) + + # rdfs:label (stored as 'labels' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'labels' in prop_def: + labels = prop_def['labels'] + if isinstance(labels, list) and labels: + label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0]) + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value=RDF_LABEL, is_uri=True), + o=Value(value=label_val, is_uri=False) + )) + + # rdfs:comment (stored as 'comment' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']: + comment = prop_def['comment'] + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True), + o=Value(value=comment, is_uri=False) + )) + + # rdfs:domain (stored as 'domain' in OntologyProperty.__dict__) + if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']: + domain = prop_def['domain'] + # Get domain class URI + if domain in ontology_subset.classes: + domain_class_def = ontology_subset.classes[domain] + if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']: + domain_uri = domain_class_def['uri'] + else: + domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}" + else: + domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}" + + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True), + o=Value(value=domain_uri, is_uri=True) + )) + + # rdfs:range (datatype) + if isinstance(prop_def, dict) and 'rdfs:range' in prop_def and prop_def['rdfs:range']: + range_val = prop_def['rdfs:range'] + # Range for datatype properties is usually xsd:string, xsd:int, etc. + if range_val.startswith('xsd:'): + range_uri = f"http://www.w3.org/2001/XMLSchema#{range_val[4:]}" + else: + range_uri = range_val + + ontology_triples.append(Triple( + s=Value(value=prop_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True), + o=Value(value=range_uri, is_uri=True) + )) + + logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements") + return ontology_triples + + def build_entity_contexts(self, triples: List[Triple]) -> List[EntityContext]: + """Build entity contexts from extracted triples. + + Collects rdfs:label and definition properties for each entity to create + contextual descriptions for embedding. + + Args: + triples: List of extracted triples + + Returns: + List of EntityContext objects + """ + # Group triples by subject to collect entity information + entity_data = {} # subject_uri -> {labels: [], definitions: []} + + for triple in triples: + subject_uri = triple.s.value + predicate_uri = triple.p.value + object_val = triple.o.value + + # Initialize entity data if not exists + if subject_uri not in entity_data: + entity_data[subject_uri] = {'labels': [], 'definitions': []} + + # Collect labels (rdfs:label) + if predicate_uri == RDF_LABEL: + if not triple.o.is_uri: # Labels are literals + entity_data[subject_uri]['labels'].append(object_val) + + # Collect definitions (skos:definition, schema:description) + elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description": + if not triple.o.is_uri: + entity_data[subject_uri]['definitions'].append(object_val) + + # Build EntityContext objects + entity_contexts = [] + for subject_uri, data in entity_data.items(): + # Build context text from labels and definitions + context_parts = [] + + if data['labels']: + context_parts.append(f"Label: {data['labels'][0]}") + + if data['definitions']: + context_parts.extend(data['definitions']) + + # Only create EntityContext if we have meaningful context + if context_parts: + context_text = ". ".join(context_parts) + entity_contexts.append(EntityContext( + entity=Value(value=subject_uri, is_uri=True), + context=context_text + )) + + logger.debug(f"Built {len(entity_contexts)} entity contexts from {len(triples)} triples") + return entity_contexts + + @staticmethod + def add_args(parser): + """Add command-line arguments.""" + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + parser.add_argument( + '--top-k', + type=int, + default=10, + help='Number of top ontology elements to retrieve (default: 10)' + ) + parser.add_argument( + '--similarity-threshold', + type=float, + default=0.3, + help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)' + ) + FlowProcessor.add_args(parser) + + +def run(): + """Launch the OntoRAG extraction service.""" + Processor.launch(default_ident, __doc__) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py new file mode 100644 index 00000000..8eee76b4 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py @@ -0,0 +1,310 @@ +""" +Ontology embedder component for OntoRAG system. +Generates and stores embeddings for ontology elements. +""" + +import asyncio +import logging +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass + +from .ontology_loader import Ontology, OntologyClass, OntologyProperty +from .vector_store import InMemoryVectorStore + +logger = logging.getLogger(__name__) + + +@dataclass +class OntologyElementMetadata: + """Metadata for an embedded ontology element.""" + type: str # 'class', 'objectProperty', 'datatypeProperty' + ontology: str # Ontology ID + element: str # Element ID + definition: Dict[str, Any] # Full element definition + text: str # Text used for embedding + + +class OntologyEmbedder: + """Generates embeddings for ontology elements and stores them in vector store.""" + + def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None): + """Initialize the ontology embedder. + + Args: + embedding_service: Service for generating embeddings + vector_store: Vector store instance (InMemoryVectorStore) + """ + self.embedding_service = embedding_service + self.vector_store = vector_store or InMemoryVectorStore() + self.embedded_ontologies = set() + + def _create_text_representation(self, element_id: str, element: Any, + element_type: str) -> str: + """Create text representation of an ontology element for embedding. + + Args: + element_id: ID of the element + element: The element object (OntologyClass or OntologyProperty) + element_type: Type of element + + Returns: + Text representation for embedding + """ + parts = [] + + # Add the element ID (often meaningful) + parts.append(element_id.replace('-', ' ').replace('_', ' ')) + + # Add labels + if hasattr(element, 'labels') and element.labels: + for label in element.labels: + if isinstance(label, dict): + parts.append(label.get('value', '')) + else: + parts.append(str(label)) + + # Add comment/description + if hasattr(element, 'comment') and element.comment: + parts.append(element.comment) + + # Add type-specific information + if element_type == 'class': + if hasattr(element, 'subclass_of') and element.subclass_of: + parts.append(f"subclass of {element.subclass_of}") + elif element_type in ['objectProperty', 'datatypeProperty']: + if hasattr(element, 'domain') and element.domain: + parts.append(f"domain: {element.domain}") + if hasattr(element, 'range') and element.range: + parts.append(f"range: {element.range}") + + # Join all parts with spaces + text = ' '.join(filter(None, parts)) + return text + + async def embed_ontology(self, ontology: Ontology) -> int: + """Generate and store embeddings for all elements in an ontology. + + Args: + ontology: The ontology to embed + + Returns: + Number of elements embedded + """ + if not self.embedding_service: + logger.warning("No embedding service available, skipping embedding") + return 0 + + embedded_count = 0 + batch_size = 50 # Process embeddings in batches + + # Collect all elements to embed + elements_to_embed = [] + + # Process classes + for class_id, class_def in ontology.classes.items(): + text = self._create_text_representation(class_id, class_def, 'class') + elements_to_embed.append({ + 'id': f"{ontology.id}:class:{class_id}", + 'text': text, + 'metadata': OntologyElementMetadata( + type='class', + ontology=ontology.id, + element=class_id, + definition=class_def.__dict__, + text=text + ).__dict__ + }) + + # Process object properties + for prop_id, prop_def in ontology.object_properties.items(): + text = self._create_text_representation(prop_id, prop_def, 'objectProperty') + elements_to_embed.append({ + 'id': f"{ontology.id}:objectProperty:{prop_id}", + 'text': text, + 'metadata': OntologyElementMetadata( + type='objectProperty', + ontology=ontology.id, + element=prop_id, + definition=prop_def.__dict__, + text=text + ).__dict__ + }) + + # Process datatype properties + for prop_id, prop_def in ontology.datatype_properties.items(): + text = self._create_text_representation(prop_id, prop_def, 'datatypeProperty') + elements_to_embed.append({ + 'id': f"{ontology.id}:datatypeProperty:{prop_id}", + 'text': text, + 'metadata': OntologyElementMetadata( + type='datatypeProperty', + ontology=ontology.id, + element=prop_id, + definition=prop_def.__dict__, + text=text + ).__dict__ + }) + + # Process in batches + for i in range(0, len(elements_to_embed), batch_size): + batch = elements_to_embed[i:i + batch_size] + + # Get embeddings for batch + texts = [elem['text'] for elem in batch] + try: + # Call embedding service for each text + # Note: embed() returns 2D array [[vector]], so extract first element + embedding_tasks = [self.embedding_service.embed(text) for text in texts] + embeddings_responses = await asyncio.gather(*embedding_tasks) + + # Extract vectors from responses (each is [[vector]]) + embeddings_list = [resp[0] for resp in embeddings_responses] + + # Convert to numpy array + embeddings = np.array(embeddings_list) + + # Log embedding shape for debugging + logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})") + + # Store in vector store + ids = [elem['id'] for elem in batch] + metadata_list = [elem['metadata'] for elem in batch] + + self.vector_store.add_batch(ids, embeddings, metadata_list) + embedded_count += len(batch) + + logger.debug(f"Embedded batch of {len(batch)} elements from ontology {ontology.id}") + + except Exception as e: + logger.error(f"Failed to embed batch for ontology {ontology.id}: {e}", exc_info=True) + + self.embedded_ontologies.add(ontology.id) + logger.info(f"Embedded {embedded_count} elements from ontology {ontology.id}") + return embedded_count + + async def embed_ontologies(self, ontologies: Dict[str, Ontology]) -> int: + """Generate and store embeddings for multiple ontologies. + + Args: + ontologies: Dictionary of ontology ID to Ontology objects + + Returns: + Total number of elements embedded + """ + total_embedded = 0 + + for ont_id, ontology in ontologies.items(): + if ont_id not in self.embedded_ontologies: + count = await self.embed_ontology(ontology) + total_embedded += count + else: + logger.debug(f"Ontology {ont_id} already embedded, skipping") + + logger.info(f"Total embedded elements: {total_embedded} from {len(ontologies)} ontologies") + return total_embedded + + async def embed_text(self, text: str) -> Optional[np.ndarray]: + """Generate embedding for a single text. + + Args: + text: Text to embed + + Returns: + Embedding vector or None if failed + """ + if not self.embedding_service: + logger.warning("No embedding service available") + return None + + try: + # embed() returns 2D array [[vector]], extract first element + embedding_response = await self.embedding_service.embed(text) + return np.array(embedding_response[0]) + except Exception as e: + logger.error(f"Failed to embed text: {e}") + return None + + async def embed_texts(self, texts: List[str]) -> Optional[np.ndarray]: + """Generate embeddings for multiple texts. + + Args: + texts: List of texts to embed + + Returns: + Array of embeddings or None if failed + """ + if not self.embedding_service: + logger.warning("No embedding service available") + return None + + try: + # Call embed() for each text (returns [[vector]] per call) + embedding_tasks = [self.embedding_service.embed(text) for text in texts] + embeddings_responses = await asyncio.gather(*embedding_tasks) + # Extract first vector from each response + embeddings_list = [resp[0] for resp in embeddings_responses] + return np.array(embeddings_list) + except Exception as e: + logger.error(f"Failed to embed texts: {e}") + return None + + def remove_ontology(self, ontology_id: str): + """Remove all embeddings for a specific ontology. + + Note: FAISS doesn't support efficient deletion, so this currently + requires rebuilding the entire index without the removed ontology. + + Args: + ontology_id: ID of ontology to remove + """ + if ontology_id not in self.embedded_ontologies: + logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove") + return + + # FAISS doesn't support selective deletion, so we'd need to rebuild the index + # For now, just remove from tracking set + # TODO: Implement index rebuilding if selective removal is needed + self.embedded_ontologies.discard(ontology_id) + logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)") + + def clear_embeddings(self, ontology_id: Optional[str] = None): + """Clear embeddings from vector store. + + Args: + ontology_id: If provided, only clear embeddings for this ontology + Otherwise, clear all embeddings + """ + if ontology_id: + self.remove_ontology(ontology_id) + else: + self.vector_store.clear() + self.embedded_ontologies.clear() + logger.info("Cleared all embeddings from vector store") + + def get_vector_store(self) -> InMemoryVectorStore: + """Get the vector store instance. + + Returns: + The vector store being used + """ + return self.vector_store + + def get_embedded_count(self) -> int: + """Get the number of embedded elements. + + Returns: + Number of elements in the vector store + """ + return self.vector_store.size() + + def is_ontology_embedded(self, ontology_id: str) -> bool: + """Check if an ontology has been embedded. + + Args: + ontology_id: ID of the ontology + + Returns: + True if the ontology has been embedded + """ + return ontology_id in self.embedded_ontologies \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py new file mode 100644 index 00000000..710108b6 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py @@ -0,0 +1,247 @@ +""" +Ontology loader component for OntoRAG system. +Loads and manages ontologies from configuration service. +""" + +import json +import logging +from typing import Dict, Any, Optional, List +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class OntologyClass: + """Represents an OWL-like class in the ontology.""" + uri: str + type: str = "owl:Class" + labels: List[Dict[str, str]] = field(default_factory=list) + comment: Optional[str] = None + subclass_of: Optional[str] = None + equivalent_classes: List[str] = field(default_factory=list) + disjoint_with: List[str] = field(default_factory=list) + identifier: Optional[str] = None + + @staticmethod + def from_dict(class_id: str, data: Dict[str, Any]) -> 'OntologyClass': + """Create OntologyClass from dictionary representation.""" + labels = data.get('rdfs:label', []) + if isinstance(labels, list): + labels = labels + else: + labels = [labels] if labels else [] + + return OntologyClass( + uri=data.get('uri', ''), + type=data.get('type', 'owl:Class'), + labels=labels, + comment=data.get('rdfs:comment'), + subclass_of=data.get('rdfs:subClassOf'), + equivalent_classes=data.get('owl:equivalentClass', []), + disjoint_with=data.get('owl:disjointWith', []), + identifier=data.get('dcterms:identifier') + ) + + +@dataclass +class OntologyProperty: + """Represents a property (object or datatype) in the ontology.""" + uri: str + type: str + labels: List[Dict[str, str]] = field(default_factory=list) + comment: Optional[str] = None + domain: Optional[str] = None + range: Optional[str] = None + inverse_of: Optional[str] = None + functional: bool = False + inverse_functional: bool = False + min_cardinality: Optional[int] = None + max_cardinality: Optional[int] = None + cardinality: Optional[int] = None + + @staticmethod + def from_dict(prop_id: str, data: Dict[str, Any]) -> 'OntologyProperty': + """Create OntologyProperty from dictionary representation.""" + labels = data.get('rdfs:label', []) + if isinstance(labels, list): + labels = labels + else: + labels = [labels] if labels else [] + + return OntologyProperty( + uri=data.get('uri', ''), + type=data.get('type', ''), + labels=labels, + comment=data.get('rdfs:comment'), + domain=data.get('rdfs:domain'), + range=data.get('rdfs:range'), + inverse_of=data.get('owl:inverseOf'), + functional=data.get('owl:functionalProperty', False), + inverse_functional=data.get('owl:inverseFunctionalProperty', False), + min_cardinality=data.get('owl:minCardinality'), + max_cardinality=data.get('owl:maxCardinality'), + cardinality=data.get('owl:cardinality') + ) + + +@dataclass +class Ontology: + """Represents a complete ontology with metadata, classes, and properties.""" + id: str + metadata: Dict[str, Any] + classes: Dict[str, OntologyClass] + object_properties: Dict[str, OntologyProperty] + datatype_properties: Dict[str, OntologyProperty] + + def get_class(self, class_id: str) -> Optional[OntologyClass]: + """Get a class by ID.""" + return self.classes.get(class_id) + + def get_property(self, prop_id: str) -> Optional[OntologyProperty]: + """Get a property (object or datatype) by ID.""" + prop = self.object_properties.get(prop_id) + if prop is None: + prop = self.datatype_properties.get(prop_id) + return prop + + def get_parent_classes(self, class_id: str) -> List[str]: + """Get all parent classes (following subClassOf hierarchy).""" + parents = [] + current = class_id + visited = set() + + while current and current not in visited: + visited.add(current) + cls = self.get_class(current) + if cls and cls.subclass_of: + parents.append(cls.subclass_of) + current = cls.subclass_of + else: + break + + return parents + + def validate_structure(self) -> List[str]: + """Validate ontology structure and return list of issues.""" + issues = [] + + # Check for circular inheritance + for class_id in self.classes: + visited = set() + current = class_id + while current: + if current in visited: + issues.append(f"Circular inheritance detected for class {class_id}") + break + visited.add(current) + cls = self.get_class(current) + if cls: + current = cls.subclass_of + else: + break + + # Check property domains and ranges exist + for prop_id, prop in {**self.object_properties, **self.datatype_properties}.items(): + if prop.domain and prop.domain not in self.classes: + issues.append(f"Property {prop_id} has unknown domain {prop.domain}") + if prop.type == "owl:ObjectProperty" and prop.range and prop.range not in self.classes: + issues.append(f"Object property {prop_id} has unknown range class {prop.range}") + + # Check disjoint classes + for class_id, cls in self.classes.items(): + for disjoint_id in cls.disjoint_with: + if disjoint_id not in self.classes: + issues.append(f"Class {class_id} disjoint with unknown class {disjoint_id}") + + return issues + + +class OntologyLoader: + """Manages ontologies received via event-driven config updates. + + No direct database access - receives ontologies via config handler. + """ + + def __init__(self): + """Initialize empty ontology store.""" + self.ontologies: Dict[str, Ontology] = {} + + def update_ontologies(self, ontology_configs: Dict[str, Any]): + """Update ontology definitions from config. + + Args: + ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts) + """ + self.ontologies.clear() + + for ont_id, ont_data in ontology_configs.items(): + try: + # Parse classes + classes = {} + for class_id, class_data in ont_data.get('classes', {}).items(): + classes[class_id] = OntologyClass.from_dict(class_id, class_data) + + # Parse object properties + object_props = {} + for prop_id, prop_data in ont_data.get('objectProperties', {}).items(): + object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) + + # Parse datatype properties + datatype_props = {} + for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items(): + datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) + + # Create ontology + ontology = Ontology( + id=ont_id, + metadata=ont_data.get('metadata', {}), + classes=classes, + object_properties=object_props, + datatype_properties=datatype_props + ) + + # Validate structure + issues = ontology.validate_structure() + if issues: + logger.warning(f"Ontology {ont_id} has validation issues: {issues}") + + self.ontologies[ont_id] = ontology + logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, " + f"{len(object_props)} object properties, " + f"{len(datatype_props)} datatype properties") + + except Exception as e: + logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True) + + def get_ontology(self, ont_id: str) -> Optional[Ontology]: + """Get a specific ontology by ID. + + Args: + ont_id: Ontology identifier + + Returns: + Ontology object or None if not found + """ + return self.ontologies.get(ont_id) + + def get_all_ontologies(self) -> Dict[str, Ontology]: + """Get all loaded ontologies. + + Returns: + Dictionary of ontology ID to Ontology objects + """ + return self.ontologies + + def list_ontology_ids(self) -> List[str]: + """Get list of loaded ontology IDs. + + Returns: + List of ontology IDs + """ + return list(self.ontologies.keys()) + + def clear(self): + """Clear all loaded ontologies.""" + self.ontologies.clear() + logger.info("Cleared all loaded ontologies") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py new file mode 100644 index 00000000..5111529a --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py @@ -0,0 +1,356 @@ +""" +Ontology selection algorithm for OntoRAG system. +Selects relevant ontology subsets based on text similarity. +""" + +import logging +from typing import List, Dict, Any, Set, Optional, Tuple +from dataclasses import dataclass +from collections import defaultdict + +from .ontology_loader import Ontology, OntologyLoader +from .ontology_embedder import OntologyEmbedder +from .text_processor import TextSegment +from .vector_store import SearchResult + +logger = logging.getLogger(__name__) + + +@dataclass +class OntologySubset: + """Represents a subset of an ontology relevant to a text chunk.""" + ontology_id: str + classes: Dict[str, Any] + object_properties: Dict[str, Any] + datatype_properties: Dict[str, Any] + metadata: Dict[str, Any] + relevance_score: float = 0.0 + + +class OntologySelector: + """Selects relevant ontology elements for text segments using vector similarity.""" + + def __init__(self, ontology_embedder: OntologyEmbedder, + ontology_loader: OntologyLoader, + top_k: int = 10, + similarity_threshold: float = 0.7): + """Initialize the ontology selector. + + Args: + ontology_embedder: Embedder with vector store + ontology_loader: Loader with ontology definitions + top_k: Number of top results to retrieve per segment + similarity_threshold: Minimum similarity score + """ + self.embedder = ontology_embedder + self.loader = ontology_loader + self.top_k = top_k + self.similarity_threshold = similarity_threshold + + async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]: + """Select relevant ontology subsets for text segments. + + Args: + segments: List of text segments to match + + Returns: + List of ontology subsets with relevant elements + """ + # Collect all relevant elements + relevant_elements = await self._find_relevant_elements(segments) + + # Group by ontology and build subsets + ontology_subsets = self._build_ontology_subsets(relevant_elements) + + # Resolve dependencies + for subset in ontology_subsets: + self._resolve_dependencies(subset) + + logger.info(f"Selected {len(ontology_subsets)} ontology subsets") + return ontology_subsets + + async def _find_relevant_elements(self, segments: List[TextSegment]) -> Set[Tuple[str, str, str, Dict]]: + """Find relevant ontology elements for text segments. + + Args: + segments: Text segments to match + + Returns: + Set of (ontology_id, element_type, element_id, definition) tuples + """ + relevant_elements = set() + element_scores = defaultdict(float) + + # Check if vector store has any elements + vector_store = self.embedder.get_vector_store() + store_size = vector_store.size() + logger.info(f"Vector store size: {store_size} elements") + + if store_size == 0: + logger.warning("Vector store is empty - no ontology elements embedded") + return relevant_elements + + # Process each segment (log first few for debugging) + for i, segment in enumerate(segments): + # Get embedding for segment + embedding = await self.embedder.embed_text(segment.text) + if embedding is None: + logger.warning(f"Failed to embed segment: {segment.text[:50]}...") + continue + + # Search vector store with no threshold to see all scores + all_results = vector_store.search( + embedding=embedding, + top_k=self.top_k, + threshold=0.0 # Get all results to see scores + ) + + # Log top scores for first 3 segments to debug + if i < 3 and all_results: + top_scores = [r.score for r in all_results[:3]] + top_elements = [r.metadata['element'] for r in all_results[:3]] + logger.info(f"Segment {i}: '{segment.text[:60]}...'") + logger.info(f" Top 3 scores: {top_scores} (threshold={self.similarity_threshold})") + logger.info(f" Top 3 elements: {top_elements}") + + # Filter by threshold + results = [r for r in all_results if r.score >= self.similarity_threshold] + + # Process results + for result in results: + metadata = result.metadata + element_key = ( + metadata['ontology'], + metadata['type'], + metadata['element'], + str(metadata['definition']) # Convert dict to string for hashability + ) + relevant_elements.add(element_key) + # Track scores for ranking + element_scores[element_key] = max(element_scores[element_key], result.score) + + logger.info(f"Found {len(relevant_elements)} relevant elements from {len(segments)} segments") + return relevant_elements + + def _build_ontology_subsets(self, relevant_elements: Set[Tuple[str, str, str, Dict]]) -> List[OntologySubset]: + """Build ontology subsets from relevant elements. + + Args: + relevant_elements: Set of relevant element tuples + + Returns: + List of ontology subsets + """ + # Group elements by ontology + ontology_groups = defaultdict(lambda: { + 'classes': {}, + 'object_properties': {}, + 'datatype_properties': {}, + 'scores': [] + }) + + for ont_id, elem_type, elem_id, definition in relevant_elements: + # Parse definition back from string if needed + if isinstance(definition, str): + import json + try: + definition = json.loads(definition.replace("'", '"')) + except: + definition = eval(definition) # Fallback for dict-like strings + + # Get the actual ontology and element + ontology = self.loader.get_ontology(ont_id) + if not ontology: + logger.warning(f"Ontology {ont_id} not found in loader") + continue + + # Add element to appropriate category + if elem_type == 'class': + cls = ontology.get_class(elem_id) + if cls: + ontology_groups[ont_id]['classes'][elem_id] = cls.__dict__ + elif elem_type == 'objectProperty': + prop = ontology.object_properties.get(elem_id) + if prop: + ontology_groups[ont_id]['object_properties'][elem_id] = prop.__dict__ + elif elem_type == 'datatypeProperty': + prop = ontology.datatype_properties.get(elem_id) + if prop: + ontology_groups[ont_id]['datatype_properties'][elem_id] = prop.__dict__ + + # Create OntologySubset objects + subsets = [] + for ont_id, elements in ontology_groups.items(): + ontology = self.loader.get_ontology(ont_id) + if ontology: + subset = OntologySubset( + ontology_id=ont_id, + classes=elements['classes'], + object_properties=elements['object_properties'], + datatype_properties=elements['datatype_properties'], + metadata=ontology.metadata, + relevance_score=sum(elements['scores']) / len(elements['scores']) if elements['scores'] else 0.0 + ) + subsets.append(subset) + + return subsets + + def _resolve_dependencies(self, subset: OntologySubset): + """Resolve dependencies for ontology subset elements. + + Args: + subset: Ontology subset to resolve dependencies for + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return + + # Track classes to add + classes_to_add = set() + + # Resolve class hierarchies + for class_id in list(subset.classes.keys()): + # Add parent classes + parents = ontology.get_parent_classes(class_id) + for parent_id in parents: + parent_class = ontology.get_class(parent_id) + if parent_class and parent_id not in subset.classes: + classes_to_add.add(parent_id) + + # Resolve property domains and ranges + for prop_id, prop_def in subset.object_properties.items(): + # Add domain class + if 'domain' in prop_def and prop_def['domain']: + domain_id = prop_def['domain'] + if domain_id not in subset.classes: + domain_class = ontology.get_class(domain_id) + if domain_class: + classes_to_add.add(domain_id) + + # Add range class + if 'range' in prop_def and prop_def['range']: + range_id = prop_def['range'] + if range_id not in subset.classes: + range_class = ontology.get_class(range_id) + if range_class: + classes_to_add.add(range_id) + + # Resolve datatype property domains + for prop_id, prop_def in subset.datatype_properties.items(): + if 'domain' in prop_def and prop_def['domain']: + domain_id = prop_def['domain'] + if domain_id not in subset.classes: + domain_class = ontology.get_class(domain_id) + if domain_class: + classes_to_add.add(domain_id) + + # Add inverse properties + for prop_id, prop_def in list(subset.object_properties.items()): + if 'inverse_of' in prop_def and prop_def['inverse_of']: + inverse_id = prop_def['inverse_of'] + if inverse_id not in subset.object_properties: + inverse_prop = ontology.object_properties.get(inverse_id) + if inverse_prop: + subset.object_properties[inverse_id] = inverse_prop.__dict__ + + # NEW: Auto-include properties related to selected classes + # For each selected class, find all properties that reference it in domain or range + properties_added = 0 + datatype_properties_added = 0 + + for class_id in list(subset.classes.keys()): + # Check all object properties in the ontology + for prop_id, prop_def in ontology.object_properties.items(): + if prop_id not in subset.object_properties: + # Check if this class is in the property's domain or range + prop_domain = getattr(prop_def, 'domain', None) + prop_range = getattr(prop_def, 'range', None) + + if prop_domain == class_id or prop_range == class_id: + subset.object_properties[prop_id] = prop_def.__dict__ + properties_added += 1 + + # Also add the other class (domain or range) if not already present + if prop_domain and prop_domain != class_id and prop_domain not in subset.classes: + other_class = ontology.get_class(prop_domain) + if other_class: + classes_to_add.add(prop_domain) + if prop_range and prop_range != class_id and prop_range not in subset.classes: + other_class = ontology.get_class(prop_range) + if other_class: + classes_to_add.add(prop_range) + + # Check all datatype properties in the ontology + for prop_id, prop_def in ontology.datatype_properties.items(): + if prop_id not in subset.datatype_properties: + # Check if this class is in the property's domain + prop_domain = getattr(prop_def, 'domain', None) + + if prop_domain == class_id: + subset.datatype_properties[prop_id] = prop_def.__dict__ + datatype_properties_added += 1 + + # Add collected classes + for class_id in classes_to_add: + cls = ontology.get_class(class_id) + if cls: + subset.classes[class_id] = cls.__dict__ + + logger.debug(f"Resolved dependencies for subset {subset.ontology_id}: " + f"added {len(classes_to_add)} classes, " + f"{properties_added} object properties, " + f"{datatype_properties_added} datatype properties") + + def merge_subsets(self, subsets: List[OntologySubset]) -> OntologySubset: + """Merge multiple ontology subsets into one. + + Args: + subsets: List of subsets to merge + + Returns: + Merged ontology subset + """ + if not subsets: + return None + if len(subsets) == 1: + return subsets[0] + + # Use first subset as base + merged = OntologySubset( + ontology_id="merged", + classes={}, + object_properties={}, + datatype_properties={}, + metadata={}, + relevance_score=0.0 + ) + + # Merge all subsets + total_score = 0.0 + for subset in subsets: + # Merge classes + for class_id, class_def in subset.classes.items(): + key = f"{subset.ontology_id}:{class_id}" + merged.classes[key] = class_def + + # Merge object properties + for prop_id, prop_def in subset.object_properties.items(): + key = f"{subset.ontology_id}:{prop_id}" + merged.object_properties[key] = prop_def + + # Merge datatype properties + for prop_id, prop_def in subset.datatype_properties.items(): + key = f"{subset.ontology_id}:{prop_id}" + merged.datatype_properties[key] = prop_def + + total_score += subset.relevance_score + + # Average relevance score + merged.relevance_score = total_score / len(subsets) + + logger.info(f"Merged {len(subsets)} subsets into one with " + f"{len(merged.classes)} classes, " + f"{len(merged.object_properties)} object properties, " + f"{len(merged.datatype_properties)} datatype properties") + + return merged \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/run.py b/trustgraph-flow/trustgraph/extract/kg/ontology/run.py new file mode 100644 index 00000000..c0a6143b --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/run.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +""" +OntoRAG extraction service launcher. +""" + +from . extract import run + +if __name__ == "__main__": + run() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py b/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py new file mode 100644 index 00000000..685699d1 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py @@ -0,0 +1,240 @@ +""" +Text processing components for OntoRAG system. +Splits text into sentences and extracts phrases for granular matching. +""" + +import logging +import re +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +import nltk +from nltk.corpus import stopwords + +logger = logging.getLogger(__name__) + +# Ensure required NLTK data is downloaded +try: + nltk.data.find('tokenizers/punkt_tab') +except LookupError: + try: + nltk.download('punkt_tab', quiet=True) + except: + # Fallback to older punkt if punkt_tab not available + try: + nltk.download('punkt', quiet=True) + except: + pass + +try: + nltk.data.find('taggers/averaged_perceptron_tagger_eng') +except LookupError: + try: + nltk.download('averaged_perceptron_tagger_eng', quiet=True) + except: + # Fallback to older name + try: + nltk.download('averaged_perceptron_tagger', quiet=True) + except: + pass + +try: + nltk.data.find('corpora/stopwords') +except LookupError: + nltk.download('stopwords', quiet=True) + + +@dataclass +class TextSegment: + """Represents a segment of text (sentence or phrase).""" + text: str + type: str # 'sentence', 'phrase', 'noun_phrase', 'verb_phrase' + position: int + parent_sentence: Optional[str] = None + metadata: Dict[str, Any] = None + + +class SentenceSplitter: + """Splits text into sentences using NLTK.""" + + def __init__(self): + """Initialize sentence splitter.""" + try: + # Try newer punkt_tab first + self.sent_detector = nltk.data.load('tokenizers/punkt_tab/english/') + logger.info("Using NLTK sentence tokenizer (punkt_tab)") + except: + # Fallback to older punkt + self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') + logger.info("Using NLTK sentence tokenizer (punkt)") + + def split(self, text: str) -> List[str]: + """Split text into sentences. + + Args: + text: Text to split + + Returns: + List of sentences + """ + sentences = self.sent_detector.tokenize(text) + return sentences + + +class PhraseExtractor: + """Extracts meaningful phrases from sentences using NLTK.""" + + def __init__(self): + """Initialize phrase extractor.""" + logger.info("Using NLTK phrase extraction") + + def extract(self, sentence: str) -> List[Dict[str, str]]: + """Extract phrases from a sentence. + + Args: + sentence: Sentence to extract phrases from + + Returns: + List of phrases with their types + """ + phrases = [] + + # Tokenize and POS tag + tokens = nltk.word_tokenize(sentence) + pos_tags = nltk.pos_tag(tokens) + + # Extract noun phrases (simple pattern) + noun_phrase = [] + for word, pos in pos_tags: + if pos.startswith('NN') or pos.startswith('JJ'): + noun_phrase.append(word) + elif noun_phrase: + if len(noun_phrase) > 1: + phrases.append({ + 'text': ' '.join(noun_phrase), + 'type': 'noun_phrase' + }) + noun_phrase = [] + + # Add last noun phrase if exists + if noun_phrase and len(noun_phrase) > 1: + phrases.append({ + 'text': ' '.join(noun_phrase), + 'type': 'noun_phrase' + }) + + # Extract verb phrases (simple pattern) + verb_phrase = [] + for word, pos in pos_tags: + if pos.startswith('VB') or pos.startswith('RB'): + verb_phrase.append(word) + elif verb_phrase: + if len(verb_phrase) > 1: + phrases.append({ + 'text': ' '.join(verb_phrase), + 'type': 'verb_phrase' + }) + verb_phrase = [] + + # Add last verb phrase if exists + if verb_phrase and len(verb_phrase) > 1: + phrases.append({ + 'text': ' '.join(verb_phrase), + 'type': 'verb_phrase' + }) + + return phrases + + +class TextProcessor: + """Main text processing class that coordinates sentence splitting and phrase extraction.""" + + def __init__(self): + """Initialize text processor.""" + self.sentence_splitter = SentenceSplitter() + self.phrase_extractor = PhraseExtractor() + + def process_chunk(self, chunk_text: str, extract_phrases: bool = True) -> List[TextSegment]: + """Process a text chunk into segments. + + Args: + chunk_text: Text chunk to process + extract_phrases: Whether to extract phrases from sentences + + Returns: + List of TextSegment objects + """ + segments = [] + position = 0 + + # Split into sentences + sentences = self.sentence_splitter.split(chunk_text) + + for sentence in sentences: + # Add sentence segment + segments.append(TextSegment( + text=sentence, + type='sentence', + position=position + )) + position += 1 + + # Extract phrases if requested + if extract_phrases: + phrases = self.phrase_extractor.extract(sentence) + for phrase_data in phrases: + segments.append(TextSegment( + text=phrase_data['text'], + type=phrase_data['type'], + position=position, + parent_sentence=sentence + )) + position += 1 + + logger.debug(f"Processed chunk into {len(segments)} segments") + return segments + + def extract_key_terms(self, text: str) -> List[str]: + """Extract key terms from text for matching. + + Args: + text: Text to extract terms from + + Returns: + List of key terms + """ + terms = [] + + # Split on word boundaries + words = re.findall(r'\b\w+\b', text.lower()) + + # Use NLTK stopwords + stop_words = set(stopwords.words('english')) + + # Filter stopwords and short words + terms = [w for w in words if w not in stop_words and len(w) > 2] + + # Also extract multi-word terms (bigrams) + for i in range(len(words) - 1): + if words[i] not in stop_words and words[i+1] not in stop_words: + bigram = f"{words[i]} {words[i+1]}" + terms.append(bigram) + + return terms + + def normalize_text(self, text: str) -> str: + """Normalize text for consistent processing. + + Args: + text: Text to normalize + + Returns: + Normalized text + """ + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text) + # Remove leading/trailing whitespace + text = text.strip() + # Normalize quotes + text = text.replace('"', '"').replace('"', '"') + text = text.replace(''', "'").replace(''', "'") + return text diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py b/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py new file mode 100644 index 00000000..6f456861 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py @@ -0,0 +1,130 @@ +""" +Vector store implementation for OntoRAG system. +Provides FAISS-based vector storage for ontology embeddings. +""" + +import logging +import numpy as np +from typing import List, Dict, Any +from dataclasses import dataclass +import faiss + +logger = logging.getLogger(__name__) + + +@dataclass +class SearchResult: + """Result from vector similarity search.""" + id: str + score: float + metadata: Dict[str, Any] + + +class InMemoryVectorStore: + """FAISS-based vector store implementation for ontology embeddings.""" + + def __init__(self, dimension: int = 1536, index_type: str = 'flat'): + """Initialize FAISS vector store. + + Args: + dimension: Embedding dimension (1536 for text-embedding-3-small) + index_type: 'flat' for exact search, 'ivf' for larger datasets + """ + self.dimension = dimension + self.metadata = [] + self.ids = [] + + if index_type == 'flat': + # Exact search - best for ontologies with <10k elements + self.index = faiss.IndexFlatIP(dimension) + logger.info(f"Created FAISS flat index with dimension {dimension}") + else: + # Approximate search - for larger ontologies + quantizer = faiss.IndexFlatIP(dimension) + self.index = faiss.IndexIVFFlat(quantizer, dimension, 100) + # Train with random vectors for initialization + training_data = np.random.randn(1000, dimension).astype('float32') + training_data = training_data / np.linalg.norm( + training_data, axis=1, keepdims=True + ) + self.index.train(training_data) + logger.info(f"Created FAISS IVF index with dimension {dimension}") + + def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]): + """Add single embedding with metadata.""" + # Normalize for cosine similarity + embedding = embedding / np.linalg.norm(embedding) + self.index.add(np.array([embedding], dtype=np.float32)) + self.metadata.append(metadata) + self.ids.append(id) + + def add_batch(self, ids: List[str], embeddings: np.ndarray, + metadata_list: List[Dict[str, Any]]): + """Batch add for initial ontology loading.""" + # Normalize all embeddings + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + normalized = embeddings / norms + self.index.add(normalized.astype(np.float32)) + self.metadata.extend(metadata_list) + self.ids.extend(ids) + logger.debug(f"Added batch of {len(ids)} embeddings to FAISS index") + + def search(self, embedding: np.ndarray, top_k: int = 10, + threshold: float = 0.0) -> List[SearchResult]: + """Search for similar vectors.""" + # Normalize query + embedding = embedding / np.linalg.norm(embedding) + + # Search + scores, indices = self.index.search( + np.array([embedding], dtype=np.float32), + min(top_k, self.index.ntotal) + ) + + # Filter by threshold and format results + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx >= 0 and score >= threshold: # FAISS returns -1 for empty slots + results.append(SearchResult( + id=self.ids[idx], + score=float(score), + metadata=self.metadata[idx] + )) + + return results + + def clear(self): + """Reset the store.""" + self.index.reset() + self.metadata = [] + self.ids = [] + logger.info("Cleared FAISS vector store") + + def size(self) -> int: + """Return number of stored vectors.""" + return self.index.ntotal + + +# Utility functions for vector operations +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + """Compute cosine similarity between two vectors.""" + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndarray: + """Compute cosine similarity between query vectors and target vectors. + + Args: + queries: Array of shape (n_queries, dimension) + targets: Array of shape (n_targets, dimension) + + Returns: + Array of shape (n_queries, n_targets) with similarity scores + """ + # Normalize queries and targets + queries_norm = queries / np.linalg.norm(queries, axis=1, keepdims=True) + targets_norm = targets / np.linalg.norm(targets, axis=1, keepdims=True) + + # Compute dot product + similarities = np.dot(queries_norm, targets_norm.T) + return similarities diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 74ed2353..d2698589 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -87,12 +87,6 @@ class Processor(LlmService): ], temperature=effective_temperature, max_tokens=self.max_output, - top_p=1, - frequency_penalty=0, - presence_penalty=0, - response_format={ - "type": "text" - } ) inputtokens = resp.usage.prompt_tokens diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 4ec91dfe..f0d66021 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -47,39 +47,6 @@ class Processor(DocumentEmbeddingsQueryService): } ) - self.last_index_name = None - - def ensure_index_exists(self, index_name, dim): - """Ensure index exists, create if it doesn't""" - if index_name != self.last_index_name: - if not self.pinecone.has_index(index_name): - try: - self.pinecone.create_index( - name=index_name, - dimension=dim, - metric="cosine", - spec=ServerlessSpec( - cloud="aws", - region="us-east-1", - ) - ) - logger.info(f"Created index: {index_name}") - - # Wait for index to be ready - import time - for i in range(0, 1000): - if self.pinecone.describe_index(index_name).status["ready"]: - break - time.sleep(1) - - if not self.pinecone.describe_index(index_name).status["ready"]: - raise RuntimeError("Gave up waiting for index creation") - - except Exception as e: - logger.error(f"Pinecone index creation failed: {e}") - raise e - self.last_index_name = index_name - async def query_document_embeddings(self, msg): try: @@ -94,11 +61,13 @@ class Processor(DocumentEmbeddingsQueryService): dim = len(vec) - index_name = ( - "d-" + msg.user + "-" + msg.collection - ) + # Use dimension suffix in index name + index_name = f"d-{msg.user}-{msg.collection}-{dim}" - self.ensure_index_exists(index_name, dim) + # Check if index exists - skip if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist, skipping this vector") + continue index = self.pinecone.Index(index_name) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index f4e04e98..46e9e687 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -71,16 +71,17 @@ class Processor(DocumentEmbeddingsQueryService): chunks = [] - collection = ( - "d_" + msg.user + "_" + msg.collection - ) - - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, returning empty results") - return [] - for vec in msg.vectors: + + # Use dimension suffix in collection name + dim = len(vec) + collection = f"d_{msg.user}_{msg.collection}_{dim}" + + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist, returning empty results") + continue + search_result = self.qdrant.query_points( collection_name=collection, query=vec, diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 30e24bd8..f6277e4f 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -49,39 +49,6 @@ class Processor(GraphEmbeddingsQueryService): } ) - self.last_index_name = None - - def ensure_index_exists(self, index_name, dim): - """Ensure index exists, create if it doesn't""" - if index_name != self.last_index_name: - if not self.pinecone.has_index(index_name): - try: - self.pinecone.create_index( - name=index_name, - dimension=dim, - metric="cosine", - spec=ServerlessSpec( - cloud="aws", - region="us-east-1", - ) - ) - logger.info(f"Created index: {index_name}") - - # Wait for index to be ready - import time - for i in range(0, 1000): - if self.pinecone.describe_index(index_name).status["ready"]: - break - time.sleep(1) - - if not self.pinecone.describe_index(index_name).status["ready"]: - raise RuntimeError("Gave up waiting for index creation") - - except Exception as e: - logger.error(f"Pinecone index creation failed: {e}") - raise e - self.last_index_name = index_name - def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): return Value(value=ent, is_uri=True) @@ -103,11 +70,13 @@ class Processor(GraphEmbeddingsQueryService): dim = len(vec) - index_name = ( - "t-" + msg.user + "-" + msg.collection - ) + # Use dimension suffix in index name + index_name = f"t-{msg.user}-{msg.collection}-{dim}" - self.ensure_index_exists(index_name, dim) + # Check if index exists - skip if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist, skipping this vector") + continue index = self.pinecone.Index(index_name) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index 6e6be420..513fd2e4 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -78,17 +78,17 @@ class Processor(GraphEmbeddingsQueryService): entity_set = set() entities = [] - collection = ( - "t_" + msg.user + "_" + msg.collection - ) - - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, returning empty results") - return [] - for vec in msg.vectors: + # Use dimension suffix in collection name + dim = len(vec) + collection = f"t_{msg.user}_{msg.collection}_{dim}" + + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist, skipping this vector") + continue + # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) entities search_result = self.qdrant.query_points( diff --git a/trustgraph-flow/trustgraph/query/ontology/__init__.py b/trustgraph-flow/trustgraph/query/ontology/__init__.py new file mode 100644 index 00000000..60557ea9 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/__init__.py @@ -0,0 +1,54 @@ +""" +OntoRAG Query System. + +Ontology-driven natural language query processing with multi-backend support. +Provides semantic query understanding, ontology matching, and answer generation. +""" + +from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse +from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType +from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .backend_router import BackendRouter, BackendType, QueryRoute +from .sparql_generator import SPARQLGenerator, SPARQLQuery +from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult +from .cypher_generator import CypherGenerator, CypherQuery +from .cypher_executor import CypherExecutor, CypherResult +from .answer_generator import AnswerGenerator, GeneratedAnswer, AnswerMetadata + +__all__ = [ + # Main service + 'OntoRAGQueryService', + 'QueryRequest', + 'QueryResponse', + + # Question analysis + 'QuestionAnalyzer', + 'QuestionComponents', + 'QuestionType', + + # Ontology matching + 'OntologyMatcher', + 'QueryOntologySubset', + + # Backend routing + 'BackendRouter', + 'BackendType', + 'QueryRoute', + + # SPARQL components + 'SPARQLGenerator', + 'SPARQLQuery', + 'SPARQLCassandraEngine', + 'SPARQLResult', + + # Cypher components + 'CypherGenerator', + 'CypherQuery', + 'CypherExecutor', + 'CypherResult', + + # Answer generation + 'AnswerGenerator', + 'GeneratedAnswer', + 'AnswerMetadata', +] \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/answer_generator.py b/trustgraph-flow/trustgraph/query/ontology/answer_generator.py new file mode 100644 index 00000000..9b4b6ba7 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/answer_generator.py @@ -0,0 +1,521 @@ +""" +Answer generator for natural language responses. +Converts query results into natural language answers using LLM assistance. +""" + +import logging +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass +from datetime import datetime + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset +from .sparql_cassandra import SPARQLResult +from .cypher_executor import CypherResult + +logger = logging.getLogger(__name__) + + +@dataclass +class AnswerMetadata: + """Metadata about answer generation.""" + query_type: str + backend_used: str + execution_time: float + result_count: int + confidence: float + explanation: str + sources: List[str] + + +@dataclass +class GeneratedAnswer: + """Generated natural language answer.""" + answer: str + metadata: AnswerMetadata + supporting_facts: List[str] + raw_results: Union[SPARQLResult, CypherResult] + generation_time: float + + +class AnswerGenerator: + """Generates natural language answers from query results.""" + + def __init__(self, prompt_service=None): + """Initialize answer generator. + + Args: + prompt_service: Service for LLM-based answer generation + """ + self.prompt_service = prompt_service + + # Answer templates for different question types + self.templates = { + 'count': "There are {count} {entity_type}.", + 'boolean_true': "Yes, {statement} is true.", + 'boolean_false': "No, {statement} is not true.", + 'list': "The {entity_type} are: {items}.", + 'single': "The {property} of {entity} is {value}.", + 'none': "No results were found for your query.", + 'error': "I encountered an error processing your query: {error}" + } + + async def generate_answer(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset, + backend_used: str) -> GeneratedAnswer: + """Generate natural language answer from query results. + + Args: + question_components: Original question analysis + query_results: Results from query execution + ontology_subset: Ontology subset used + backend_used: Backend that executed the query + + Returns: + Generated answer with metadata + """ + start_time = datetime.now() + + try: + # Try LLM-based generation first + if self.prompt_service: + llm_answer = await self._generate_with_llm( + question_components, query_results, ontology_subset + ) + if llm_answer: + execution_time = (datetime.now() - start_time).total_seconds() + return self._build_answer_response( + llm_answer, question_components, query_results, + backend_used, execution_time + ) + + # Fall back to template-based generation + template_answer = self._generate_with_template( + question_components, query_results, ontology_subset + ) + + execution_time = (datetime.now() - start_time).total_seconds() + return self._build_answer_response( + template_answer, question_components, query_results, + backend_used, execution_time + ) + + except Exception as e: + logger.error(f"Answer generation failed: {e}") + execution_time = (datetime.now() - start_time).total_seconds() + error_answer = self.templates['error'].format(error=str(e)) + return self._build_answer_response( + error_answer, question_components, query_results, + backend_used, execution_time, confidence=0.0 + ) + + async def _generate_with_llm(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset) -> Optional[str]: + """Generate answer using LLM. + + Args: + question_components: Question analysis + query_results: Query results + ontology_subset: Ontology subset + + Returns: + Generated answer or None if failed + """ + try: + prompt = self._build_answer_prompt( + question_components, query_results, ontology_subset + ) + response = await self.prompt_service.generate_answer(prompt=prompt) + + if response and isinstance(response, dict): + return response.get('answer', '').strip() + elif isinstance(response, str): + return response.strip() + + except Exception as e: + logger.error(f"LLM answer generation failed: {e}") + + return None + + def _generate_with_template(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset) -> str: + """Generate answer using templates. + + Args: + question_components: Question analysis + query_results: Query results + ontology_subset: Ontology subset + + Returns: + Template-based answer + """ + # Handle empty results + if not self._has_results(query_results): + return self.templates['none'] + + # Handle boolean queries + if question_components.question_type == QuestionType.BOOLEAN: + if hasattr(query_results, 'ask_result'): + # SPARQL ASK result + statement = self._extract_boolean_statement(question_components) + if query_results.ask_result: + return self.templates['boolean_true'].format(statement=statement) + else: + return self.templates['boolean_false'].format(statement=statement) + else: + # Cypher boolean (check if any results) + has_results = len(query_results.records) > 0 + statement = self._extract_boolean_statement(question_components) + if has_results: + return self.templates['boolean_true'].format(statement=statement) + else: + return self.templates['boolean_false'].format(statement=statement) + + # Handle count queries + if question_components.question_type == QuestionType.AGGREGATION: + count = self._extract_count(query_results) + entity_type = self._infer_entity_type(question_components, ontology_subset) + return self.templates['count'].format(count=count, entity_type=entity_type) + + # Handle retrieval queries + if question_components.question_type == QuestionType.RETRIEVAL: + items = self._extract_items(query_results) + if len(items) == 1: + # Single result + entity = question_components.entities[0] if question_components.entities else "entity" + property_name = "value" + return self.templates['single'].format( + property=property_name, entity=entity, value=items[0] + ) + else: + # Multiple results + entity_type = self._infer_entity_type(question_components, ontology_subset) + items_str = ", ".join(items) + return self.templates['list'].format(entity_type=entity_type, items=items_str) + + # Handle factual queries + if question_components.question_type == QuestionType.FACTUAL: + facts = self._extract_facts(query_results) + return ". ".join(facts) if facts else self.templates['none'] + + # Default fallback + items = self._extract_items(query_results) + if items: + return f"Found: {', '.join(items[:5])}" + ("..." if len(items) > 5 else "") + else: + return self.templates['none'] + + def _build_answer_prompt(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset) -> str: + """Build prompt for LLM answer generation. + + Args: + question_components: Question analysis + query_results: Query results + ontology_subset: Ontology subset + + Returns: + Formatted prompt string + """ + # Format results for prompt + results_str = self._format_results_for_prompt(query_results) + + # Extract ontology context + context_classes = list(ontology_subset.classes.keys())[:5] + context_properties = list(ontology_subset.object_properties.keys())[:5] + + prompt = f"""Generate a natural language answer for the following question based on the query results. + +ORIGINAL QUESTION: {question_components.original_question} + +QUESTION TYPE: {question_components.question_type.value} +EXPECTED ANSWER: {question_components.expected_answer_type} + +ONTOLOGY CONTEXT: +- Classes: {', '.join(context_classes)} +- Properties: {', '.join(context_properties)} + +QUERY RESULTS: +{results_str} + +INSTRUCTIONS: +- Provide a clear, concise answer in natural language +- Use the original question's tone and style +- Include specific facts from the results +- If no results, explain that no information was found +- Be accurate and don't make assumptions beyond the data +- Limit response to 2-3 sentences unless the question requires more detail + +ANSWER:""" + + return prompt + + def _format_results_for_prompt(self, query_results: Union[SPARQLResult, CypherResult]) -> str: + """Format query results for prompt inclusion. + + Args: + query_results: Query results to format + + Returns: + Formatted results string + """ + if isinstance(query_results, SPARQLResult): + if hasattr(query_results, 'ask_result') and query_results.ask_result is not None: + return f"Boolean result: {query_results.ask_result}" + + if not query_results.bindings: + return "No results found" + + # Format SPARQL bindings + lines = [] + for binding in query_results.bindings[:10]: # Limit to first 10 + formatted = [] + for var, value in binding.items(): + if isinstance(value, dict): + formatted.append(f"{var}: {value.get('value', value)}") + else: + formatted.append(f"{var}: {value}") + lines.append("- " + ", ".join(formatted)) + + if len(query_results.bindings) > 10: + lines.append(f"... and {len(query_results.bindings) - 10} more results") + + return "\n".join(lines) + + else: # CypherResult + if not query_results.records: + return "No results found" + + # Format Cypher records + lines = [] + for record in query_results.records[:10]: # Limit to first 10 + if isinstance(record, dict): + formatted = [f"{k}: {v}" for k, v in record.items()] + lines.append("- " + ", ".join(formatted)) + else: + lines.append(f"- {record}") + + if len(query_results.records) > 10: + lines.append(f"... and {len(query_results.records) - 10} more results") + + return "\n".join(lines) + + def _has_results(self, query_results: Union[SPARQLResult, CypherResult]) -> bool: + """Check if query results contain data. + + Args: + query_results: Query results to check + + Returns: + True if results contain data + """ + if isinstance(query_results, SPARQLResult): + return bool(query_results.bindings) or query_results.ask_result is not None + else: # CypherResult + return bool(query_results.records) + + def _extract_count(self, query_results: Union[SPARQLResult, CypherResult]) -> int: + """Extract count from aggregation query results. + + Args: + query_results: Query results + + Returns: + Count value + """ + if isinstance(query_results, SPARQLResult): + if query_results.bindings: + binding = query_results.bindings[0] + # Look for count variable + for var, value in binding.items(): + if 'count' in var.lower(): + if isinstance(value, dict): + return int(value.get('value', 0)) + return int(value) + return len(query_results.bindings) + else: # CypherResult + if query_results.records: + record = query_results.records[0] + if isinstance(record, dict): + # Look for count key + for key, value in record.items(): + if 'count' in key.lower(): + return int(value) + elif isinstance(record, (int, float)): + return int(record) + return len(query_results.records) + + def _extract_items(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]: + """Extract items from query results. + + Args: + query_results: Query results + + Returns: + List of extracted items + """ + items = [] + + if isinstance(query_results, SPARQLResult): + for binding in query_results.bindings: + for var, value in binding.items(): + if isinstance(value, dict): + item_value = value.get('value', str(value)) + else: + item_value = str(value) + + # Clean up URIs + if item_value.startswith('http'): + item_value = item_value.split('/')[-1].split('#')[-1] + + items.append(item_value) + break # Take first value per binding + + else: # CypherResult + for record in query_results.records: + if isinstance(record, dict): + # Take first value from record + for key, value in record.items(): + items.append(str(value)) + break + else: + items.append(str(record)) + + return items + + def _extract_facts(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]: + """Extract facts from query results. + + Args: + query_results: Query results + + Returns: + List of facts + """ + facts = [] + + if isinstance(query_results, SPARQLResult): + for binding in query_results.bindings: + fact_parts = [] + for var, value in binding.items(): + if isinstance(value, dict): + val_str = value.get('value', str(value)) + else: + val_str = str(value) + + # Clean up URIs + if val_str.startswith('http'): + val_str = val_str.split('/')[-1].split('#')[-1] + + fact_parts.append(f"{var}: {val_str}") + + facts.append(", ".join(fact_parts)) + + else: # CypherResult + for record in query_results.records: + if isinstance(record, dict): + fact_parts = [f"{k}: {v}" for k, v in record.items()] + facts.append(", ".join(fact_parts)) + else: + facts.append(str(record)) + + return facts + + def _extract_boolean_statement(self, question_components: QuestionComponents) -> str: + """Extract statement for boolean answer. + + Args: + question_components: Question analysis + + Returns: + Statement string + """ + # Extract the key assertion from the question + question = question_components.original_question.lower() + + # Remove question words + statement = question.replace('is ', '').replace('are ', '').replace('does ', '') + statement = statement.replace('?', '').strip() + + return statement + + def _infer_entity_type(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> str: + """Infer entity type from question and ontology. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Entity type string + """ + # Try to match entities to ontology classes + for entity in question_components.entities: + entity_lower = entity.lower() + for class_id in ontology_subset.classes: + if class_id.lower() == entity_lower or entity_lower in class_id.lower(): + return class_id + + # Fallback to first entity or generic term + if question_components.entities: + return question_components.entities[0] + else: + return "entities" + + def _build_answer_response(self, + answer: str, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + backend_used: str, + execution_time: float, + confidence: float = 0.8) -> GeneratedAnswer: + """Build final answer response. + + Args: + answer: Generated answer text + question_components: Question analysis + query_results: Query results + backend_used: Backend used for query + execution_time: Answer generation time + confidence: Confidence score + + Returns: + Complete answer response + """ + # Extract supporting facts + supporting_facts = self._extract_facts(query_results) + + # Build metadata + result_count = 0 + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: # CypherResult + result_count = len(query_results.records) + + metadata = AnswerMetadata( + query_type=question_components.question_type.value, + backend_used=backend_used, + execution_time=execution_time, + result_count=result_count, + confidence=confidence, + explanation=f"Generated answer using {backend_used} backend", + sources=[] # Could be populated with data source information + ) + + return GeneratedAnswer( + answer=answer, + metadata=metadata, + supporting_facts=supporting_facts[:5], # Limit to top 5 + raw_results=query_results, + generation_time=execution_time + ) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/backend_router.py b/trustgraph-flow/trustgraph/query/ontology/backend_router.py new file mode 100644 index 00000000..cbd23530 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/backend_router.py @@ -0,0 +1,350 @@ +""" +Backend router for ontology query system. +Routes queries to appropriate backend based on configuration. +""" + +import logging +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +from enum import Enum + +from .question_analyzer import QuestionComponents +from .ontology_matcher import QueryOntologySubset + +logger = logging.getLogger(__name__) + + +class BackendType(Enum): + """Supported backend types.""" + CASSANDRA = "cassandra" + NEO4J = "neo4j" + MEMGRAPH = "memgraph" + FALKORDB = "falkordb" + + +@dataclass +class BackendConfig: + """Configuration for a backend.""" + type: BackendType + priority: int = 0 + enabled: bool = True + config: Dict[str, Any] = None + + +@dataclass +class QueryRoute: + """Routing decision for a query.""" + backend_type: BackendType + query_language: str # 'sparql' or 'cypher' + confidence: float + reasoning: str + + +class BackendRouter: + """Routes queries to appropriate backends based on configuration and heuristics.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize backend router. + + Args: + config: Router configuration + """ + self.config = config + self.backends = self._parse_backend_config(config) + self.routing_strategy = config.get('routing_strategy', 'priority') + self.enable_fallback = config.get('enable_fallback', True) + + def _parse_backend_config(self, config: Dict[str, Any]) -> Dict[BackendType, BackendConfig]: + """Parse backend configuration. + + Args: + config: Configuration dictionary + + Returns: + Dictionary of backend type to configuration + """ + backends = {} + + # Parse primary backend + primary = config.get('primary', 'cassandra') + if primary: + try: + backend_type = BackendType(primary) + backends[backend_type] = BackendConfig( + type=backend_type, + priority=100, + enabled=True, + config=config.get(primary, {}) + ) + except ValueError: + logger.warning(f"Unknown primary backend type: {primary}") + + # Parse fallback backends + fallbacks = config.get('fallback', []) + for i, fallback in enumerate(fallbacks): + try: + backend_type = BackendType(fallback) + backends[backend_type] = BackendConfig( + type=backend_type, + priority=50 - i * 10, # Decreasing priority + enabled=True, + config=config.get(fallback, {}) + ) + except ValueError: + logger.warning(f"Unknown fallback backend type: {fallback}") + + return backends + + def route_query(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> QueryRoute: + """Route a query to the best backend. + + Args: + question_components: Analyzed question + ontology_subsets: Relevant ontology subsets + + Returns: + QueryRoute with routing decision + """ + if self.routing_strategy == 'priority': + return self._route_by_priority() + elif self.routing_strategy == 'adaptive': + return self._route_adaptive(question_components, ontology_subsets) + elif self.routing_strategy == 'round_robin': + return self._route_round_robin() + else: + return self._route_by_priority() + + def _route_by_priority(self) -> QueryRoute: + """Route based on backend priority. + + Returns: + QueryRoute to highest priority backend + """ + # Find highest priority enabled backend + best_backend = None + best_priority = -1 + + for backend_type, backend_config in self.backends.items(): + if backend_config.enabled and backend_config.priority > best_priority: + best_backend = backend_type + best_priority = backend_config.priority + + if best_backend is None: + raise RuntimeError("No enabled backends available") + + # Determine query language + query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=best_backend, + query_language=query_language, + confidence=1.0, + reasoning=f"Priority routing to {best_backend.value}" + ) + + def _route_adaptive(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> QueryRoute: + """Route based on question characteristics and ontology complexity. + + Args: + question_components: Analyzed question + ontology_subsets: Relevant ontology subsets + + Returns: + QueryRoute with adaptive decision + """ + scores = {} + + for backend_type, backend_config in self.backends.items(): + if not backend_config.enabled: + continue + + score = self._calculate_backend_score( + backend_type, question_components, ontology_subsets + ) + scores[backend_type] = score + + if not scores: + raise RuntimeError("No enabled backends available") + + # Select backend with highest score + best_backend = max(scores.keys(), key=lambda k: scores[k]) + best_score = scores[best_backend] + + # Determine query language + query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=best_backend, + query_language=query_language, + confidence=best_score, + reasoning=f"Adaptive routing: {best_backend.value} scored {best_score:.2f}" + ) + + def _calculate_backend_score(self, + backend_type: BackendType, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> float: + """Calculate score for a backend based on query characteristics. + + Args: + backend_type: Backend to score + question_components: Question analysis + ontology_subsets: Ontology subsets + + Returns: + Score (0.0 to 1.0) + """ + score = 0.0 + + # Base priority score + backend_config = self.backends[backend_type] + score += backend_config.priority / 100.0 + + # Question type preferences + if backend_type == BackendType.CASSANDRA: + # SPARQL is good for hierarchical and complex reasoning + if question_components.question_type.value in ['factual', 'aggregation']: + score += 0.3 + # Good for ontology-heavy queries + if len(ontology_subsets) > 1: + score += 0.2 + else: + # Cypher is good for graph traversal and relationships + if question_components.question_type.value in ['relationship', 'retrieval']: + score += 0.3 + # Good for simple graph patterns + if len(question_components.relationships) > 0: + score += 0.2 + + # Complexity considerations + total_elements = sum( + len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties) + for subset in ontology_subsets + ) + + if backend_type == BackendType.CASSANDRA: + # SPARQL handles complex ontologies well + if total_elements > 20: + score += 0.2 + else: + # Cypher is efficient for simpler queries + if total_elements <= 10: + score += 0.2 + + # Aggregation considerations + if question_components.aggregations: + if backend_type == BackendType.CASSANDRA: + score += 0.1 # SPARQL has built-in aggregation + else: + score += 0.2 # Cypher has excellent aggregation + + return min(score, 1.0) + + def _route_round_robin(self) -> QueryRoute: + """Route using round-robin strategy. + + Returns: + QueryRoute using round-robin selection + """ + # Simple round-robin implementation + enabled_backends = [ + bt for bt, bc in self.backends.items() if bc.enabled + ] + + if not enabled_backends: + raise RuntimeError("No enabled backends available") + + # For simplicity, just return the first enabled backend + # In a real implementation, you'd track state + backend_type = enabled_backends[0] + query_language = 'sparql' if backend_type == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=backend_type, + query_language=query_language, + confidence=0.8, + reasoning=f"Round-robin routing to {backend_type.value}" + ) + + def get_fallback_route(self, failed_backend: BackendType) -> Optional[QueryRoute]: + """Get fallback route when a backend fails. + + Args: + failed_backend: Backend that failed + + Returns: + Fallback route or None if no fallback available + """ + if not self.enable_fallback: + return None + + # Find next best backend + fallback_backends = [ + (bt, bc) for bt, bc in self.backends.items() + if bc.enabled and bt != failed_backend + ] + + if not fallback_backends: + return None + + # Sort by priority + fallback_backends.sort(key=lambda x: x[1].priority, reverse=True) + fallback_type = fallback_backends[0][0] + + query_language = 'sparql' if fallback_type == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=fallback_type, + query_language=query_language, + confidence=0.7, + reasoning=f"Fallback from {failed_backend.value} to {fallback_type.value}" + ) + + def get_available_backends(self) -> List[BackendType]: + """Get list of available backends. + + Returns: + List of enabled backend types + """ + return [bt for bt, bc in self.backends.items() if bc.enabled] + + def is_backend_enabled(self, backend_type: BackendType) -> bool: + """Check if a backend is enabled. + + Args: + backend_type: Backend to check + + Returns: + True if backend is enabled + """ + backend_config = self.backends.get(backend_type) + return backend_config is not None and backend_config.enabled + + def update_backend_status(self, backend_type: BackendType, enabled: bool): + """Update backend enabled status. + + Args: + backend_type: Backend to update + enabled: New enabled status + """ + if backend_type in self.backends: + self.backends[backend_type].enabled = enabled + logger.info(f"Backend {backend_type.value} {'enabled' if enabled else 'disabled'}") + else: + logger.warning(f"Unknown backend type: {backend_type}") + + def get_backend_config(self, backend_type: BackendType) -> Optional[Dict[str, Any]]: + """Get configuration for a backend. + + Args: + backend_type: Backend type + + Returns: + Configuration dictionary or None + """ + backend_config = self.backends.get(backend_type) + return backend_config.config if backend_config else None \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/cache.py b/trustgraph-flow/trustgraph/query/ontology/cache.py new file mode 100644 index 00000000..266bd805 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/cache.py @@ -0,0 +1,651 @@ +""" +Caching system for OntoRAG query results and computations. +Provides multiple cache backends and intelligent cache management. +""" + +import logging +import time +import json +import pickle +from typing import Dict, Any, List, Optional, Union, Tuple +from dataclasses import dataclass, asdict +from datetime import datetime, timedelta +from abc import ABC, abstractmethod +from pathlib import Path +import threading + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + """Cache entry with metadata.""" + key: str + value: Any + created_at: datetime + accessed_at: datetime + access_count: int + ttl_seconds: Optional[int] = None + tags: List[str] = None + size_bytes: int = 0 + + def is_expired(self) -> bool: + """Check if cache entry is expired.""" + if self.ttl_seconds is None: + return False + return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds + + def touch(self): + """Update access time and count.""" + self.accessed_at = datetime.now() + self.access_count += 1 + + +@dataclass +class CacheStats: + """Cache performance statistics.""" + hits: int = 0 + misses: int = 0 + evictions: int = 0 + total_entries: int = 0 + total_size_bytes: int = 0 + hit_rate: float = 0.0 + + def update_hit_rate(self): + """Update hit rate calculation.""" + total_requests = self.hits + self.misses + self.hit_rate = self.hits / total_requests if total_requests > 0 else 0.0 + + +class CacheBackend(ABC): + """Abstract base class for cache backends.""" + + @abstractmethod + def get(self, key: str) -> Optional[CacheEntry]: + """Get cache entry by key.""" + pass + + @abstractmethod + def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None): + """Set cache entry.""" + pass + + @abstractmethod + def delete(self, key: str) -> bool: + """Delete cache entry.""" + pass + + @abstractmethod + def clear(self, tags: Optional[List[str]] = None): + """Clear cache entries.""" + pass + + @abstractmethod + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + pass + + @abstractmethod + def cleanup_expired(self): + """Clean up expired entries.""" + pass + + +class InMemoryCache(CacheBackend): + """In-memory cache backend.""" + + def __init__(self, max_size: int = 1000, max_size_bytes: int = 100 * 1024 * 1024): + """Initialize in-memory cache. + + Args: + max_size: Maximum number of entries + max_size_bytes: Maximum total size in bytes + """ + self.max_size = max_size + self.max_size_bytes = max_size_bytes + self.entries: Dict[str, CacheEntry] = {} + self.stats = CacheStats() + self._lock = threading.RLock() + + def get(self, key: str) -> Optional[CacheEntry]: + """Get cache entry by key.""" + with self._lock: + entry = self.entries.get(key) + if entry is None: + self.stats.misses += 1 + self.stats.update_hit_rate() + return None + + if entry.is_expired(): + del self.entries[key] + self.stats.misses += 1 + self.stats.evictions += 1 + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= entry.size_bytes + self.stats.update_hit_rate() + return None + + entry.touch() + self.stats.hits += 1 + self.stats.update_hit_rate() + return entry + + def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None): + """Set cache entry.""" + with self._lock: + # Calculate size + try: + size_bytes = len(pickle.dumps(value)) + except Exception: + size_bytes = len(str(value).encode('utf-8')) + + # Create entry + now = datetime.now() + entry = CacheEntry( + key=key, + value=value, + created_at=now, + accessed_at=now, + access_count=1, + ttl_seconds=ttl_seconds, + tags=tags or [], + size_bytes=size_bytes + ) + + # Check if we need to evict + self._ensure_capacity(size_bytes) + + # Store entry + old_entry = self.entries.get(key) + if old_entry: + self.stats.total_size_bytes -= old_entry.size_bytes + else: + self.stats.total_entries += 1 + + self.entries[key] = entry + self.stats.total_size_bytes += size_bytes + + def delete(self, key: str) -> bool: + """Delete cache entry.""" + with self._lock: + entry = self.entries.pop(key, None) + if entry: + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= entry.size_bytes + self.stats.evictions += 1 + return True + return False + + def clear(self, tags: Optional[List[str]] = None): + """Clear cache entries.""" + with self._lock: + if tags is None: + # Clear all + self.stats.evictions += len(self.entries) + self.entries.clear() + self.stats.total_entries = 0 + self.stats.total_size_bytes = 0 + else: + # Clear by tags + to_delete = [] + for key, entry in self.entries.items(): + if any(tag in entry.tags for tag in tags): + to_delete.append(key) + + for key in to_delete: + self.delete(key) + + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + with self._lock: + return CacheStats( + hits=self.stats.hits, + misses=self.stats.misses, + evictions=self.stats.evictions, + total_entries=self.stats.total_entries, + total_size_bytes=self.stats.total_size_bytes, + hit_rate=self.stats.hit_rate + ) + + def cleanup_expired(self): + """Clean up expired entries.""" + with self._lock: + to_delete = [] + for key, entry in self.entries.items(): + if entry.is_expired(): + to_delete.append(key) + + for key in to_delete: + self.delete(key) + + def _ensure_capacity(self, new_size_bytes: int): + """Ensure cache has capacity for new entry.""" + # Check size limit + if self.stats.total_size_bytes + new_size_bytes > self.max_size_bytes: + self._evict_by_size(new_size_bytes) + + # Check count limit + if len(self.entries) >= self.max_size: + self._evict_by_count() + + def _evict_by_size(self, needed_bytes: int): + """Evict entries to free up space.""" + # Sort by access time (LRU) + sorted_entries = sorted( + self.entries.items(), + key=lambda x: (x[1].accessed_at, x[1].access_count) + ) + + freed_bytes = 0 + for key, entry in sorted_entries: + if freed_bytes >= needed_bytes: + break + freed_bytes += entry.size_bytes + del self.entries[key] + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= entry.size_bytes + self.stats.evictions += 1 + + def _evict_by_count(self): + """Evict least recently used entry.""" + if not self.entries: + return + + # Find LRU entry + lru_key = min( + self.entries.keys(), + key=lambda k: (self.entries[k].accessed_at, self.entries[k].access_count) + ) + self.delete(lru_key) + + +class FileCache(CacheBackend): + """File-based cache backend.""" + + def __init__(self, cache_dir: str, max_files: int = 10000): + """Initialize file cache. + + Args: + cache_dir: Directory to store cache files + max_files: Maximum number of cache files + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.max_files = max_files + self.stats = CacheStats() + self._lock = threading.RLock() + + # Load existing stats + self._load_stats() + + def get(self, key: str) -> Optional[CacheEntry]: + """Get cache entry by key.""" + with self._lock: + cache_file = self.cache_dir / f"{self._safe_key(key)}.cache" + if not cache_file.exists(): + self.stats.misses += 1 + self.stats.update_hit_rate() + return None + + try: + with open(cache_file, 'rb') as f: + entry = pickle.load(f) + + if entry.is_expired(): + cache_file.unlink() + self.stats.misses += 1 + self.stats.evictions += 1 + self.stats.total_entries -= 1 + self.stats.update_hit_rate() + return None + + entry.touch() + # Update file modification time + cache_file.touch() + + self.stats.hits += 1 + self.stats.update_hit_rate() + return entry + + except Exception as e: + logger.error(f"Error reading cache file {cache_file}: {e}") + cache_file.unlink(missing_ok=True) + self.stats.misses += 1 + self.stats.update_hit_rate() + return None + + def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None): + """Set cache entry.""" + with self._lock: + cache_file = self.cache_dir / f"{self._safe_key(key)}.cache" + + # Create entry + now = datetime.now() + entry = CacheEntry( + key=key, + value=value, + created_at=now, + accessed_at=now, + access_count=1, + ttl_seconds=ttl_seconds, + tags=tags or [] + ) + + try: + # Ensure capacity + self._ensure_capacity() + + # Write to file + with open(cache_file, 'wb') as f: + pickle.dump(entry, f) + + entry.size_bytes = cache_file.stat().st_size + + if not cache_file.exists(): + self.stats.total_entries += 1 + + self.stats.total_size_bytes += entry.size_bytes + self._save_stats() + + except Exception as e: + logger.error(f"Error writing cache file {cache_file}: {e}") + + def delete(self, key: str) -> bool: + """Delete cache entry.""" + with self._lock: + cache_file = self.cache_dir / f"{self._safe_key(key)}.cache" + if cache_file.exists(): + size = cache_file.stat().st_size + cache_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + self._save_stats() + return True + return False + + def clear(self, tags: Optional[List[str]] = None): + """Clear cache entries.""" + with self._lock: + if tags is None: + # Clear all + for cache_file in self.cache_dir.glob("*.cache"): + cache_file.unlink() + self.stats.evictions += self.stats.total_entries + self.stats.total_entries = 0 + self.stats.total_size_bytes = 0 + else: + # Clear by tags + for cache_file in self.cache_dir.glob("*.cache"): + try: + with open(cache_file, 'rb') as f: + entry = pickle.load(f) + if any(tag in entry.tags for tag in tags): + size = cache_file.stat().st_size + cache_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + except Exception: + continue + + self._save_stats() + + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + with self._lock: + return CacheStats( + hits=self.stats.hits, + misses=self.stats.misses, + evictions=self.stats.evictions, + total_entries=self.stats.total_entries, + total_size_bytes=self.stats.total_size_bytes, + hit_rate=self.stats.hit_rate + ) + + def cleanup_expired(self): + """Clean up expired entries.""" + with self._lock: + for cache_file in self.cache_dir.glob("*.cache"): + try: + with open(cache_file, 'rb') as f: + entry = pickle.load(f) + if entry.is_expired(): + size = cache_file.stat().st_size + cache_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + except Exception: + # Remove corrupted files + cache_file.unlink() + + self._save_stats() + + def _safe_key(self, key: str) -> str: + """Convert key to safe filename.""" + import hashlib + return hashlib.md5(key.encode()).hexdigest() + + def _ensure_capacity(self): + """Ensure cache has capacity for new entry.""" + cache_files = list(self.cache_dir.glob("*.cache")) + if len(cache_files) >= self.max_files: + # Remove oldest file + oldest_file = min(cache_files, key=lambda f: f.stat().st_mtime) + size = oldest_file.stat().st_size + oldest_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + + def _load_stats(self): + """Load statistics from file.""" + stats_file = self.cache_dir / "stats.json" + if stats_file.exists(): + try: + with open(stats_file, 'r') as f: + data = json.load(f) + self.stats = CacheStats(**data) + except Exception: + pass + + def _save_stats(self): + """Save statistics to file.""" + stats_file = self.cache_dir / "stats.json" + try: + with open(stats_file, 'w') as f: + json.dump(asdict(self.stats), f, default=str) + except Exception: + pass + + +class CacheManager: + """Cache manager with multiple backends and intelligent caching strategies.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize cache manager. + + Args: + config: Cache configuration + """ + self.config = config + self.backends: Dict[str, CacheBackend] = {} + self.default_backend = config.get('default_backend', 'memory') + self.default_ttl = config.get('default_ttl_seconds', 3600) # 1 hour + + # Initialize backends + self._init_backends() + + # Start cleanup task + self.cleanup_interval = config.get('cleanup_interval_seconds', 300) # 5 minutes + self._start_cleanup_task() + + def _init_backends(self): + """Initialize cache backends.""" + backends_config = self.config.get('backends', {}) + + # Memory backend + if 'memory' in backends_config or self.default_backend == 'memory': + memory_config = backends_config.get('memory', {}) + self.backends['memory'] = InMemoryCache( + max_size=memory_config.get('max_size', 1000), + max_size_bytes=memory_config.get('max_size_bytes', 100 * 1024 * 1024) + ) + + # File backend + if 'file' in backends_config or self.default_backend == 'file': + file_config = backends_config.get('file', {}) + self.backends['file'] = FileCache( + cache_dir=file_config.get('cache_dir', './cache'), + max_files=file_config.get('max_files', 10000) + ) + + def get(self, key: str, backend: Optional[str] = None) -> Optional[Any]: + """Get value from cache. + + Args: + key: Cache key + backend: Backend name (optional) + + Returns: + Cached value or None + """ + backend_name = backend or self.default_backend + cache_backend = self.backends.get(backend_name) + + if cache_backend is None: + logger.warning(f"Cache backend '{backend_name}' not found") + return None + + entry = cache_backend.get(key) + return entry.value if entry else None + + def set(self, + key: str, + value: Any, + ttl_seconds: Optional[int] = None, + tags: Optional[List[str]] = None, + backend: Optional[str] = None): + """Set value in cache. + + Args: + key: Cache key + value: Value to cache + ttl_seconds: Time to live in seconds + tags: Cache tags + backend: Backend name (optional) + """ + backend_name = backend or self.default_backend + cache_backend = self.backends.get(backend_name) + + if cache_backend is None: + logger.warning(f"Cache backend '{backend_name}' not found") + return + + ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl + cache_backend.set(key, value, ttl, tags) + + def delete(self, key: str, backend: Optional[str] = None) -> bool: + """Delete value from cache. + + Args: + key: Cache key + backend: Backend name (optional) + + Returns: + True if deleted + """ + backend_name = backend or self.default_backend + cache_backend = self.backends.get(backend_name) + + if cache_backend is None: + return False + + return cache_backend.delete(key) + + def clear(self, tags: Optional[List[str]] = None, backend: Optional[str] = None): + """Clear cache entries. + + Args: + tags: Tags to clear (optional) + backend: Backend name (optional) + """ + if backend: + cache_backend = self.backends.get(backend) + if cache_backend: + cache_backend.clear(tags) + else: + # Clear all backends + for cache_backend in self.backends.values(): + cache_backend.clear(tags) + + def get_stats(self) -> Dict[str, CacheStats]: + """Get statistics for all backends. + + Returns: + Dictionary of backend name to statistics + """ + return {name: backend.get_stats() for name, backend in self.backends.items()} + + def cleanup_expired(self): + """Clean up expired entries in all backends.""" + for backend in self.backends.values(): + try: + backend.cleanup_expired() + except Exception as e: + logger.error(f"Error cleaning up cache backend: {e}") + + def _start_cleanup_task(self): + """Start periodic cleanup task.""" + def cleanup_worker(): + while True: + try: + time.sleep(self.cleanup_interval) + self.cleanup_expired() + except Exception as e: + logger.error(f"Cache cleanup error: {e}") + + cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + cleanup_thread.start() + + +# Cache decorators and utilities + +def cache_result(cache_manager: CacheManager, + key_func: Optional[callable] = None, + ttl_seconds: Optional[int] = None, + tags: Optional[List[str]] = None, + backend: Optional[str] = None): + """Decorator to cache function results. + + Args: + cache_manager: Cache manager instance + key_func: Function to generate cache key + ttl_seconds: Time to live + tags: Cache tags + backend: Backend name + """ + def decorator(func): + def wrapper(*args, **kwargs): + # Generate cache key + if key_func: + cache_key = key_func(*args, **kwargs) + else: + cache_key = f"{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}" + + # Try to get from cache + cached_result = cache_manager.get(cache_key, backend) + if cached_result is not None: + return cached_result + + # Execute function + result = func(*args, **kwargs) + + # Cache result + cache_manager.set(cache_key, result, ttl_seconds, tags, backend) + + return result + + return wrapper + return decorator \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/cypher_executor.py b/trustgraph-flow/trustgraph/query/ontology/cypher_executor.py new file mode 100644 index 00000000..56e4c829 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/cypher_executor.py @@ -0,0 +1,610 @@ +""" +Cypher executor for multiple graph databases. +Executes Cypher queries against Neo4j, Memgraph, and FalkorDB. +""" + +import logging +import asyncio +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from .cypher_generator import CypherQuery + +logger = logging.getLogger(__name__) + +# Try to import various database drivers +try: + from neo4j import GraphDatabase, Driver as Neo4jDriver + NEO4J_AVAILABLE = True +except ImportError: + NEO4J_AVAILABLE = False + Neo4jDriver = None + +try: + import redis + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + + +@dataclass +class CypherResult: + """Result from Cypher query execution.""" + records: List[Dict[str, Any]] + summary: Dict[str, Any] + execution_time: float + database_type: str + query_plan: Optional[Dict[str, Any]] = None + + +class CypherExecutorBase(ABC): + """Abstract base class for Cypher executors.""" + + @abstractmethod + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query.""" + pass + + @abstractmethod + async def close(self): + """Close database connection.""" + pass + + @abstractmethod + def is_connected(self) -> bool: + """Check if connected to database.""" + pass + + +class Neo4jExecutor(CypherExecutorBase): + """Cypher executor for Neo4j database.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize Neo4j executor. + + Args: + config: Neo4j configuration + """ + if not NEO4J_AVAILABLE: + raise RuntimeError("Neo4j driver not available") + + self.config = config + self.driver: Optional[Neo4jDriver] = None + self._connection_pool_size = config.get('connection_pool_size', 10) + + async def connect(self): + """Connect to Neo4j database.""" + try: + uri = self.config.get('uri', 'bolt://localhost:7687') + username = self.config.get('username') + password = self.config.get('password') + + auth = (username, password) if username and password else None + + # Create driver with connection pool + self.driver = GraphDatabase.driver( + uri, + auth=auth, + max_connection_pool_size=self._connection_pool_size, + connection_timeout=self.config.get('connection_timeout', 30), + max_retry_time=self.config.get('max_retry_time', 15) + ) + + # Verify connectivity + await asyncio.get_event_loop().run_in_executor( + None, self.driver.verify_connectivity + ) + + logger.info(f"Connected to Neo4j at {uri}") + + except Exception as e: + logger.error(f"Failed to connect to Neo4j: {e}") + raise + + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query against Neo4j. + + Args: + cypher_query: Cypher query to execute + + Returns: + Query results + """ + if not self.driver: + await self.connect() + + import time + start_time = time.time() + + try: + # Execute query in a session + records = await asyncio.get_event_loop().run_in_executor( + None, self._execute_sync, cypher_query + ) + + execution_time = time.time() - start_time + + return CypherResult( + records=records, + summary={'record_count': len(records)}, + execution_time=execution_time, + database_type='neo4j' + ) + + except Exception as e: + logger.error(f"Neo4j query execution error: {e}") + execution_time = time.time() - start_time + return CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=execution_time, + database_type='neo4j' + ) + + def _execute_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]: + """Execute query synchronously in thread executor. + + Args: + cypher_query: Cypher query to execute + + Returns: + List of record dictionaries + """ + with self.driver.session() as session: + result = session.run(cypher_query.query, cypher_query.parameters) + records = [] + for record in result: + record_dict = {} + for key in record.keys(): + value = record[key] + record_dict[key] = self._format_neo4j_value(value) + records.append(record_dict) + return records + + def _format_neo4j_value(self, value): + """Format Neo4j value for JSON serialization. + + Args: + value: Neo4j value + + Returns: + JSON-serializable value + """ + # Handle Neo4j node objects + if hasattr(value, 'labels') and hasattr(value, 'items'): + return { + 'labels': list(value.labels), + 'properties': dict(value.items()) + } + # Handle Neo4j relationship objects + elif hasattr(value, 'type') and hasattr(value, 'items'): + return { + 'type': value.type, + 'properties': dict(value.items()) + } + # Handle Neo4j path objects + elif hasattr(value, 'nodes') and hasattr(value, 'relationships'): + return { + 'nodes': [self._format_neo4j_value(n) for n in value.nodes], + 'relationships': [self._format_neo4j_value(r) for r in value.relationships] + } + else: + return value + + async def close(self): + """Close Neo4j connection.""" + if self.driver: + await asyncio.get_event_loop().run_in_executor( + None, self.driver.close + ) + self.driver = None + logger.info("Neo4j connection closed") + + def is_connected(self) -> bool: + """Check if connected to Neo4j.""" + return self.driver is not None + + +class MemgraphExecutor(CypherExecutorBase): + """Cypher executor for Memgraph database.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize Memgraph executor. + + Args: + config: Memgraph configuration + """ + if not NEO4J_AVAILABLE: # Memgraph uses Neo4j driver + raise RuntimeError("Neo4j driver required for Memgraph") + + self.config = config + self.driver: Optional[Neo4jDriver] = None + + async def connect(self): + """Connect to Memgraph database.""" + try: + uri = self.config.get('uri', 'bolt://localhost:7688') + username = self.config.get('username') + password = self.config.get('password') + + auth = (username, password) if username and password else None + + # Memgraph uses Neo4j driver but with different defaults + self.driver = GraphDatabase.driver( + uri, + auth=auth, + max_connection_pool_size=self.config.get('connection_pool_size', 5), + connection_timeout=self.config.get('connection_timeout', 10) + ) + + # Verify connectivity + await asyncio.get_event_loop().run_in_executor( + None, self.driver.verify_connectivity + ) + + logger.info(f"Connected to Memgraph at {uri}") + + except Exception as e: + logger.error(f"Failed to connect to Memgraph: {e}") + raise + + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query against Memgraph. + + Args: + cypher_query: Cypher query to execute + + Returns: + Query results + """ + if not self.driver: + await self.connect() + + import time + start_time = time.time() + + try: + # Execute query with Memgraph-specific optimizations + records = await asyncio.get_event_loop().run_in_executor( + None, self._execute_memgraph_sync, cypher_query + ) + + execution_time = time.time() - start_time + + return CypherResult( + records=records, + summary={ + 'record_count': len(records), + 'engine': 'memgraph' + }, + execution_time=execution_time, + database_type='memgraph' + ) + + except Exception as e: + logger.error(f"Memgraph query execution error: {e}") + execution_time = time.time() - start_time + return CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=execution_time, + database_type='memgraph' + ) + + def _execute_memgraph_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]: + """Execute query synchronously for Memgraph. + + Args: + cypher_query: Cypher query to execute + + Returns: + List of record dictionaries + """ + with self.driver.session() as session: + # Add Memgraph-specific query hints if available + query = cypher_query.query + if cypher_query.database_hints and cypher_query.database_hints.get('memory_limit'): + # Memgraph supports memory limits + query = f"// Memory limit: {cypher_query.database_hints['memory_limit']}\n{query}" + + result = session.run(query, cypher_query.parameters) + records = [] + for record in result: + record_dict = {} + for key in record.keys(): + record_dict[key] = record[key] + records.append(record_dict) + return records + + async def close(self): + """Close Memgraph connection.""" + if self.driver: + await asyncio.get_event_loop().run_in_executor( + None, self.driver.close + ) + self.driver = None + logger.info("Memgraph connection closed") + + def is_connected(self) -> bool: + """Check if connected to Memgraph.""" + return self.driver is not None + + +class FalkorDBExecutor(CypherExecutorBase): + """Cypher executor for FalkorDB (Redis-based graph database).""" + + def __init__(self, config: Dict[str, Any]): + """Initialize FalkorDB executor. + + Args: + config: FalkorDB configuration + """ + if not REDIS_AVAILABLE: + raise RuntimeError("Redis driver required for FalkorDB") + + self.config = config + self.redis_client: Optional[redis.Redis] = None + self.graph_name = config.get('graph_name', 'knowledge_graph') + + async def connect(self): + """Connect to FalkorDB (Redis).""" + try: + self.redis_client = redis.Redis( + host=self.config.get('host', 'localhost'), + port=self.config.get('port', 6379), + password=self.config.get('password'), + db=self.config.get('db', 0), + decode_responses=True, + socket_connect_timeout=self.config.get('connection_timeout', 10), + socket_timeout=self.config.get('socket_timeout', 10) + ) + + # Test connection + await asyncio.get_event_loop().run_in_executor( + None, self.redis_client.ping + ) + + logger.info(f"Connected to FalkorDB at {self.config.get('host', 'localhost')}") + + except Exception as e: + logger.error(f"Failed to connect to FalkorDB: {e}") + raise + + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query against FalkorDB. + + Args: + cypher_query: Cypher query to execute + + Returns: + Query results + """ + if not self.redis_client: + await self.connect() + + import time + start_time = time.time() + + try: + # Execute query using FalkorDB's GRAPH.QUERY command + records = await asyncio.get_event_loop().run_in_executor( + None, self._execute_falkordb_sync, cypher_query + ) + + execution_time = time.time() - start_time + + return CypherResult( + records=records, + summary={ + 'record_count': len(records), + 'engine': 'falkordb' + }, + execution_time=execution_time, + database_type='falkordb' + ) + + except Exception as e: + logger.error(f"FalkorDB query execution error: {e}") + execution_time = time.time() - start_time + return CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=execution_time, + database_type='falkordb' + ) + + def _execute_falkordb_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]: + """Execute query synchronously for FalkorDB. + + Args: + cypher_query: Cypher query to execute + + Returns: + List of record dictionaries + """ + # Substitute parameters in query (FalkorDB parameter handling) + query = cypher_query.query + for param, value in cypher_query.parameters.items(): + if isinstance(value, str): + query = query.replace(f'${param}', f'"{value}"') + else: + query = query.replace(f'${param}', str(value)) + + # Execute using FalkorDB GRAPH.QUERY command + result = self.redis_client.execute_command( + 'GRAPH.QUERY', self.graph_name, query + ) + + # Parse FalkorDB result format + records = [] + if result and len(result) > 1: + # FalkorDB returns [header, data rows, statistics] + headers = result[0] if result[0] else [] + data_rows = result[1] if len(result) > 1 else [] + + for row in data_rows: + record = {} + for i, header in enumerate(headers): + if i < len(row): + record[header] = self._format_falkordb_value(row[i]) + records.append(record) + + return records + + def _format_falkordb_value(self, value): + """Format FalkorDB value for JSON serialization. + + Args: + value: FalkorDB value + + Returns: + JSON-serializable value + """ + # FalkorDB returns values in specific formats + if isinstance(value, list) and len(value) == 3: + # Check if it's a node/relationship representation + if value[0] == 1: # Node + return { + 'type': 'node', + 'labels': value[1], + 'properties': value[2] + } + elif value[0] == 2: # Relationship + return { + 'type': 'relationship', + 'rel_type': value[1], + 'properties': value[2] + } + + return value + + async def close(self): + """Close FalkorDB connection.""" + if self.redis_client: + await asyncio.get_event_loop().run_in_executor( + None, self.redis_client.close + ) + self.redis_client = None + logger.info("FalkorDB connection closed") + + def is_connected(self) -> bool: + """Check if connected to FalkorDB.""" + return self.redis_client is not None + + +class CypherExecutor: + """Multi-database Cypher executor with automatic routing.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize multi-database executor. + + Args: + config: Configuration for all database types + """ + self.config = config + self.executors: Dict[str, CypherExecutorBase] = {} + + # Initialize available executors + self._initialize_executors() + + def _initialize_executors(self): + """Initialize database executors based on configuration.""" + # Neo4j executor + if 'neo4j' in self.config and NEO4J_AVAILABLE: + try: + self.executors['neo4j'] = Neo4jExecutor(self.config['neo4j']) + logger.info("Neo4j executor initialized") + except Exception as e: + logger.error(f"Failed to initialize Neo4j executor: {e}") + + # Memgraph executor + if 'memgraph' in self.config and NEO4J_AVAILABLE: + try: + self.executors['memgraph'] = MemgraphExecutor(self.config['memgraph']) + logger.info("Memgraph executor initialized") + except Exception as e: + logger.error(f"Failed to initialize Memgraph executor: {e}") + + # FalkorDB executor + if 'falkordb' in self.config and REDIS_AVAILABLE: + try: + self.executors['falkordb'] = FalkorDBExecutor(self.config['falkordb']) + logger.info("FalkorDB executor initialized") + except Exception as e: + logger.error(f"Failed to initialize FalkorDB executor: {e}") + + if not self.executors: + raise RuntimeError("No database executors could be initialized") + + async def execute_cypher(self, cypher_query: CypherQuery, + database_type: str) -> CypherResult: + """Execute Cypher query on specified database. + + Args: + cypher_query: Cypher query to execute + database_type: Target database type + + Returns: + Query results + """ + if database_type not in self.executors: + raise ValueError(f"Database type {database_type} not available. " + f"Available: {list(self.executors.keys())}") + + executor = self.executors[database_type] + + # Ensure connection + if not executor.is_connected(): + await executor.connect() + + # Execute query + return await executor.execute(cypher_query) + + async def execute_on_all(self, cypher_query: CypherQuery) -> Dict[str, CypherResult]: + """Execute query on all available databases. + + Args: + cypher_query: Cypher query to execute + + Returns: + Results from all databases + """ + results = {} + tasks = [] + + for db_type, executor in self.executors.items(): + task = asyncio.create_task( + self.execute_cypher(cypher_query, db_type), + name=f"cypher_query_{db_type}" + ) + tasks.append((db_type, task)) + + # Wait for all tasks to complete + for db_type, task in tasks: + try: + results[db_type] = await task + except Exception as e: + logger.error(f"Query failed on {db_type}: {e}") + results[db_type] = CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=0.0, + database_type=db_type + ) + + return results + + def get_available_databases(self) -> List[str]: + """Get list of available database types. + + Returns: + List of available database type names + """ + return list(self.executors.keys()) + + async def close_all(self): + """Close all database connections.""" + for executor in self.executors.values(): + await executor.close() + logger.info("All Cypher executor connections closed") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/cypher_generator.py b/trustgraph-flow/trustgraph/query/ontology/cypher_generator.py new file mode 100644 index 00000000..8c43e964 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/cypher_generator.py @@ -0,0 +1,628 @@ +""" +Cypher query generator for ontology-sensitive queries. +Converts natural language questions to Cypher queries for graph databases. +""" + +import logging +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset + +logger = logging.getLogger(__name__) + + +@dataclass +class CypherQuery: + """Generated Cypher query with metadata.""" + query: str + parameters: Dict[str, Any] + variables: List[str] + explanation: str + complexity_score: float + database_hints: Dict[str, Any] = None # Database-specific optimization hints + + +class CypherGenerator: + """Generates Cypher queries from natural language questions using LLM assistance.""" + + def __init__(self, prompt_service=None): + """Initialize Cypher generator. + + Args: + prompt_service: Service for LLM-based query generation + """ + self.prompt_service = prompt_service + + # Cypher query templates for common patterns + self.templates = { + 'simple_node_query': """ +MATCH (n:{node_label}) +RETURN n.name AS name, n.{property} AS {property} +LIMIT {limit}""", + + 'relationship_query': """ +MATCH (a:{source_label})-[r:{relationship}]->(b:{target_label}) +WHERE a.name = $source_name +RETURN b.name AS name, r.{rel_property} AS property""", + + 'path_query': """ +MATCH path = (start:{start_label})-[*1..{max_depth}]->(end:{end_label}) +WHERE start.name = $start_name +RETURN path, length(path) AS path_length +ORDER BY path_length""", + + 'count_query': """ +MATCH (n:{node_label}) +{where_clause} +RETURN count(n) AS count""", + + 'aggregation_query': """ +MATCH (n:{node_label}) +{where_clause} +RETURN + count(n) AS count, + avg(n.{numeric_property}) AS average, + sum(n.{numeric_property}) AS total""", + + 'boolean_query': """ +MATCH (a:{source_label})-[:{relationship}]->(b:{target_label}) +WHERE a.name = $source_name AND b.name = $target_name +RETURN count(*) > 0 AS exists""", + + 'hierarchy_query': """ +MATCH (child:{child_label})-[:SUBCLASS_OF*]->(parent:{parent_label}) +WHERE parent.name = $parent_name +RETURN child.name AS child_name, parent.name AS parent_name""", + + 'property_filter_query': """ +MATCH (n:{node_label}) +WHERE n.{property} {operator} ${property}_value +RETURN n.name AS name, n.{property} AS {property} +ORDER BY n.{property}""" + } + + async def generate_cypher(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str = "neo4j") -> CypherQuery: + """Generate Cypher query for a question. + + Args: + question_components: Analyzed question components + ontology_subset: Relevant ontology subset + database_type: Target database (neo4j, memgraph, falkordb) + + Returns: + Generated Cypher query + """ + # Try template-based generation first + template_query = self._try_template_generation( + question_components, ontology_subset, database_type + ) + if template_query: + logger.debug("Generated Cypher using template") + return template_query + + # Fall back to LLM-based generation + if self.prompt_service: + llm_query = await self._generate_with_llm( + question_components, ontology_subset, database_type + ) + if llm_query: + logger.debug("Generated Cypher using LLM") + return llm_query + + # Final fallback to simple pattern + logger.warning("Falling back to simple Cypher pattern") + return self._generate_fallback_query(question_components, ontology_subset) + + def _try_template_generation(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str) -> Optional[CypherQuery]: + """Try to generate query using templates. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + database_type: Target database type + + Returns: + Generated query or None if no template matches + """ + # Simple node query (What are the animals?) + if (question_components.question_type == QuestionType.RETRIEVAL and + len(question_components.entities) == 1): + + node_label = self._find_matching_node_label( + question_components.entities[0], ontology_subset + ) + if node_label: + query = self.templates['simple_node_query'].format( + node_label=node_label, + property='name', + limit=100 + ) + return CypherQuery( + query=query, + parameters={}, + variables=['name'], + explanation=f"Retrieve all nodes of type {node_label}", + complexity_score=0.2, + database_hints=self._get_database_hints(database_type, 'simple') + ) + + # Count query (How many animals are there?) + if (question_components.question_type == QuestionType.AGGREGATION and + 'count' in question_components.aggregations): + + node_label = self._find_matching_node_label( + question_components.entities[0] if question_components.entities else 'Entity', + ontology_subset + ) + if node_label: + where_clause = self._build_where_clause(question_components) + query = self.templates['count_query'].format( + node_label=node_label, + where_clause=where_clause + ) + return CypherQuery( + query=query, + parameters=self._extract_parameters(question_components), + variables=['count'], + explanation=f"Count nodes of type {node_label}", + complexity_score=0.3, + database_hints=self._get_database_hints(database_type, 'aggregation') + ) + + # Relationship query (Which documents were authored by John Smith?) + if (question_components.question_type == QuestionType.RETRIEVAL and + len(question_components.entities) >= 2): + + source_label = self._find_matching_node_label( + question_components.entities[1], ontology_subset + ) + target_label = self._find_matching_node_label( + question_components.entities[0], ontology_subset + ) + relationship = self._find_matching_relationship( + question_components, ontology_subset + ) + + if source_label and target_label and relationship: + query = self.templates['relationship_query'].format( + source_label=source_label, + target_label=target_label, + relationship=relationship, + rel_property='name' + ) + return CypherQuery( + query=query, + parameters={'source_name': question_components.entities[1]}, + variables=['name'], + explanation=f"Find {target_label} related to {source_label} via {relationship}", + complexity_score=0.4, + database_hints=self._get_database_hints(database_type, 'relationship') + ) + + # Boolean query (Is X related to Y?) + if question_components.question_type == QuestionType.BOOLEAN: + if len(question_components.entities) >= 2: + source_label = self._find_matching_node_label( + question_components.entities[0], ontology_subset + ) + target_label = self._find_matching_node_label( + question_components.entities[1], ontology_subset + ) + relationship = self._find_matching_relationship( + question_components, ontology_subset + ) + + if source_label and target_label and relationship: + query = self.templates['boolean_query'].format( + source_label=source_label, + target_label=target_label, + relationship=relationship + ) + return CypherQuery( + query=query, + parameters={ + 'source_name': question_components.entities[0], + 'target_name': question_components.entities[1] + }, + variables=['exists'], + explanation="Boolean check for relationship existence", + complexity_score=0.3, + database_hints=self._get_database_hints(database_type, 'boolean') + ) + + return None + + async def _generate_with_llm(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str) -> Optional[CypherQuery]: + """Generate Cypher using LLM. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + database_type: Target database type + + Returns: + Generated query or None if failed + """ + try: + prompt = self._build_cypher_prompt( + question_components, ontology_subset, database_type + ) + response = await self.prompt_service.generate_cypher(prompt=prompt) + + if response and isinstance(response, dict): + query = response.get('query', '').strip() + if query.upper().startswith(('MATCH', 'CREATE', 'MERGE', 'DELETE', 'RETURN')): + return CypherQuery( + query=query, + parameters=response.get('parameters', {}), + variables=self._extract_variables(query), + explanation=response.get('explanation', 'Generated by LLM'), + complexity_score=self._calculate_complexity(query), + database_hints=self._get_database_hints(database_type, 'complex') + ) + + except Exception as e: + logger.error(f"LLM Cypher generation failed: {e}") + + return None + + def _build_cypher_prompt(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str) -> str: + """Build prompt for LLM Cypher generation. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + database_type: Target database type + + Returns: + Formatted prompt string + """ + # Format ontology elements as node labels and relationships + node_labels = self._format_node_labels(ontology_subset.classes) + relationships = self._format_relationships( + ontology_subset.object_properties, + ontology_subset.datatype_properties + ) + + prompt = f"""Generate a Cypher query for the following question using the provided ontology. + +QUESTION: {question_components.original_question} + +TARGET DATABASE: {database_type} + +AVAILABLE NODE LABELS (from classes): +{node_labels} + +AVAILABLE RELATIONSHIP TYPES (from properties): +{relationships} + +RULES: +- Use MATCH patterns for graph traversal +- Include WHERE clauses for filters +- Use aggregation functions when needed (COUNT, SUM, AVG) +- Optimize for {database_type} performance +- Consider index hints for large datasets +- Use parameters for values (e.g., $name) + +QUERY TYPE HINTS: +- Question type: {question_components.question_type.value} +- Expected answer: {question_components.expected_answer_type} +- Entities mentioned: {', '.join(question_components.entities)} +- Aggregations: {', '.join(question_components.aggregations)} + +DATABASE-SPECIFIC OPTIMIZATIONS: +{self._get_database_specific_hints(database_type)} + +Generate a complete Cypher query with parameters:""" + + return prompt + + def _generate_fallback_query(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> CypherQuery: + """Generate simple fallback query. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Basic Cypher query + """ + # Very basic MATCH query + first_class = list(ontology_subset.classes.keys())[0] if ontology_subset.classes else 'Entity' + + query = f"""MATCH (n:{first_class}) +WHERE n.name CONTAINS $keyword +RETURN n.name AS name, labels(n) AS types +LIMIT 10""" + + return CypherQuery( + query=query, + parameters={'keyword': question_components.keywords[0] if question_components.keywords else 'entity'}, + variables=['name', 'types'], + explanation="Fallback query for basic pattern matching", + complexity_score=0.1 + ) + + def _find_matching_node_label(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]: + """Find matching node label in ontology subset. + + Args: + entity: Entity string to match + ontology_subset: Ontology subset + + Returns: + Matching node label or None + """ + entity_lower = entity.lower() + + # Direct match + for class_id in ontology_subset.classes: + if class_id.lower() == entity_lower: + return class_id + + # Label match + for class_id, class_def in ontology_subset.classes.items(): + labels = class_def.get('labels', []) + for label in labels: + if isinstance(label, dict): + label_value = label.get('value', '').lower() + if label_value == entity_lower: + return class_id + + # Partial match + for class_id in ontology_subset.classes: + if entity_lower in class_id.lower() or class_id.lower() in entity_lower: + return class_id + + return None + + def _find_matching_relationship(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[str]: + """Find matching relationship type. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Matching relationship type or None + """ + # Look for relationship keywords + for keyword in question_components.keywords: + keyword_lower = keyword.lower() + + # Check object properties + for prop_id in ontology_subset.object_properties: + if keyword_lower in prop_id.lower() or prop_id.lower() in keyword_lower: + return prop_id.upper().replace('-', '_') + + # Common relationship mappings + relationship_mappings = { + 'author': 'AUTHORED_BY', + 'created': 'CREATED_BY', + 'owns': 'OWNS', + 'has': 'HAS', + 'contains': 'CONTAINS', + 'parent': 'PARENT_OF', + 'child': 'CHILD_OF', + 'related': 'RELATED_TO' + } + + for keyword in question_components.keywords: + if keyword.lower() in relationship_mappings: + return relationship_mappings[keyword.lower()] + + # Default relationship + return 'RELATED_TO' + + def _build_where_clause(self, question_components: QuestionComponents) -> str: + """Build WHERE clause for Cypher query. + + Args: + question_components: Question analysis + + Returns: + WHERE clause string + """ + conditions = [] + + for constraint in question_components.constraints: + if 'greater than' in constraint.lower(): + import re + numbers = re.findall(r'\d+', constraint) + if numbers: + conditions.append(f"n.value > {numbers[0]}") + elif 'less than' in constraint.lower(): + numbers = re.findall(r'\d+', constraint) + if numbers: + conditions.append(f"n.value < {numbers[0]}") + + if conditions: + return f"WHERE {' AND '.join(conditions)}" + return "" + + def _extract_parameters(self, question_components: QuestionComponents) -> Dict[str, Any]: + """Extract parameters from question components. + + Args: + question_components: Question analysis + + Returns: + Parameters dictionary + """ + parameters = {} + + # Extract numeric values + import re + for constraint in question_components.constraints: + numbers = re.findall(r'\d+', constraint) + for i, number in enumerate(numbers): + parameters[f'value_{i}'] = int(number) + + return parameters + + def _format_node_labels(self, classes: Dict[str, Any]) -> str: + """Format classes as node labels for prompt. + + Args: + classes: Classes dictionary + + Returns: + Formatted node labels string + """ + if not classes: + return "None" + + lines = [] + for class_id, definition in classes.items(): + comment = definition.get('comment', '') + lines.append(f"- :{class_id} - {comment}") + + return '\n'.join(lines) + + def _format_relationships(self, + object_props: Dict[str, Any], + datatype_props: Dict[str, Any]) -> str: + """Format properties as relationships for prompt. + + Args: + object_props: Object properties + datatype_props: Datatype properties + + Returns: + Formatted relationships string + """ + lines = [] + + for prop_id, definition in object_props.items(): + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'Any') + comment = definition.get('comment', '') + rel_type = prop_id.upper().replace('-', '_') + lines.append(f"- :{rel_type} ({domain} -> {range_val}) - {comment}") + + return '\n'.join(lines) if lines else "None" + + def _extract_variables(self, query: str) -> List[str]: + """Extract variables from Cypher query. + + Args: + query: Cypher query string + + Returns: + List of variable names + """ + import re + # Extract RETURN clause variables + return_match = re.search(r'RETURN\s+(.+?)(?:ORDER|LIMIT|$)', query, re.IGNORECASE | re.DOTALL) + if return_match: + return_clause = return_match.group(1) + variables = re.findall(r'(\w+)(?:\s+AS\s+(\w+))?', return_clause) + return [var[1] if var[1] else var[0] for var in variables] + return [] + + def _calculate_complexity(self, query: str) -> float: + """Calculate complexity score for Cypher query. + + Args: + query: Cypher query string + + Returns: + Complexity score (0.0 to 1.0) + """ + complexity = 0.0 + query_upper = query.upper() + + # Count different Cypher features + if 'JOIN' in query_upper or 'UNION' in query_upper: + complexity += 0.3 + if 'WHERE' in query_upper: + complexity += 0.2 + if 'OPTIONAL' in query_upper: + complexity += 0.1 + if 'ORDER BY' in query_upper: + complexity += 0.1 + if '*' in query: # Variable length paths + complexity += 0.2 + if any(agg in query_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']): + complexity += 0.2 + + # Count path length + path_matches = re.findall(r'\[.*?\*(\d+)\.\.(\d+).*?\]', query) + for start, end in path_matches: + complexity += (int(end) - int(start)) * 0.05 + + return min(complexity, 1.0) + + def _get_database_hints(self, database_type: str, query_category: str) -> Dict[str, Any]: + """Get database-specific optimization hints. + + Args: + database_type: Target database + query_category: Category of query + + Returns: + Optimization hints + """ + hints = {} + + if database_type == "neo4j": + hints.update({ + 'use_index': True, + 'explain_plan': 'EXPLAIN', + 'profile_query': 'PROFILE' + }) + elif database_type == "memgraph": + hints.update({ + 'use_index': True, + 'explain_plan': 'EXPLAIN', + 'memory_limit': '1GB' + }) + elif database_type == "falkordb": + hints.update({ + 'use_index': False, # Redis-based, different indexing + 'cache_result': True + }) + + return hints + + def _get_database_specific_hints(self, database_type: str) -> str: + """Get database-specific optimization hints as text. + + Args: + database_type: Target database + + Returns: + Hints as formatted string + """ + if database_type == "neo4j": + return """- Use USING INDEX hints for large datasets +- Consider PROFILE for query optimization +- Prefer MERGE over CREATE when appropriate""" + elif database_type == "memgraph": + return """- Leverage in-memory processing advantages +- Use streaming for large result sets +- Consider query parallelization""" + elif database_type == "falkordb": + return """- Optimize for Redis memory constraints +- Use simple patterns for best performance +- Leverage Redis data structures when possible""" + else: + return "- Use standard Cypher optimization patterns" \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/error_handling.py b/trustgraph-flow/trustgraph/query/ontology/error_handling.py new file mode 100644 index 00000000..cc047787 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/error_handling.py @@ -0,0 +1,557 @@ +""" +Error handling and recovery mechanisms for OntoRAG. +Provides comprehensive error handling, retry logic, and graceful degradation. +""" + +import logging +import time +import asyncio +from typing import Dict, Any, List, Optional, Callable, Union, Type +from dataclasses import dataclass +from enum import Enum +from functools import wraps +import traceback + +logger = logging.getLogger(__name__) + + +class ErrorSeverity(Enum): + """Error severity levels.""" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +class ErrorCategory(Enum): + """Error categories for better handling.""" + ONTOLOGY_LOADING = "ontology_loading" + QUESTION_ANALYSIS = "question_analysis" + QUERY_GENERATION = "query_generation" + QUERY_EXECUTION = "query_execution" + ANSWER_GENERATION = "answer_generation" + BACKEND_CONNECTION = "backend_connection" + CACHE_ERROR = "cache_error" + VALIDATION_ERROR = "validation_error" + TIMEOUT_ERROR = "timeout_error" + AUTHENTICATION_ERROR = "authentication_error" + + +@dataclass +class ErrorContext: + """Context information for an error.""" + category: ErrorCategory + severity: ErrorSeverity + component: str + operation: str + user_message: Optional[str] = None + technical_details: Optional[str] = None + suggestion: Optional[str] = None + retry_count: int = 0 + max_retries: int = 3 + metadata: Dict[str, Any] = None + + +class OntoRAGError(Exception): + """Base exception for OntoRAG system.""" + + def __init__(self, + message: str, + context: Optional[ErrorContext] = None, + cause: Optional[Exception] = None): + """Initialize OntoRAG error. + + Args: + message: Error message + context: Error context + cause: Original exception that caused this error + """ + super().__init__(message) + self.message = message + self.context = context or ErrorContext( + category=ErrorCategory.VALIDATION_ERROR, + severity=ErrorSeverity.MEDIUM, + component="unknown", + operation="unknown" + ) + self.cause = cause + self.timestamp = time.time() + + +class OntologyLoadingError(OntoRAGError): + """Error loading ontology.""" + pass + + +class QuestionAnalysisError(OntoRAGError): + """Error analyzing question.""" + pass + + +class QueryGenerationError(OntoRAGError): + """Error generating query.""" + pass + + +class QueryExecutionError(OntoRAGError): + """Error executing query.""" + pass + + +class AnswerGenerationError(OntoRAGError): + """Error generating answer.""" + pass + + +class BackendConnectionError(OntoRAGError): + """Error connecting to backend.""" + pass + + +class TimeoutError(OntoRAGError): + """Operation timeout error.""" + pass + + +@dataclass +class RetryConfig: + """Configuration for retry logic.""" + max_retries: int = 3 + base_delay: float = 1.0 + max_delay: float = 60.0 + exponential_backoff: bool = True + jitter: bool = True + retry_on_exceptions: List[Type[Exception]] = None + + +class ErrorRecoveryStrategy: + """Strategy for handling and recovering from errors.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize error recovery strategy. + + Args: + config: Recovery configuration + """ + self.config = config or {} + self.retry_configs = self._build_retry_configs() + self.fallback_strategies = self._build_fallback_strategies() + self.error_counters: Dict[str, int] = {} + self.circuit_breakers: Dict[str, Dict[str, Any]] = {} + + def _build_retry_configs(self) -> Dict[ErrorCategory, RetryConfig]: + """Build retry configurations for different error categories.""" + return { + ErrorCategory.BACKEND_CONNECTION: RetryConfig( + max_retries=5, + base_delay=2.0, + retry_on_exceptions=[BackendConnectionError, ConnectionError, TimeoutError] + ), + ErrorCategory.QUERY_EXECUTION: RetryConfig( + max_retries=3, + base_delay=1.0, + retry_on_exceptions=[QueryExecutionError, TimeoutError] + ), + ErrorCategory.ONTOLOGY_LOADING: RetryConfig( + max_retries=2, + base_delay=0.5, + retry_on_exceptions=[OntologyLoadingError, IOError] + ), + ErrorCategory.QUESTION_ANALYSIS: RetryConfig( + max_retries=2, + base_delay=1.0, + retry_on_exceptions=[QuestionAnalysisError, TimeoutError] + ), + ErrorCategory.ANSWER_GENERATION: RetryConfig( + max_retries=2, + base_delay=1.0, + retry_on_exceptions=[AnswerGenerationError, TimeoutError] + ) + } + + def _build_fallback_strategies(self) -> Dict[ErrorCategory, Callable]: + """Build fallback strategies for different error categories.""" + return { + ErrorCategory.QUESTION_ANALYSIS: self._fallback_question_analysis, + ErrorCategory.QUERY_GENERATION: self._fallback_query_generation, + ErrorCategory.QUERY_EXECUTION: self._fallback_query_execution, + ErrorCategory.ANSWER_GENERATION: self._fallback_answer_generation, + ErrorCategory.BACKEND_CONNECTION: self._fallback_backend_connection + } + + async def handle_error(self, + error: Exception, + context: ErrorContext, + operation: Callable, + *args, + **kwargs) -> Any: + """Handle error with recovery strategies. + + Args: + error: The exception that occurred + context: Error context + operation: Function to retry + *args: Operation arguments + **kwargs: Operation keyword arguments + + Returns: + Result of successful operation or fallback + """ + logger.error(f"Handling error in {context.component}.{context.operation}: {error}") + + # Update error counters + error_key = f"{context.category.value}:{context.component}" + self.error_counters[error_key] = self.error_counters.get(error_key, 0) + 1 + + # Check circuit breaker + if self._is_circuit_open(error_key): + return await self._execute_fallback(context, *args, **kwargs) + + # Try retry if configured + retry_config = self.retry_configs.get(context.category) + if retry_config and context.retry_count < retry_config.max_retries: + if any(isinstance(error, exc_type) for exc_type in retry_config.retry_on_exceptions or []): + return await self._retry_operation( + operation, context, retry_config, *args, **kwargs + ) + + # Execute fallback strategy + return await self._execute_fallback(context, *args, **kwargs) + + async def _retry_operation(self, + operation: Callable, + context: ErrorContext, + retry_config: RetryConfig, + *args, + **kwargs) -> Any: + """Retry operation with backoff.""" + context.retry_count += 1 + + # Calculate delay + delay = retry_config.base_delay + if retry_config.exponential_backoff: + delay *= (2 ** (context.retry_count - 1)) + delay = min(delay, retry_config.max_delay) + + # Add jitter + if retry_config.jitter: + import random + delay *= (0.5 + random.random()) + + logger.info(f"Retrying {context.component}.{context.operation} " + f"(attempt {context.retry_count}) after {delay:.2f}s") + + await asyncio.sleep(delay) + + try: + if asyncio.iscoroutinefunction(operation): + return await operation(*args, **kwargs) + else: + return operation(*args, **kwargs) + except Exception as e: + return await self.handle_error(e, context, operation, *args, **kwargs) + + async def _execute_fallback(self, + context: ErrorContext, + *args, + **kwargs) -> Any: + """Execute fallback strategy.""" + fallback_func = self.fallback_strategies.get(context.category) + if fallback_func: + logger.info(f"Executing fallback for {context.category.value}") + try: + if asyncio.iscoroutinefunction(fallback_func): + return await fallback_func(context, *args, **kwargs) + else: + return fallback_func(context, *args, **kwargs) + except Exception as e: + logger.error(f"Fallback strategy failed: {e}") + + # Default fallback + return self._default_fallback(context) + + def _is_circuit_open(self, error_key: str) -> bool: + """Check if circuit breaker is open.""" + circuit = self.circuit_breakers.get(error_key, {}) + error_count = self.error_counters.get(error_key, 0) + error_threshold = self.config.get('circuit_breaker_threshold', 10) + window_seconds = self.config.get('circuit_breaker_window', 300) # 5 minutes + + current_time = time.time() + window_start = circuit.get('window_start', current_time) + + # Reset window if expired + if current_time - window_start > window_seconds: + self.circuit_breakers[error_key] = {'window_start': current_time} + self.error_counters[error_key] = 0 + return False + + return error_count >= error_threshold + + def _default_fallback(self, context: ErrorContext) -> Any: + """Default fallback response.""" + if context.category == ErrorCategory.ANSWER_GENERATION: + return "I'm sorry, I encountered an error while processing your question. Please try again." + elif context.category == ErrorCategory.QUERY_EXECUTION: + return {"error": "Query execution failed", "results": []} + else: + return None + + # Fallback strategy implementations + + async def _fallback_question_analysis(self, context: ErrorContext, question: str, **kwargs): + """Fallback for question analysis.""" + from .question_analyzer import QuestionComponents, QuestionType + + # Simple keyword-based analysis + question_lower = question.lower() + + # Determine question type + if any(word in question_lower for word in ['how many', 'count', 'number']): + question_type = QuestionType.AGGREGATION + elif question_lower.startswith(('is', 'are', 'does', 'can')): + question_type = QuestionType.BOOLEAN + elif any(word in question_lower for word in ['what', 'which', 'who', 'where']): + question_type = QuestionType.RETRIEVAL + else: + question_type = QuestionType.FACTUAL + + # Extract simple entities (nouns) + import re + words = re.findall(r'\b[a-zA-Z]+\b', question) + entities = [word for word in words if len(word) > 3 and word.lower() not in + {'what', 'which', 'where', 'when', 'who', 'how', 'does', 'are', 'the'}] + + return QuestionComponents( + original_question=question, + normalized_question=question.lower(), + question_type=question_type, + entities=entities[:3], # Limit to 3 entities + keywords=words[:5], # Limit to 5 keywords + relationships=[], + constraints=[], + aggregations=['count'] if question_type == QuestionType.AGGREGATION else [], + expected_answer_type='text' + ) + + async def _fallback_query_generation(self, context: ErrorContext, **kwargs): + """Fallback for query generation.""" + # Generate simple query based on available information + if 'sparql' in context.metadata.get('query_language', '').lower(): + query = """ +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subject ?predicate ?object WHERE { + ?subject ?predicate ?object . +} +LIMIT 10 +""" + from .sparql_generator import SPARQLQuery + return SPARQLQuery( + query=query, + variables=['subject', 'predicate', 'object'], + query_type='SELECT', + explanation='Fallback SPARQL query', + complexity_score=0.1 + ) + else: + query = "MATCH (n) RETURN n LIMIT 10" + from .cypher_generator import CypherQuery + return CypherQuery( + query=query, + variables=['n'], + query_type='MATCH', + explanation='Fallback Cypher query', + complexity_score=0.1 + ) + + async def _fallback_query_execution(self, context: ErrorContext, **kwargs): + """Fallback for query execution.""" + # Return empty results + if 'sparql' in context.metadata.get('query_language', '').lower(): + from .sparql_cassandra import SPARQLResult + return SPARQLResult( + bindings=[], + variables=[], + execution_time=0.0 + ) + else: + from .cypher_executor import CypherResult + return CypherResult( + records=[], + summary={'type': 'fallback'}, + metadata={'query': 'fallback'}, + execution_time=0.0 + ) + + async def _fallback_answer_generation(self, context: ErrorContext, question: str = None, **kwargs): + """Fallback for answer generation.""" + fallback_messages = [ + "I'm experiencing some technical difficulties. Please try rephrasing your question.", + "I couldn't process your question at the moment. Could you try asking it differently?", + "There seems to be an issue with my analysis. Please try again in a moment.", + "I'm having trouble understanding your question right now. Please try again." + ] + + import random + return random.choice(fallback_messages) + + async def _fallback_backend_connection(self, context: ErrorContext, **kwargs): + """Fallback for backend connection.""" + logger.warning(f"Backend connection failed for {context.component}") + # Could switch to alternative backend here + return None + + +def with_error_handling(category: ErrorCategory, + component: str, + operation: str, + severity: ErrorSeverity = ErrorSeverity.MEDIUM): + """Decorator for automatic error handling. + + Args: + category: Error category + component: Component name + operation: Operation name + severity: Error severity + """ + def decorator(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + try: + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + except Exception as e: + context = ErrorContext( + category=category, + severity=severity, + component=component, + operation=operation, + technical_details=str(e), + metadata={'args': str(args), 'kwargs': str(kwargs)} + ) + + # Get error recovery strategy from first argument if it's available + error_strategy = None + if args and hasattr(args[0], '_error_strategy'): + error_strategy = args[0]._error_strategy + + if error_strategy: + return await error_strategy.handle_error(e, context, func, *args, **kwargs) + else: + # Re-raise as OntoRAG error + raise OntoRAGError( + f"Error in {component}.{operation}: {str(e)}", + context=context, + cause=e + ) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + context = ErrorContext( + category=category, + severity=severity, + component=component, + operation=operation, + technical_details=str(e), + metadata={'args': str(args), 'kwargs': str(kwargs)} + ) + + raise OntoRAGError( + f"Error in {component}.{operation}: {str(e)}", + context=context, + cause=e + ) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +class ErrorReporter: + """Reports and tracks errors for monitoring and debugging.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize error reporter. + + Args: + config: Reporter configuration + """ + self.config = config or {} + self.error_log: List[Dict[str, Any]] = [] + self.max_log_size = self.config.get('max_log_size', 1000) + + def report_error(self, error: OntoRAGError): + """Report an error for tracking. + + Args: + error: The error to report + """ + error_entry = { + 'timestamp': error.timestamp, + 'message': error.message, + 'category': error.context.category.value, + 'severity': error.context.severity.value, + 'component': error.context.component, + 'operation': error.context.operation, + 'retry_count': error.context.retry_count, + 'technical_details': error.context.technical_details, + 'stack_trace': traceback.format_exc() if error.cause else None + } + + self.error_log.append(error_entry) + + # Trim log if too large + if len(self.error_log) > self.max_log_size: + self.error_log = self.error_log[-self.max_log_size:] + + # Log based on severity + if error.context.severity == ErrorSeverity.CRITICAL: + logger.critical(f"CRITICAL ERROR: {error.message}") + elif error.context.severity == ErrorSeverity.HIGH: + logger.error(f"HIGH SEVERITY: {error.message}") + elif error.context.severity == ErrorSeverity.MEDIUM: + logger.warning(f"MEDIUM SEVERITY: {error.message}") + else: + logger.info(f"LOW SEVERITY: {error.message}") + + def get_error_summary(self) -> Dict[str, Any]: + """Get summary of recent errors. + + Returns: + Error summary statistics + """ + if not self.error_log: + return {'total_errors': 0} + + recent_errors = [ + e for e in self.error_log + if time.time() - e['timestamp'] < 3600 # Last hour + ] + + category_counts = {} + severity_counts = {} + component_counts = {} + + for error in recent_errors: + category_counts[error['category']] = category_counts.get(error['category'], 0) + 1 + severity_counts[error['severity']] = severity_counts.get(error['severity'], 0) + 1 + component_counts[error['component']] = component_counts.get(error['component'], 0) + 1 + + return { + 'total_errors': len(self.error_log), + 'recent_errors': len(recent_errors), + 'category_breakdown': category_counts, + 'severity_breakdown': severity_counts, + 'component_breakdown': component_counts, + 'most_recent_error': self.error_log[-1] if self.error_log else None + } \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/monitoring.py b/trustgraph-flow/trustgraph/query/ontology/monitoring.py new file mode 100644 index 00000000..3eac4175 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/monitoring.py @@ -0,0 +1,737 @@ +""" +Performance monitoring and metrics collection for OntoRAG. +Provides comprehensive monitoring of system performance, query patterns, and resource usage. +""" + +import logging +import time +import asyncio +import threading +from typing import Dict, Any, List, Optional, Callable +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from collections import defaultdict, deque +import statistics +from enum import Enum + +logger = logging.getLogger(__name__) + + +class MetricType(Enum): + """Types of metrics to collect.""" + COUNTER = "counter" + GAUGE = "gauge" + HISTOGRAM = "histogram" + TIMER = "timer" + + +@dataclass +class Metric: + """Individual metric data point.""" + name: str + value: float + timestamp: datetime + labels: Dict[str, str] = field(default_factory=dict) + metric_type: MetricType = MetricType.GAUGE + + +@dataclass +class TimerMetric: + """Timer metric for measuring duration.""" + name: str + start_time: float + labels: Dict[str, str] = field(default_factory=dict) + + def stop(self) -> float: + """Stop timer and return duration.""" + return time.time() - self.start_time + + +@dataclass +class PerformanceStats: + """Performance statistics for a component.""" + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + avg_response_time: float = 0.0 + min_response_time: float = float('inf') + max_response_time: float = 0.0 + p95_response_time: float = 0.0 + p99_response_time: float = 0.0 + throughput_per_second: float = 0.0 + error_rate: float = 0.0 + + +@dataclass +class SystemHealth: + """Overall system health metrics.""" + status: str = "healthy" # healthy, degraded, unhealthy + uptime_seconds: float = 0.0 + cpu_usage_percent: float = 0.0 + memory_usage_percent: float = 0.0 + active_connections: int = 0 + queue_size: int = 0 + cache_hit_rate: float = 0.0 + error_rate: float = 0.0 + + +class MetricsCollector: + """Collects and stores metrics data.""" + + def __init__(self, max_metrics: int = 10000, retention_hours: int = 24): + """Initialize metrics collector. + + Args: + max_metrics: Maximum number of metrics to retain + retention_hours: Hours to retain metrics + """ + self.max_metrics = max_metrics + self.retention_hours = retention_hours + self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_metrics)) + self.counters: Dict[str, float] = defaultdict(float) + self.gauges: Dict[str, float] = defaultdict(float) + self.timers: Dict[str, List[float]] = defaultdict(list) + self._lock = threading.RLock() + + def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None): + """Increment a counter metric. + + Args: + name: Metric name + value: Value to increment by + labels: Metric labels + """ + with self._lock: + metric_key = self._build_key(name, labels) + self.counters[metric_key] += value + self._add_metric(name, value, MetricType.COUNTER, labels) + + def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None): + """Set a gauge metric value. + + Args: + name: Metric name + value: Gauge value + labels: Metric labels + """ + with self._lock: + metric_key = self._build_key(name, labels) + self.gauges[metric_key] = value + self._add_metric(name, value, MetricType.GAUGE, labels) + + def record_timer(self, name: str, duration: float, labels: Dict[str, str] = None): + """Record a timer measurement. + + Args: + name: Metric name + duration: Duration in seconds + labels: Metric labels + """ + with self._lock: + metric_key = self._build_key(name, labels) + self.timers[metric_key].append(duration) + + # Keep only recent measurements + max_timer_values = 1000 + if len(self.timers[metric_key]) > max_timer_values: + self.timers[metric_key] = self.timers[metric_key][-max_timer_values:] + + self._add_metric(name, duration, MetricType.TIMER, labels) + + def start_timer(self, name: str, labels: Dict[str, str] = None) -> TimerMetric: + """Start a timer. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Timer metric object + """ + return TimerMetric(name=name, start_time=time.time(), labels=labels or {}) + + def stop_timer(self, timer: TimerMetric): + """Stop a timer and record the measurement. + + Args: + timer: Timer metric to stop + """ + duration = timer.stop() + self.record_timer(timer.name, duration, timer.labels) + return duration + + def get_counter(self, name: str, labels: Dict[str, str] = None) -> float: + """Get counter value. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Counter value + """ + metric_key = self._build_key(name, labels) + return self.counters.get(metric_key, 0.0) + + def get_gauge(self, name: str, labels: Dict[str, str] = None) -> float: + """Get gauge value. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Gauge value + """ + metric_key = self._build_key(name, labels) + return self.gauges.get(metric_key, 0.0) + + def get_timer_stats(self, name: str, labels: Dict[str, str] = None) -> Dict[str, float]: + """Get timer statistics. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Timer statistics + """ + metric_key = self._build_key(name, labels) + values = self.timers.get(metric_key, []) + + if not values: + return {} + + sorted_values = sorted(values) + return { + 'count': len(values), + 'sum': sum(values), + 'avg': statistics.mean(values), + 'min': min(values), + 'max': max(values), + 'p50': sorted_values[int(len(sorted_values) * 0.5)], + 'p95': sorted_values[int(len(sorted_values) * 0.95)], + 'p99': sorted_values[int(len(sorted_values) * 0.99)] + } + + def get_metrics(self, + name_pattern: Optional[str] = None, + since: Optional[datetime] = None) -> List[Metric]: + """Get metrics matching pattern and time range. + + Args: + name_pattern: Pattern to match metric names + since: Only return metrics since this time + + Returns: + List of matching metrics + """ + with self._lock: + results = [] + cutoff_time = since or datetime.now() - timedelta(hours=self.retention_hours) + + for metric_name, metric_queue in self.metrics.items(): + if name_pattern and name_pattern not in metric_name: + continue + + for metric in metric_queue: + if metric.timestamp >= cutoff_time: + results.append(metric) + + return sorted(results, key=lambda m: m.timestamp) + + def cleanup_old_metrics(self): + """Remove old metrics beyond retention period.""" + with self._lock: + cutoff_time = datetime.now() - timedelta(hours=self.retention_hours) + + for metric_name in list(self.metrics.keys()): + metric_queue = self.metrics[metric_name] + # Remove old metrics + while metric_queue and metric_queue[0].timestamp < cutoff_time: + metric_queue.popleft() + + # Remove empty queues + if not metric_queue: + del self.metrics[metric_name] + + def _add_metric(self, name: str, value: float, metric_type: MetricType, labels: Dict[str, str]): + """Add metric to storage.""" + metric = Metric( + name=name, + value=value, + timestamp=datetime.now(), + labels=labels or {}, + metric_type=metric_type + ) + self.metrics[name].append(metric) + + def _build_key(self, name: str, labels: Dict[str, str]) -> str: + """Build metric key from name and labels.""" + if not labels: + return name + + label_str = ','.join(f"{k}={v}" for k, v in sorted(labels.items())) + return f"{name}{{{label_str}}}" + + +class PerformanceMonitor: + """Monitors system performance and component health.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize performance monitor. + + Args: + config: Monitor configuration + """ + self.config = config or {} + self.metrics_collector = MetricsCollector( + max_metrics=self.config.get('max_metrics', 10000), + retention_hours=self.config.get('retention_hours', 24) + ) + + self.component_stats: Dict[str, PerformanceStats] = {} + self.start_time = time.time() + self.monitoring_enabled = self.config.get('enabled', True) + + # Start background monitoring tasks + if self.monitoring_enabled: + self._start_background_tasks() + + def record_request(self, + component: str, + operation: str, + duration: float, + success: bool = True, + labels: Dict[str, str] = None): + """Record a request completion. + + Args: + component: Component name + operation: Operation name + duration: Request duration in seconds + success: Whether request was successful + labels: Additional labels + """ + if not self.monitoring_enabled: + return + + base_labels = {'component': component, 'operation': operation} + if labels: + base_labels.update(labels) + + # Record metrics + self.metrics_collector.increment('requests_total', labels=base_labels) + self.metrics_collector.record_timer('request_duration', duration, base_labels) + + if success: + self.metrics_collector.increment('requests_successful', labels=base_labels) + else: + self.metrics_collector.increment('requests_failed', labels=base_labels) + + # Update component stats + self._update_component_stats(component, duration, success) + + def record_query_complexity(self, + complexity_score: float, + query_type: str, + backend: str): + """Record query complexity metrics. + + Args: + complexity_score: Query complexity score (0.0 to 1.0) + query_type: Type of query (SPARQL, Cypher) + backend: Backend used + """ + if not self.monitoring_enabled: + return + + labels = {'query_type': query_type, 'backend': backend} + self.metrics_collector.set_gauge('query_complexity', complexity_score, labels) + + def record_cache_access(self, hit: bool, cache_type: str = 'default'): + """Record cache access. + + Args: + hit: Whether it was a cache hit + cache_type: Type of cache + """ + if not self.monitoring_enabled: + return + + labels = {'cache_type': cache_type} + self.metrics_collector.increment('cache_requests_total', labels=labels) + + if hit: + self.metrics_collector.increment('cache_hits_total', labels=labels) + else: + self.metrics_collector.increment('cache_misses_total', labels=labels) + + def record_ontology_selection(self, + selected_elements: int, + total_elements: int, + ontology_id: str): + """Record ontology selection metrics. + + Args: + selected_elements: Number of selected ontology elements + total_elements: Total ontology elements + ontology_id: Ontology identifier + """ + if not self.monitoring_enabled: + return + + labels = {'ontology_id': ontology_id} + self.metrics_collector.set_gauge('ontology_elements_selected', selected_elements, labels) + self.metrics_collector.set_gauge('ontology_elements_total', total_elements, labels) + + selection_ratio = selected_elements / total_elements if total_elements > 0 else 0 + self.metrics_collector.set_gauge('ontology_selection_ratio', selection_ratio, labels) + + def get_component_stats(self, component: str) -> Optional[PerformanceStats]: + """Get performance statistics for a component. + + Args: + component: Component name + + Returns: + Performance statistics or None + """ + return self.component_stats.get(component) + + def get_system_health(self) -> SystemHealth: + """Get overall system health status. + + Returns: + System health metrics + """ + # Calculate uptime + uptime = time.time() - self.start_time + + # Get error rate + total_requests = self.metrics_collector.get_counter('requests_total') + failed_requests = self.metrics_collector.get_counter('requests_failed') + error_rate = failed_requests / total_requests if total_requests > 0 else 0.0 + + # Get cache hit rate + cache_hits = self.metrics_collector.get_counter('cache_hits_total') + cache_requests = self.metrics_collector.get_counter('cache_requests_total') + cache_hit_rate = cache_hits / cache_requests if cache_requests > 0 else 0.0 + + # Determine status + status = "healthy" + if error_rate > 0.1: # More than 10% error rate + status = "degraded" + if error_rate > 0.3: # More than 30% error rate + status = "unhealthy" + + return SystemHealth( + status=status, + uptime_seconds=uptime, + error_rate=error_rate, + cache_hit_rate=cache_hit_rate + ) + + def get_performance_report(self) -> Dict[str, Any]: + """Get comprehensive performance report. + + Returns: + Performance report + """ + report = { + 'system_health': self.get_system_health(), + 'component_stats': {}, + 'top_slow_operations': [], + 'error_patterns': {}, + 'cache_performance': {}, + 'ontology_usage': {} + } + + # Component statistics + for component, stats in self.component_stats.items(): + report['component_stats'][component] = stats + + # Top slow operations + timer_stats = {} + for metric_name in self.metrics_collector.timers.keys(): + if 'request_duration' in metric_name: + stats = self.metrics_collector.get_timer_stats(metric_name) + if stats: + timer_stats[metric_name] = stats + + # Sort by p95 latency + slow_ops = sorted( + timer_stats.items(), + key=lambda x: x[1].get('p95', 0), + reverse=True + )[:10] + + report['top_slow_operations'] = [ + {'operation': op, 'stats': stats} for op, stats in slow_ops + ] + + # Cache performance + cache_types = set() + for metric_name in self.metrics_collector.counters.keys(): + if 'cache_type=' in metric_name: + cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0] + cache_types.add(cache_type) + + for cache_type in cache_types: + labels = {'cache_type': cache_type} + hits = self.metrics_collector.get_counter('cache_hits_total', labels) + requests = self.metrics_collector.get_counter('cache_requests_total', labels) + hit_rate = hits / requests if requests > 0 else 0.0 + + report['cache_performance'][cache_type] = { + 'hit_rate': hit_rate, + 'total_requests': requests, + 'total_hits': hits + } + + return report + + def _update_component_stats(self, component: str, duration: float, success: bool): + """Update component performance statistics.""" + if component not in self.component_stats: + self.component_stats[component] = PerformanceStats() + + stats = self.component_stats[component] + stats.total_requests += 1 + + if success: + stats.successful_requests += 1 + else: + stats.failed_requests += 1 + + # Update response time stats + stats.min_response_time = min(stats.min_response_time, duration) + stats.max_response_time = max(stats.max_response_time, duration) + + # Get timer stats for percentiles + timer_stats = self.metrics_collector.get_timer_stats( + 'request_duration', {'component': component} + ) + + if timer_stats: + stats.avg_response_time = timer_stats.get('avg', 0.0) + stats.p95_response_time = timer_stats.get('p95', 0.0) + stats.p99_response_time = timer_stats.get('p99', 0.0) + + # Calculate rates + stats.error_rate = stats.failed_requests / stats.total_requests + + # Calculate throughput (requests per second over last minute) + recent_requests = len([ + m for m in self.metrics_collector.get_metrics('requests_total') + if m.labels.get('component') == component and + m.timestamp > datetime.now() - timedelta(minutes=1) + ]) + stats.throughput_per_second = recent_requests / 60.0 + + def _start_background_tasks(self): + """Start background monitoring tasks.""" + def cleanup_worker(): + """Worker to clean up old metrics.""" + while True: + try: + time.sleep(300) # 5 minutes + self.metrics_collector.cleanup_old_metrics() + except Exception as e: + logger.error(f"Metrics cleanup error: {e}") + + cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + cleanup_thread.start() + + +# Monitoring decorators + +def monitor_performance(component: str, + operation: str, + monitor: Optional[PerformanceMonitor] = None): + """Decorator to monitor function performance. + + Args: + component: Component name + operation: Operation name + monitor: Performance monitor instance + """ + def decorator(func): + def wrapper(*args, **kwargs): + if not monitor or not monitor.monitoring_enabled: + return func(*args, **kwargs) + + timer = monitor.metrics_collector.start_timer( + 'request_duration', + {'component': component, 'operation': operation} + ) + + success = True + try: + result = func(*args, **kwargs) + return result + except Exception as e: + success = False + raise + finally: + duration = monitor.metrics_collector.stop_timer(timer) + monitor.record_request(component, operation, duration, success) + + async def async_wrapper(*args, **kwargs): + if not monitor or not monitor.monitoring_enabled: + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + + timer = monitor.metrics_collector.start_timer( + 'request_duration', + {'component': component, 'operation': operation} + ) + + success = True + try: + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + return result + except Exception as e: + success = False + raise + finally: + duration = monitor.metrics_collector.stop_timer(timer) + monitor.record_request(component, operation, duration, success) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return wrapper + + return decorator + + +class QueryPatternAnalyzer: + """Analyzes query patterns for optimization insights.""" + + def __init__(self, monitor: PerformanceMonitor): + """Initialize query pattern analyzer. + + Args: + monitor: Performance monitor instance + """ + self.monitor = monitor + self.query_patterns: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + + def record_query_pattern(self, + question_type: str, + entities: List[str], + complexity: float, + backend: str, + duration: float, + success: bool): + """Record a query pattern for analysis. + + Args: + question_type: Type of question + entities: Entities in question + complexity: Query complexity score + backend: Backend used + duration: Query duration + success: Whether query succeeded + """ + pattern = { + 'timestamp': datetime.now(), + 'question_type': question_type, + 'entity_count': len(entities), + 'entities': entities, + 'complexity': complexity, + 'backend': backend, + 'duration': duration, + 'success': success + } + + pattern_key = f"{question_type}:{len(entities)}" + self.query_patterns[pattern_key].append(pattern) + + # Keep only recent patterns + cutoff_time = datetime.now() - timedelta(hours=24) + self.query_patterns[pattern_key] = [ + p for p in self.query_patterns[pattern_key] + if p['timestamp'] > cutoff_time + ] + + def get_optimization_insights(self) -> Dict[str, Any]: + """Get insights for query optimization. + + Returns: + Optimization insights and recommendations + """ + insights = { + 'slow_patterns': [], + 'common_failures': [], + 'backend_performance': {}, + 'complexity_analysis': {}, + 'recommendations': [] + } + + # Analyze slow patterns + for pattern_key, patterns in self.query_patterns.items(): + if not patterns: + continue + + avg_duration = statistics.mean([p['duration'] for p in patterns]) + success_rate = sum(1 for p in patterns if p['success']) / len(patterns) + + if avg_duration > 5.0: # Slow queries > 5 seconds + insights['slow_patterns'].append({ + 'pattern': pattern_key, + 'avg_duration': avg_duration, + 'count': len(patterns), + 'success_rate': success_rate + }) + + if success_rate < 0.8: # Low success rate + insights['common_failures'].append({ + 'pattern': pattern_key, + 'success_rate': success_rate, + 'count': len(patterns) + }) + + # Analyze backend performance + backend_stats = defaultdict(list) + for patterns in self.query_patterns.values(): + for pattern in patterns: + backend_stats[pattern['backend']].append(pattern['duration']) + + for backend, durations in backend_stats.items(): + insights['backend_performance'][backend] = { + 'avg_duration': statistics.mean(durations), + 'p95_duration': sorted(durations)[int(len(durations) * 0.95)], + 'query_count': len(durations) + } + + # Generate recommendations + recommendations = [] + + # Slow pattern recommendations + for slow_pattern in insights['slow_patterns']: + recommendations.append( + f"Consider optimizing {slow_pattern['pattern']} queries - " + f"average duration {slow_pattern['avg_duration']:.2f}s" + ) + + # Backend recommendations + if len(insights['backend_performance']) > 1: + fastest_backend = min( + insights['backend_performance'].items(), + key=lambda x: x[1]['avg_duration'] + )[0] + recommendations.append( + f"Consider routing more queries to {fastest_backend} " + f"for better performance" + ) + + insights['recommendations'] = recommendations + + return insights \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/multi_language.py b/trustgraph-flow/trustgraph/query/ontology/multi_language.py new file mode 100644 index 00000000..d7b7883a --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/multi_language.py @@ -0,0 +1,656 @@ +""" +Multi-language support for OntoRAG. +Provides language detection, translation, and multilingual query processing. +""" + +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import dataclass +from enum import Enum + +logger = logging.getLogger(__name__) + + +class Language(Enum): + """Supported languages.""" + ENGLISH = "en" + SPANISH = "es" + FRENCH = "fr" + GERMAN = "de" + ITALIAN = "it" + PORTUGUESE = "pt" + CHINESE = "zh" + JAPANESE = "ja" + KOREAN = "ko" + ARABIC = "ar" + RUSSIAN = "ru" + DUTCH = "nl" + + +@dataclass +class LanguageDetectionResult: + """Language detection result.""" + language: Language + confidence: float + detected_text: str + alternative_languages: List[Tuple[Language, float]] = None + + +@dataclass +class TranslationResult: + """Translation result.""" + original_text: str + translated_text: str + source_language: Language + target_language: Language + confidence: float + + +class LanguageDetector: + """Detects language of input text.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize language detector. + + Args: + config: Detector configuration + """ + self.config = config or {} + self.default_language = Language(self.config.get('default_language', 'en')) + self.confidence_threshold = self.config.get('confidence_threshold', 0.7) + + # Try to import language detection libraries + self.detector = None + self._init_detector() + + def _init_detector(self): + """Initialize language detection backend.""" + try: + # Try langdetect first + import langdetect + self.detector = 'langdetect' + logger.info("Using langdetect for language detection") + except ImportError: + try: + # Try textblob as fallback + from textblob import TextBlob + self.detector = 'textblob' + logger.info("Using TextBlob for language detection") + except ImportError: + logger.warning("No language detection library available, using rule-based detection") + self.detector = 'rule_based' + + def detect_language(self, text: str) -> LanguageDetectionResult: + """Detect language of input text. + + Args: + text: Text to analyze + + Returns: + Language detection result + """ + if not text or not text.strip(): + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + try: + if self.detector == 'langdetect': + return self._detect_with_langdetect(text) + elif self.detector == 'textblob': + return self._detect_with_textblob(text) + else: + return self._detect_with_rules(text) + + except Exception as e: + logger.error(f"Language detection failed: {e}") + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + def _detect_with_langdetect(self, text: str) -> LanguageDetectionResult: + """Detect language using langdetect library.""" + import langdetect + from langdetect.lang_detect_exception import LangDetectException + + try: + # Get detailed detection results + probabilities = langdetect.detect_langs(text) + + if not probabilities: + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + best_match = probabilities[0] + detected_lang_code = best_match.lang + confidence = best_match.prob + + # Map to our Language enum + try: + detected_language = Language(detected_lang_code) + except ValueError: + # Map common variations + lang_mapping = { + 'ca': Language.SPANISH, # Catalan -> Spanish + 'eu': Language.SPANISH, # Basque -> Spanish + 'gl': Language.SPANISH, # Galician -> Spanish + 'zh-cn': Language.CHINESE, + 'zh-tw': Language.CHINESE, + } + detected_language = lang_mapping.get(detected_lang_code, self.default_language) + + # Get alternatives + alternatives = [] + for lang_prob in probabilities[1:3]: # Top 3 alternatives + try: + alt_lang = Language(lang_prob.lang) + alternatives.append((alt_lang, lang_prob.prob)) + except ValueError: + continue + + return LanguageDetectionResult( + language=detected_language, + confidence=confidence, + detected_text=text, + alternative_languages=alternatives + ) + + except LangDetectException: + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + def _detect_with_textblob(self, text: str) -> LanguageDetectionResult: + """Detect language using TextBlob.""" + from textblob import TextBlob + + try: + blob = TextBlob(text) + detected_lang_code = blob.detect_language() + + try: + detected_language = Language(detected_lang_code) + except ValueError: + detected_language = self.default_language + + # TextBlob doesn't provide confidence, so estimate based on text length + confidence = min(0.8, len(text) / 100.0) if len(text) > 10 else 0.5 + + return LanguageDetectionResult( + language=detected_language, + confidence=confidence, + detected_text=text + ) + + except Exception: + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + def _detect_with_rules(self, text: str) -> LanguageDetectionResult: + """Rule-based language detection fallback.""" + text_lower = text.lower() + + # Simple keyword-based detection + language_keywords = { + Language.SPANISH: ['qué', 'cuál', 'cuándo', 'dónde', 'cómo', 'por qué', 'cuántos'], + Language.FRENCH: ['que', 'quel', 'quand', 'où', 'comment', 'pourquoi', 'combien'], + Language.GERMAN: ['was', 'welche', 'wann', 'wo', 'wie', 'warum', 'wieviele'], + Language.ITALIAN: ['che', 'quale', 'quando', 'dove', 'come', 'perché', 'quanti'], + Language.PORTUGUESE: ['que', 'qual', 'quando', 'onde', 'como', 'por que', 'quantos'], + Language.DUTCH: ['wat', 'welke', 'wanneer', 'waar', 'hoe', 'waarom', 'hoeveel'] + } + + best_match = self.default_language + best_score = 0 + + for language, keywords in language_keywords.items(): + score = sum(1 for keyword in keywords if keyword in text_lower) + if score > best_score: + best_score = score + best_match = language + + confidence = min(0.8, best_score / 3.0) if best_score > 0 else 0.1 + + return LanguageDetectionResult( + language=best_match, + confidence=confidence, + detected_text=text + ) + + +class TextTranslator: + """Translates text between languages.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize text translator. + + Args: + config: Translator configuration + """ + self.config = config or {} + self.translator = None + self._init_translator() + + def _init_translator(self): + """Initialize translation backend.""" + try: + # Try Google Translate first + from googletrans import Translator + self.translator = Translator() + self.backend = 'googletrans' + logger.info("Using Google Translate for translation") + except ImportError: + try: + # Try TextBlob as fallback + from textblob import TextBlob + self.backend = 'textblob' + logger.info("Using TextBlob for translation") + except ImportError: + logger.warning("No translation library available") + self.backend = None + + def translate(self, + text: str, + target_language: Language, + source_language: Optional[Language] = None) -> TranslationResult: + """Translate text to target language. + + Args: + text: Text to translate + target_language: Target language + source_language: Source language (auto-detect if None) + + Returns: + Translation result + """ + if not text or not text.strip(): + return TranslationResult( + original_text=text, + translated_text=text, + source_language=source_language or Language.ENGLISH, + target_language=target_language, + confidence=0.0 + ) + + try: + if self.backend == 'googletrans': + return self._translate_with_googletrans(text, target_language, source_language) + elif self.backend == 'textblob': + return self._translate_with_textblob(text, target_language, source_language) + else: + # No translation available + return TranslationResult( + original_text=text, + translated_text=text, + source_language=source_language or Language.ENGLISH, + target_language=target_language, + confidence=0.0 + ) + + except Exception as e: + logger.error(f"Translation failed: {e}") + return TranslationResult( + original_text=text, + translated_text=text, + source_language=source_language or Language.ENGLISH, + target_language=target_language, + confidence=0.0 + ) + + def _translate_with_googletrans(self, + text: str, + target_language: Language, + source_language: Optional[Language]) -> TranslationResult: + """Translate using Google Translate.""" + try: + src_code = source_language.value if source_language else 'auto' + dest_code = target_language.value + + result = self.translator.translate(text, src=src_code, dest=dest_code) + + detected_source = Language(result.src) if result.src != 'auto' else Language.ENGLISH + confidence = 0.9 # Google Translate is generally reliable + + return TranslationResult( + original_text=text, + translated_text=result.text, + source_language=detected_source, + target_language=target_language, + confidence=confidence + ) + + except Exception as e: + logger.error(f"Google Translate error: {e}") + raise + + def _translate_with_textblob(self, + text: str, + target_language: Language, + source_language: Optional[Language]) -> TranslationResult: + """Translate using TextBlob.""" + from textblob import TextBlob + + try: + blob = TextBlob(text) + + if not source_language: + # Auto-detect source language + detected_lang = blob.detect_language() + try: + source_language = Language(detected_lang) + except ValueError: + source_language = Language.ENGLISH + + translated_blob = blob.translate(to=target_language.value) + translated_text = str(translated_blob) + + # TextBlob confidence estimation + confidence = 0.7 if len(text) > 10 else 0.5 + + return TranslationResult( + original_text=text, + translated_text=translated_text, + source_language=source_language, + target_language=target_language, + confidence=confidence + ) + + except Exception as e: + logger.error(f"TextBlob translation error: {e}") + raise + + +class MultiLanguageQueryProcessor: + """Processes queries in multiple languages.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize multi-language query processor. + + Args: + config: Processor configuration + """ + self.config = config or {} + self.language_detector = LanguageDetector(config.get('language_detection', {})) + self.translator = TextTranslator(config.get('translation', {})) + self.supported_languages = [Language(lang) for lang in config.get('supported_languages', ['en'])] + self.primary_language = Language(config.get('primary_language', 'en')) + + async def process_multilingual_query(self, question: str) -> Dict[str, Any]: + """Process a query in any supported language. + + Args: + question: Question in any language + + Returns: + Processing result with language information + """ + # Step 1: Detect language + detection_result = self.language_detector.detect_language(question) + detected_language = detection_result.language + + logger.info(f"Detected language: {detected_language.value} " + f"(confidence: {detection_result.confidence:.2f})") + + # Step 2: Translate to primary language if needed + translated_question = question + translation_result = None + + if detected_language != self.primary_language: + if detection_result.confidence >= self.language_detector.confidence_threshold: + translation_result = self.translator.translate( + question, self.primary_language, detected_language + ) + translated_question = translation_result.translated_text + logger.info(f"Translated question: {translated_question}") + else: + logger.warning(f"Low confidence language detection, processing in {self.primary_language.value}") + + # Step 3: Return processing information + return { + 'original_question': question, + 'translated_question': translated_question, + 'detected_language': detected_language, + 'detection_confidence': detection_result.confidence, + 'translation_result': translation_result, + 'processing_language': self.primary_language, + 'alternative_languages': detection_result.alternative_languages + } + + async def translate_answer(self, + answer: str, + target_language: Language) -> TranslationResult: + """Translate answer back to target language. + + Args: + answer: Answer in primary language + target_language: Target language for answer + + Returns: + Translation result + """ + if target_language == self.primary_language: + # No translation needed + return TranslationResult( + original_text=answer, + translated_text=answer, + source_language=self.primary_language, + target_language=target_language, + confidence=1.0 + ) + + return self.translator.translate(answer, target_language, self.primary_language) + + def get_language_specific_ontology_terms(self, + ontology_subset: Dict[str, Any], + language: Language) -> Dict[str, Any]: + """Get language-specific terms from ontology. + + Args: + ontology_subset: Ontology subset + language: Target language + + Returns: + Language-specific ontology terms + """ + # Extract language-specific labels and descriptions + lang_code = language.value + result = {} + + # Process classes + if 'classes' in ontology_subset: + result['classes'] = {} + for class_id, class_def in ontology_subset['classes'].items(): + lang_labels = [] + if 'labels' in class_def: + for label in class_def['labels']: + if isinstance(label, dict) and label.get('language') == lang_code: + lang_labels.append(label['value']) + elif isinstance(label, str): + lang_labels.append(label) + + result['classes'][class_id] = { + **class_def, + 'language_labels': lang_labels + } + + # Process properties + for prop_type in ['object_properties', 'datatype_properties']: + if prop_type in ontology_subset: + result[prop_type] = {} + for prop_id, prop_def in ontology_subset[prop_type].items(): + lang_labels = [] + if 'labels' in prop_def: + for label in prop_def['labels']: + if isinstance(label, dict) and label.get('language') == lang_code: + lang_labels.append(label['value']) + elif isinstance(label, str): + lang_labels.append(label) + + result[prop_type][prop_id] = { + **prop_def, + 'language_labels': lang_labels + } + + return result + + def is_language_supported(self, language: Language) -> bool: + """Check if language is supported. + + Args: + language: Language to check + + Returns: + True if language is supported + """ + return language in self.supported_languages + + def get_supported_languages(self) -> List[Language]: + """Get list of supported languages. + + Returns: + List of supported languages + """ + return self.supported_languages.copy() + + def add_language_support(self, language: Language): + """Add support for a new language. + + Args: + language: Language to add support for + """ + if language not in self.supported_languages: + self.supported_languages.append(language) + logger.info(f"Added support for language: {language.value}") + + def remove_language_support(self, language: Language): + """Remove support for a language. + + Args: + language: Language to remove support for + """ + if language in self.supported_languages and language != self.primary_language: + self.supported_languages.remove(language) + logger.info(f"Removed support for language: {language.value}") + else: + logger.warning(f"Cannot remove primary language or unsupported language: {language.value}") + + +class LanguageSpecificTemplates: + """Manages language-specific query and answer templates.""" + + def __init__(self): + """Initialize language-specific templates.""" + self.question_templates = { + Language.ENGLISH: { + 'count': ['how many', 'count of', 'number of'], + 'boolean': ['is', 'are', 'does', 'can', 'will'], + 'retrieval': ['what', 'which', 'who', 'where'], + 'factual': ['tell me about', 'describe', 'explain'] + }, + Language.SPANISH: { + 'count': ['cuántos', 'cuántas', 'número de', 'cantidad de'], + 'boolean': ['es', 'son', 'está', 'están', 'puede', 'pueden'], + 'retrieval': ['qué', 'cuál', 'cuáles', 'quién', 'dónde'], + 'factual': ['dime sobre', 'describe', 'explica'] + }, + Language.FRENCH: { + 'count': ['combien', 'nombre de', 'quantité de'], + 'boolean': ['est', 'sont', 'peut', 'peuvent'], + 'retrieval': ['que', 'quel', 'quelle', 'qui', 'où'], + 'factual': ['dis-moi sur', 'décris', 'explique'] + }, + Language.GERMAN: { + 'count': ['wie viele', 'anzahl der', 'zahl der'], + 'boolean': ['ist', 'sind', 'kann', 'können'], + 'retrieval': ['was', 'welche', 'wer', 'wo'], + 'factual': ['erzähl mir über', 'beschreibe', 'erkläre'] + } + } + + self.answer_templates = { + Language.ENGLISH: { + 'count': 'There are {count} {entity}.', + 'boolean_true': 'Yes, {statement}.', + 'boolean_false': 'No, {statement}.', + 'not_found': 'No information found.', + 'error': 'Sorry, I encountered an error.' + }, + Language.SPANISH: { + 'count': 'Hay {count} {entity}.', + 'boolean_true': 'Sí, {statement}.', + 'boolean_false': 'No, {statement}.', + 'not_found': 'No se encontró información.', + 'error': 'Lo siento, encontré un error.' + }, + Language.FRENCH: { + 'count': 'Il y a {count} {entity}.', + 'boolean_true': 'Oui, {statement}.', + 'boolean_false': 'Non, {statement}.', + 'not_found': 'Aucune information trouvée.', + 'error': 'Désolé, j\'ai rencontré une erreur.' + }, + Language.GERMAN: { + 'count': 'Es gibt {count} {entity}.', + 'boolean_true': 'Ja, {statement}.', + 'boolean_false': 'Nein, {statement}.', + 'not_found': 'Keine Informationen gefunden.', + 'error': 'Entschuldigung, ich bin auf einen Fehler gestoßen.' + } + } + + def get_question_patterns(self, language: Language) -> Dict[str, List[str]]: + """Get question patterns for a language. + + Args: + language: Target language + + Returns: + Dictionary of question patterns + """ + return self.question_templates.get(language, self.question_templates[Language.ENGLISH]) + + def get_answer_template(self, language: Language, template_type: str) -> str: + """Get answer template for a language and type. + + Args: + language: Target language + template_type: Template type + + Returns: + Answer template string + """ + templates = self.answer_templates.get(language, self.answer_templates[Language.ENGLISH]) + return templates.get(template_type, templates.get('error', 'Error')) + + def format_answer(self, + language: Language, + template_type: str, + **kwargs) -> str: + """Format answer using language-specific template. + + Args: + language: Target language + template_type: Template type + **kwargs: Template variables + + Returns: + Formatted answer + """ + template = self.get_answer_template(language, template_type) + try: + return template.format(**kwargs) + except KeyError as e: + logger.error(f"Missing template variable: {e}") + return self.get_answer_template(language, 'error') \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py new file mode 100644 index 00000000..895856f3 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py @@ -0,0 +1,256 @@ +""" +Ontology matcher for query system. +Identifies relevant ontology subsets for answering questions. +""" + +import logging +from typing import List, Dict, Any, Set, Optional +from dataclasses import dataclass + +from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader +from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder +from ...extract.kg.ontology.text_processor import TextSegment +from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset +from .question_analyzer import QuestionComponents, QuestionType + +logger = logging.getLogger(__name__) + + +@dataclass +class QueryOntologySubset(OntologySubset): + """Extended ontology subset for query processing.""" + traversal_properties: Dict[str, Any] = None # Additional properties for graph traversal + inference_rules: List[Dict[str, Any]] = None # Inference rules for reasoning + + +class OntologyMatcherForQueries(OntologySelector): + """ + Specialized ontology matcher for question answering. + Extends OntologySelector with query-specific logic. + """ + + def __init__(self, ontology_embedder: OntologyEmbedder, + ontology_loader: OntologyLoader, + top_k: int = 15, # Higher k for queries + similarity_threshold: float = 0.6): # Lower threshold for broader coverage + """Initialize query-specific ontology matcher. + + Args: + ontology_embedder: Embedder with vector store + ontology_loader: Loader with ontology definitions + top_k: Number of top results to retrieve + similarity_threshold: Minimum similarity score + """ + super().__init__(ontology_embedder, ontology_loader, top_k, similarity_threshold) + + async def match_question_to_ontology(self, + question_components: QuestionComponents, + question_segments: List[str]) -> List[QueryOntologySubset]: + """Match question components to relevant ontology elements. + + Args: + question_components: Analyzed question components + question_segments: Text segments from question + + Returns: + List of query-optimized ontology subsets + """ + # Convert question segments to TextSegment objects + text_segments = [ + TextSegment(text=seg, type='question', position=i) + for i, seg in enumerate(question_segments) + ] + + # Get base ontology subsets using parent class method + base_subsets = await self.select_ontology_subset(text_segments) + + # Enhance subsets for query processing + query_subsets = [] + for subset in base_subsets: + query_subset = self._enhance_for_query(subset, question_components) + query_subsets.append(query_subset) + + return query_subsets + + def _enhance_for_query(self, subset: OntologySubset, + question_components: QuestionComponents) -> QueryOntologySubset: + """Enhance ontology subset with query-specific elements. + + Args: + subset: Base ontology subset + question_components: Analyzed question components + + Returns: + Enhanced query ontology subset + """ + # Create query subset + query_subset = QueryOntologySubset( + ontology_id=subset.ontology_id, + classes=dict(subset.classes), + object_properties=dict(subset.object_properties), + datatype_properties=dict(subset.datatype_properties), + metadata=subset.metadata, + relevance_score=subset.relevance_score, + traversal_properties={}, + inference_rules=[] + ) + + # Add traversal properties based on question type + self._add_traversal_properties(query_subset, question_components) + + # Add related properties for exploration + self._add_related_properties(query_subset) + + # Add inference rules if needed + self._add_inference_rules(query_subset, question_components) + + return query_subset + + def _add_traversal_properties(self, subset: QueryOntologySubset, + question_components: QuestionComponents): + """Add properties useful for graph traversal. + + Args: + subset: Query ontology subset to enhance + question_components: Question analysis + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return + + # For relationship questions, add all properties connecting mentioned classes + if question_components.question_type == QuestionType.RELATIONSHIP: + for prop_id, prop_def in ontology.object_properties.items(): + domain = prop_def.domain + range_val = prop_def.range + + # Check if property connects relevant classes + if domain in subset.classes or range_val in subset.classes: + if prop_id not in subset.object_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + logger.debug(f"Added traversal property: {prop_id}") + + # For retrieval questions, add properties that might filter results + elif question_components.question_type == QuestionType.RETRIEVAL: + # Add all properties with domains in our classes + for prop_id, prop_def in ontology.object_properties.items(): + if prop_def.domain in subset.classes: + if prop_id not in subset.object_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + for prop_id, prop_def in ontology.datatype_properties.items(): + if prop_def.domain in subset.classes: + if prop_id not in subset.datatype_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + # For aggregation questions, ensure we have counting properties + elif question_components.question_type == QuestionType.AGGREGATION: + # Add properties that might be counted + for prop_id, prop_def in ontology.datatype_properties.items(): + if 'count' in prop_id.lower() or 'number' in prop_id.lower(): + if prop_id not in subset.datatype_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + def _add_related_properties(self, subset: QueryOntologySubset): + """Add properties related to already selected ones. + + Args: + subset: Query ontology subset to enhance + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return + + # Add inverse properties + for prop_id in list(subset.object_properties.keys()): + prop = ontology.object_properties.get(prop_id) + if prop and prop.inverse_of: + inverse_prop = ontology.object_properties.get(prop.inverse_of) + if inverse_prop and prop.inverse_of not in subset.object_properties: + subset.object_properties[prop.inverse_of] = inverse_prop.__dict__ + logger.debug(f"Added inverse property: {prop.inverse_of}") + + # Add sibling properties (same domain) + domains_in_subset = set() + for prop_def in subset.object_properties.values(): + if 'domain' in prop_def and prop_def['domain']: + domains_in_subset.add(prop_def['domain']) + + for domain in domains_in_subset: + for prop_id, prop_def in ontology.object_properties.items(): + if prop_def.domain == domain and prop_id not in subset.object_properties: + # Add up to 3 sibling properties + if len(subset.traversal_properties) < 3: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + def _add_inference_rules(self, subset: QueryOntologySubset, + question_components: QuestionComponents): + """Add inference rules for reasoning. + + Args: + subset: Query ontology subset to enhance + question_components: Question analysis + """ + # Add transitivity rules for subclass relationships + if any(cls.get('subclass_of') for cls in subset.classes.values()): + subset.inference_rules.append({ + 'type': 'transitivity', + 'property': 'rdfs:subClassOf', + 'description': 'Subclass relationships are transitive' + }) + + # Add symmetry rules for equivalent classes + if any(cls.get('equivalent_classes') for cls in subset.classes.values()): + subset.inference_rules.append({ + 'type': 'symmetry', + 'property': 'owl:equivalentClass', + 'description': 'Equivalent class relationships are symmetric' + }) + + # Add inverse property rules + for prop_id, prop_def in subset.object_properties.items(): + if 'inverse_of' in prop_def and prop_def['inverse_of']: + subset.inference_rules.append({ + 'type': 'inverse', + 'property': prop_id, + 'inverse': prop_def['inverse_of'], + 'description': f'{prop_id} is inverse of {prop_def["inverse_of"]}' + }) + + def expand_for_hierarchical_queries(self, subset: QueryOntologySubset) -> QueryOntologySubset: + """Expand subset to include full class hierarchies. + + Args: + subset: Query ontology subset + + Returns: + Expanded subset with complete hierarchies + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return subset + + # Add all parent and child classes + classes_to_add = set() + for class_id in list(subset.classes.keys()): + # Add all parents + parents = ontology.get_parent_classes(class_id) + for parent_id in parents: + if parent_id not in subset.classes: + parent_class = ontology.get_class(parent_id) + if parent_class: + classes_to_add.add(parent_id) + + # Add all children + for other_class_id, other_class in ontology.classes.items(): + if other_class.subclass_of == class_id and other_class_id not in subset.classes: + classes_to_add.add(other_class_id) + + # Add collected classes + for class_id in classes_to_add: + cls = ontology.get_class(class_id) + if cls: + subset.classes[class_id] = cls.__dict__ + + logger.debug(f"Expanded hierarchy: added {len(classes_to_add)} classes") + return subset \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/query_explanation.py b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py new file mode 100644 index 00000000..bd72aedc --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py @@ -0,0 +1,640 @@ +""" +Query explanation system for OntoRAG. +Provides detailed explanations of how queries are processed and results are derived. +""" + +import logging +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass, field +from datetime import datetime + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset +from .sparql_generator import SPARQLQuery +from .cypher_generator import CypherQuery +from .sparql_cassandra import SPARQLResult +from .cypher_executor import CypherResult + +logger = logging.getLogger(__name__) + + +@dataclass +class ExplanationStep: + """Individual step in query explanation.""" + step_number: int + component: str + operation: str + input_data: Dict[str, Any] + output_data: Dict[str, Any] + explanation: str + duration_ms: float + success: bool + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class QueryExplanation: + """Complete explanation of query processing.""" + query_id: str + original_question: str + processing_steps: List[ExplanationStep] + final_answer: str + confidence_score: float + total_duration_ms: float + ontologies_used: List[str] + backend_used: str + reasoning_chain: List[str] + technical_details: Dict[str, Any] + user_friendly_explanation: str + + +class QueryExplainer: + """Generates explanations for query processing.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize query explainer. + + Args: + config: Explainer configuration + """ + self.config = config or {} + self.explanation_level = self.config.get('explanation_level', 'detailed') # basic, detailed, technical + self.include_technical_details = self.config.get('include_technical_details', True) + self.max_reasoning_steps = self.config.get('max_reasoning_steps', 10) + + # Templates for different explanation types + self.step_templates = { + 'question_analysis': { + 'basic': "I analyzed your question to understand what you're asking.", + 'detailed': "I analyzed your question '{question}' and identified it as a {question_type} query about {entities}.", + 'technical': "Question analysis: Type={question_type}, Entities={entities}, Keywords={keywords}, Expected answer={answer_type}" + }, + 'ontology_matching': { + 'basic': "I found relevant knowledge about {entities} in the available ontologies.", + 'detailed': "I searched through {ontology_count} ontologies and found {selected_elements} relevant concepts related to your question.", + 'technical': "Ontology matching: Selected {classes} classes, {properties} properties from {ontologies}" + }, + 'query_generation': { + 'basic': "I generated a query to search for the information.", + 'detailed': "I created a {query_type} query using {query_language} to search the {backend} database.", + 'technical': "Query generation: {query_language} query with {variables} variables, complexity score {complexity}" + }, + 'query_execution': { + 'basic': "I searched the database and found {result_count} results.", + 'detailed': "I executed the query against the {backend} database and retrieved {result_count} results in {duration}ms.", + 'technical': "Query execution: {backend} backend, {result_count} results, execution time {duration}ms" + }, + 'answer_generation': { + 'basic': "I generated a natural language answer from the results.", + 'detailed': "I processed {result_count} results and generated an answer with {confidence}% confidence.", + 'technical': "Answer generation: {result_count} input results, {generation_method} method, confidence {confidence}" + } + } + + self.reasoning_templates = { + 'entity_identification': "I identified '{entity}' as a key concept in your question.", + 'ontology_selection': "I selected the '{ontology}' ontology because it contains relevant information about {concepts}.", + 'query_strategy': "I chose a {strategy} query approach because {reason}.", + 'result_filtering': "I filtered the results to show only the most relevant {count} items.", + 'confidence_assessment': "I'm {confidence}% confident in this answer because {reasoning}." + } + + def explain_query_processing(self, + question: str, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + generated_query: Union[SPARQLQuery, CypherQuery], + query_results: Union[SPARQLResult, CypherResult], + final_answer: str, + processing_metadata: Dict[str, Any]) -> QueryExplanation: + """Generate comprehensive explanation of query processing. + + Args: + question: Original question + question_components: Analyzed question components + ontology_subsets: Selected ontology subsets + generated_query: Generated query + query_results: Query execution results + final_answer: Final generated answer + processing_metadata: Processing metadata + + Returns: + Complete query explanation + """ + query_id = processing_metadata.get('query_id', f"query_{datetime.now().timestamp()}") + start_time = processing_metadata.get('start_time', datetime.now()) + + # Build explanation steps + steps = [] + step_number = 1 + + # Step 1: Question Analysis + steps.append(self._explain_question_analysis( + step_number, question, question_components + )) + step_number += 1 + + # Step 2: Ontology Matching + steps.append(self._explain_ontology_matching( + step_number, question_components, ontology_subsets + )) + step_number += 1 + + # Step 3: Query Generation + steps.append(self._explain_query_generation( + step_number, generated_query, processing_metadata + )) + step_number += 1 + + # Step 4: Query Execution + steps.append(self._explain_query_execution( + step_number, generated_query, query_results, processing_metadata + )) + step_number += 1 + + # Step 5: Answer Generation + steps.append(self._explain_answer_generation( + step_number, query_results, final_answer, processing_metadata + )) + + # Build reasoning chain + reasoning_chain = self._build_reasoning_chain( + question_components, ontology_subsets, generated_query, processing_metadata + ) + + # Calculate overall confidence + confidence_score = self._calculate_explanation_confidence( + question_components, query_results, processing_metadata + ) + + # Generate user-friendly explanation + user_friendly_explanation = self._generate_user_friendly_explanation( + question, question_components, ontology_subsets, final_answer + ) + + # Calculate total duration + total_duration = processing_metadata.get('total_duration_ms', 0) + + return QueryExplanation( + query_id=query_id, + original_question=question, + processing_steps=steps, + final_answer=final_answer, + confidence_score=confidence_score, + total_duration_ms=total_duration, + ontologies_used=[subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets], + backend_used=processing_metadata.get('backend_used', 'unknown'), + reasoning_chain=reasoning_chain, + technical_details=self._extract_technical_details(processing_metadata), + user_friendly_explanation=user_friendly_explanation + ) + + def _explain_question_analysis(self, + step_number: int, + question: str, + question_components: QuestionComponents) -> ExplanationStep: + """Explain question analysis step.""" + template = self.step_templates['question_analysis'][self.explanation_level] + + if self.explanation_level == 'basic': + explanation = template + elif self.explanation_level == 'detailed': + explanation = template.format( + question=question, + question_type=question_components.question_type.value.replace('_', ' '), + entities=', '.join(question_components.entities[:3]) + ) + else: # technical + explanation = template.format( + question_type=question_components.question_type.value, + entities=question_components.entities, + keywords=question_components.keywords, + answer_type=question_components.expected_answer_type + ) + + return ExplanationStep( + step_number=step_number, + component="question_analyzer", + operation="analyze_question", + input_data={"question": question}, + output_data={ + "question_type": question_components.question_type.value, + "entities": question_components.entities, + "keywords": question_components.keywords + }, + explanation=explanation, + duration_ms=0.0, # Would be tracked in actual implementation + success=True + ) + + def _explain_ontology_matching(self, + step_number: int, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> ExplanationStep: + """Explain ontology matching step.""" + template = self.step_templates['ontology_matching'][self.explanation_level] + + total_elements = sum( + len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties) + for subset in ontology_subsets + ) + + if self.explanation_level == 'basic': + explanation = template.format( + entities=', '.join(question_components.entities[:3]) + ) + elif self.explanation_level == 'detailed': + explanation = template.format( + ontology_count=len(ontology_subsets), + selected_elements=total_elements + ) + else: # technical + total_classes = sum(len(subset.classes) for subset in ontology_subsets) + total_properties = sum( + len(subset.object_properties) + len(subset.datatype_properties) + for subset in ontology_subsets + ) + ontology_names = [subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets] + + explanation = template.format( + classes=total_classes, + properties=total_properties, + ontologies=', '.join(ontology_names) + ) + + return ExplanationStep( + step_number=step_number, + component="ontology_matcher", + operation="select_relevant_subset", + input_data={"entities": question_components.entities}, + output_data={ + "ontology_count": len(ontology_subsets), + "total_elements": total_elements + }, + explanation=explanation, + duration_ms=0.0, + success=True + ) + + def _explain_query_generation(self, + step_number: int, + generated_query: Union[SPARQLQuery, CypherQuery], + metadata: Dict[str, Any]) -> ExplanationStep: + """Explain query generation step.""" + template = self.step_templates['query_generation'][self.explanation_level] + + query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher" + backend = metadata.get('backend_used', 'unknown') + + if self.explanation_level == 'basic': + explanation = template + elif self.explanation_level == 'detailed': + explanation = template.format( + query_type=generated_query.query_type, + query_language=query_language, + backend=backend + ) + else: # technical + explanation = template.format( + query_language=query_language, + variables=len(generated_query.variables), + complexity=f"{generated_query.complexity_score:.2f}" + ) + + return ExplanationStep( + step_number=step_number, + component="query_generator", + operation="generate_query", + input_data={"query_type": generated_query.query_type}, + output_data={ + "query_language": query_language, + "variables": generated_query.variables, + "complexity": generated_query.complexity_score + }, + explanation=explanation, + duration_ms=0.0, + success=True, + metadata={"generated_query": generated_query.query} + ) + + def _explain_query_execution(self, + step_number: int, + generated_query: Union[SPARQLQuery, CypherQuery], + query_results: Union[SPARQLResult, CypherResult], + metadata: Dict[str, Any]) -> ExplanationStep: + """Explain query execution step.""" + template = self.step_templates['query_execution'][self.explanation_level] + + backend = metadata.get('backend_used', 'unknown') + duration = getattr(query_results, 'execution_time', 0) * 1000 # Convert to ms + + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: # CypherResult + result_count = len(query_results.records) + + if self.explanation_level == 'basic': + explanation = template.format(result_count=result_count) + elif self.explanation_level == 'detailed': + explanation = template.format( + backend=backend, + result_count=result_count, + duration=f"{duration:.1f}" + ) + else: # technical + explanation = template.format( + backend=backend, + result_count=result_count, + duration=f"{duration:.1f}" + ) + + return ExplanationStep( + step_number=step_number, + component="query_executor", + operation="execute_query", + input_data={"query": generated_query.query}, + output_data={ + "result_count": result_count, + "execution_time_ms": duration + }, + explanation=explanation, + duration_ms=duration, + success=result_count >= 0 + ) + + def _explain_answer_generation(self, + step_number: int, + query_results: Union[SPARQLResult, CypherResult], + final_answer: str, + metadata: Dict[str, Any]) -> ExplanationStep: + """Explain answer generation step.""" + template = self.step_templates['answer_generation'][self.explanation_level] + + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: # CypherResult + result_count = len(query_results.records) + + confidence = metadata.get('answer_confidence', 0.8) * 100 # Convert to percentage + + if self.explanation_level == 'basic': + explanation = template + elif self.explanation_level == 'detailed': + explanation = template.format( + result_count=result_count, + confidence=f"{confidence:.0f}" + ) + else: # technical + generation_method = metadata.get('generation_method', 'template_based') + explanation = template.format( + result_count=result_count, + generation_method=generation_method, + confidence=f"{confidence:.1f}" + ) + + return ExplanationStep( + step_number=step_number, + component="answer_generator", + operation="generate_answer", + input_data={"result_count": result_count}, + output_data={ + "answer": final_answer, + "confidence": confidence / 100 + }, + explanation=explanation, + duration_ms=0.0, + success=bool(final_answer) + ) + + def _build_reasoning_chain(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + generated_query: Union[SPARQLQuery, CypherQuery], + metadata: Dict[str, Any]) -> List[str]: + """Build reasoning chain explaining the decision process.""" + reasoning = [] + + # Entity identification reasoning + if question_components.entities: + for entity in question_components.entities[:3]: + reasoning.append( + self.reasoning_templates['entity_identification'].format(entity=entity) + ) + + # Ontology selection reasoning + if ontology_subsets: + primary_ontology = ontology_subsets[0] + ontology_id = primary_ontology.metadata.get('ontology_id', 'primary') + concepts = list(primary_ontology.classes.keys())[:3] + reasoning.append( + self.reasoning_templates['ontology_selection'].format( + ontology=ontology_id, + concepts=', '.join(concepts) + ) + ) + + # Query strategy reasoning + query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher" + if question_components.question_type == QuestionType.AGGREGATION: + strategy = "aggregation" + reason = "you asked for a count or sum" + elif question_components.question_type == QuestionType.BOOLEAN: + strategy = "boolean" + reason = "you asked a yes/no question" + else: + strategy = "retrieval" + reason = "you asked for specific information" + + reasoning.append( + self.reasoning_templates['query_strategy'].format( + strategy=strategy, + reason=reason + ) + ) + + # Confidence assessment + confidence = metadata.get('answer_confidence', 0.8) * 100 + if confidence > 90: + confidence_reason = "the query matched well with available data" + elif confidence > 70: + confidence_reason = "the query found relevant information with some uncertainty" + else: + confidence_reason = "the available data partially matches your question" + + reasoning.append( + self.reasoning_templates['confidence_assessment'].format( + confidence=f"{confidence:.0f}", + reasoning=confidence_reason + ) + ) + + return reasoning[:self.max_reasoning_steps] + + def _calculate_explanation_confidence(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + metadata: Dict[str, Any]) -> float: + """Calculate confidence score for the explanation.""" + confidence = 0.8 # Base confidence + + # Adjust based on result count + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: + result_count = len(query_results.records) + + if result_count > 0: + confidence += 0.1 + if result_count > 5: + confidence += 0.05 + + # Adjust based on question complexity + if len(question_components.entities) > 0: + confidence += 0.05 + + # Adjust based on processing success + if metadata.get('all_steps_successful', True): + confidence += 0.05 + + return min(confidence, 1.0) + + def _generate_user_friendly_explanation(self, + question: str, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + final_answer: str) -> str: + """Generate user-friendly explanation of the process.""" + explanation_parts = [] + + # Introduction + explanation_parts.append(f"To answer your question '{question}', I followed these steps:") + + # Process summary + if question_components.question_type == QuestionType.AGGREGATION: + explanation_parts.append("1. I recognized this as a counting or aggregation question") + elif question_components.question_type == QuestionType.BOOLEAN: + explanation_parts.append("1. I recognized this as a yes/no question") + else: + explanation_parts.append("1. I analyzed your question to understand what information you need") + + # Ontology usage + if ontology_subsets: + ontology_count = len(ontology_subsets) + if ontology_count == 1: + explanation_parts.append("2. I searched through the relevant knowledge base") + else: + explanation_parts.append(f"2. I searched through {ontology_count} knowledge bases") + + # Result processing + explanation_parts.append("3. I found the relevant information and generated your answer") + + # Conclusion + explanation_parts.append(f"The answer is: {final_answer}") + + return " ".join(explanation_parts) + + def _extract_technical_details(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Extract technical details for debugging and optimization.""" + return { + 'query_optimization': metadata.get('query_optimization', {}), + 'backend_performance': metadata.get('backend_performance', {}), + 'cache_usage': metadata.get('cache_usage', {}), + 'error_handling': metadata.get('error_handling', {}), + 'routing_decision': metadata.get('routing_decision', {}) + } + + def format_explanation_for_display(self, + explanation: QueryExplanation, + format_type: str = 'html') -> str: + """Format explanation for display. + + Args: + explanation: Query explanation + format_type: Output format ('html', 'markdown', 'text') + + Returns: + Formatted explanation + """ + if format_type == 'html': + return self._format_html_explanation(explanation) + elif format_type == 'markdown': + return self._format_markdown_explanation(explanation) + else: + return self._format_text_explanation(explanation) + + def _format_html_explanation(self, explanation: QueryExplanation) -> str: + """Format explanation as HTML.""" + html_parts = [ + f"

Query Explanation: {explanation.query_id}

", + f"

Question: {explanation.original_question}

", + f"

Answer: {explanation.final_answer}

", + f"

Confidence: {explanation.confidence_score:.1%}

", + "

Processing Steps:

", + "
    " + ] + + for step in explanation.processing_steps: + html_parts.append(f"
  1. {step.component}: {step.explanation}
  2. ") + + html_parts.extend([ + "
", + "

Reasoning:

", + "
    " + ]) + + for reasoning in explanation.reasoning_chain: + html_parts.append(f"
  • {reasoning}
  • ") + + html_parts.append("
") + + return "".join(html_parts) + + def _format_markdown_explanation(self, explanation: QueryExplanation) -> str: + """Format explanation as Markdown.""" + md_parts = [ + f"## Query Explanation: {explanation.query_id}", + f"**Question:** {explanation.original_question}", + f"**Answer:** {explanation.final_answer}", + f"**Confidence:** {explanation.confidence_score:.1%}", + "", + "### Processing Steps:", + "" + ] + + for i, step in enumerate(explanation.processing_steps, 1): + md_parts.append(f"{i}. **{step.component}**: {step.explanation}") + + md_parts.extend([ + "", + "### Reasoning:", + "" + ]) + + for reasoning in explanation.reasoning_chain: + md_parts.append(f"- {reasoning}") + + return "\n".join(md_parts) + + def _format_text_explanation(self, explanation: QueryExplanation) -> str: + """Format explanation as plain text.""" + text_parts = [ + f"Query Explanation: {explanation.query_id}", + f"Question: {explanation.original_question}", + f"Answer: {explanation.final_answer}", + f"Confidence: {explanation.confidence_score:.1%}", + "", + "Processing Steps:", + ] + + for i, step in enumerate(explanation.processing_steps, 1): + text_parts.append(f" {i}. {step.component}: {step.explanation}") + + text_parts.extend([ + "", + "Reasoning:", + ]) + + for reasoning in explanation.reasoning_chain: + text_parts.append(f" - {reasoning}") + + return "\n".join(text_parts) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/query_optimizer.py b/trustgraph-flow/trustgraph/query/ontology/query_optimizer.py new file mode 100644 index 00000000..5d8f36ec --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/query_optimizer.py @@ -0,0 +1,519 @@ +""" +Query optimization module for OntoRAG. +Optimizes SPARQL and Cypher queries for better performance and accuracy. +""" + +import logging +from typing import Dict, Any, List, Optional, Union, Tuple +from dataclasses import dataclass +from enum import Enum +import re + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset +from .sparql_generator import SPARQLQuery +from .cypher_generator import CypherQuery + +logger = logging.getLogger(__name__) + + +class OptimizationStrategy(Enum): + """Query optimization strategies.""" + PERFORMANCE = "performance" + ACCURACY = "accuracy" + BALANCED = "balanced" + + +@dataclass +class OptimizationHint: + """Optimization hint for query processing.""" + strategy: OptimizationStrategy + max_results: Optional[int] = None + timeout_seconds: Optional[int] = None + use_indices: bool = True + enable_parallel: bool = False + cache_results: bool = True + + +@dataclass +class QueryPlan: + """Query execution plan with optimization metadata.""" + original_query: str + optimized_query: str + estimated_cost: float + optimization_notes: List[str] + index_hints: List[str] + execution_order: List[str] + + +class QueryOptimizer: + """Optimizes SPARQL and Cypher queries for performance and accuracy.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize query optimizer. + + Args: + config: Optimizer configuration + """ + self.config = config or {} + self.default_strategy = OptimizationStrategy( + self.config.get('default_strategy', 'balanced') + ) + self.max_query_complexity = self.config.get('max_query_complexity', 10) + self.enable_query_rewriting = self.config.get('enable_query_rewriting', True) + + # Performance thresholds + self.large_result_threshold = self.config.get('large_result_threshold', 1000) + self.complex_join_threshold = self.config.get('complex_join_threshold', 3) + + def optimize_sparql(self, + sparql_query: SPARQLQuery, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + optimization_hint: Optional[OptimizationHint] = None) -> Tuple[SPARQLQuery, QueryPlan]: + """Optimize SPARQL query. + + Args: + sparql_query: Original SPARQL query + question_components: Question analysis + ontology_subset: Ontology subset + optimization_hint: Optimization hints + + Returns: + Optimized SPARQL query and execution plan + """ + hint = optimization_hint or OptimizationHint(strategy=self.default_strategy) + + optimized_query = sparql_query.query + optimization_notes = [] + index_hints = [] + execution_order = [] + + # Apply optimizations based on strategy + if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]: + optimized_query, perf_notes, perf_hints = self._optimize_sparql_performance( + optimized_query, question_components, ontology_subset, hint + ) + optimization_notes.extend(perf_notes) + index_hints.extend(perf_hints) + + if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]: + optimized_query, acc_notes = self._optimize_sparql_accuracy( + optimized_query, question_components, ontology_subset + ) + optimization_notes.extend(acc_notes) + + # Estimate query cost + estimated_cost = self._estimate_sparql_cost(optimized_query, ontology_subset) + + # Build execution plan + query_plan = QueryPlan( + original_query=sparql_query.query, + optimized_query=optimized_query, + estimated_cost=estimated_cost, + optimization_notes=optimization_notes, + index_hints=index_hints, + execution_order=execution_order + ) + + # Create optimized query object + optimized_sparql = SPARQLQuery( + query=optimized_query, + variables=sparql_query.variables, + query_type=sparql_query.query_type, + explanation=f"Optimized: {sparql_query.explanation}", + complexity_score=min(sparql_query.complexity_score * 0.8, 1.0) # Assume optimization reduces complexity + ) + + return optimized_sparql, query_plan + + def optimize_cypher(self, + cypher_query: CypherQuery, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + optimization_hint: Optional[OptimizationHint] = None) -> Tuple[CypherQuery, QueryPlan]: + """Optimize Cypher query. + + Args: + cypher_query: Original Cypher query + question_components: Question analysis + ontology_subset: Ontology subset + optimization_hint: Optimization hints + + Returns: + Optimized Cypher query and execution plan + """ + hint = optimization_hint or OptimizationHint(strategy=self.default_strategy) + + optimized_query = cypher_query.query + optimization_notes = [] + index_hints = [] + execution_order = [] + + # Apply optimizations based on strategy + if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]: + optimized_query, perf_notes, perf_hints = self._optimize_cypher_performance( + optimized_query, question_components, ontology_subset, hint + ) + optimization_notes.extend(perf_notes) + index_hints.extend(perf_hints) + + if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]: + optimized_query, acc_notes = self._optimize_cypher_accuracy( + optimized_query, question_components, ontology_subset + ) + optimization_notes.extend(acc_notes) + + # Estimate query cost + estimated_cost = self._estimate_cypher_cost(optimized_query, ontology_subset) + + # Build execution plan + query_plan = QueryPlan( + original_query=cypher_query.query, + optimized_query=optimized_query, + estimated_cost=estimated_cost, + optimization_notes=optimization_notes, + index_hints=index_hints, + execution_order=execution_order + ) + + # Create optimized query object + optimized_cypher = CypherQuery( + query=optimized_query, + variables=cypher_query.variables, + query_type=cypher_query.query_type, + explanation=f"Optimized: {cypher_query.explanation}", + complexity_score=min(cypher_query.complexity_score * 0.8, 1.0) + ) + + return optimized_cypher, query_plan + + def _optimize_sparql_performance(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + hint: OptimizationHint) -> Tuple[str, List[str], List[str]]: + """Apply performance optimizations to SPARQL query. + + Args: + query: SPARQL query string + question_components: Question analysis + ontology_subset: Ontology subset + hint: Optimization hints + + Returns: + Optimized query, optimization notes, and index hints + """ + optimized = query + notes = [] + index_hints = [] + + # Add LIMIT if not present and large results expected + if hint.max_results and 'LIMIT' not in optimized.upper(): + optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}" + notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets") + + # Optimize OPTIONAL clauses (move to end) + optional_pattern = re.compile(r'OPTIONAL\s*\{[^}]+\}', re.IGNORECASE | re.DOTALL) + optionals = optional_pattern.findall(optimized) + if optionals: + # Remove optionals from current position + for optional in optionals: + optimized = optimized.replace(optional, '') + + # Add them at the end (before ORDER BY/LIMIT) + insert_point = optimized.rfind('ORDER BY') + if insert_point == -1: + insert_point = optimized.rfind('LIMIT') + if insert_point == -1: + insert_point = len(optimized.rstrip()) + + for optional in optionals: + optimized = optimized[:insert_point] + f"\n {optional}" + optimized[insert_point:] + + notes.append("Moved OPTIONAL clauses to end for better performance") + + # Add index hints for Cassandra + if 'WHERE' in optimized.upper(): + # Suggest indices for common patterns + if '?subject rdf:type' in optimized: + index_hints.append("type_index") + if 'rdfs:subClassOf' in optimized: + index_hints.append("hierarchy_index") + + # Optimize FILTER clauses (move closer to variable bindings) + filter_pattern = re.compile(r'FILTER\s*\([^)]+\)', re.IGNORECASE) + filters = filter_pattern.findall(optimized) + if filters: + notes.append("FILTER clauses present - ensure they're positioned optimally") + + return optimized, notes, index_hints + + def _optimize_sparql_accuracy(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]: + """Apply accuracy optimizations to SPARQL query. + + Args: + query: SPARQL query string + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Optimized query and optimization notes + """ + optimized = query + notes = [] + + # Add missing namespace checks + if question_components.question_type == QuestionType.RETRIEVAL: + # Ensure we're not mixing namespaces inappropriately + if 'http://' in optimized and '?' in optimized: + notes.append("Verified namespace consistency for accuracy") + + # Add type constraints for better precision + if '?entity' in optimized and 'rdf:type' not in optimized: + # Find a good insertion point + where_clause = re.search(r'WHERE\s*\{(.+)\}', optimized, re.DOTALL | re.IGNORECASE) + if where_clause and ontology_subset.classes: + # Add type constraint for the most relevant class + main_class = list(ontology_subset.classes.keys())[0] + type_constraint = f"\n ?entity rdf:type :{main_class} ." + + # Insert after the WHERE { + where_start = where_clause.start(1) + optimized = optimized[:where_start] + type_constraint + optimized[where_start:] + notes.append(f"Added type constraint for {main_class} to improve accuracy") + + # Add DISTINCT if not present for retrieval queries + if (question_components.question_type == QuestionType.RETRIEVAL and + 'DISTINCT' not in optimized.upper() and + 'SELECT' in optimized.upper()): + optimized = optimized.replace('SELECT ', 'SELECT DISTINCT ', 1) + notes.append("Added DISTINCT to eliminate duplicate results") + + return optimized, notes + + def _optimize_cypher_performance(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + hint: OptimizationHint) -> Tuple[str, List[str], List[str]]: + """Apply performance optimizations to Cypher query. + + Args: + query: Cypher query string + question_components: Question analysis + ontology_subset: Ontology subset + hint: Optimization hints + + Returns: + Optimized query, optimization notes, and index hints + """ + optimized = query + notes = [] + index_hints = [] + + # Add LIMIT if not present + if hint.max_results and 'LIMIT' not in optimized.upper(): + optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}" + notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets") + + # Use parameters for literals to enable query plan caching + if "'" in optimized or '"' in optimized: + notes.append("Consider using parameters for literal values to enable query plan caching") + + # Suggest indices based on query patterns + if 'MATCH (n:' in optimized: + label_match = re.search(r'MATCH \(n:(\w+)\)', optimized) + if label_match: + label = label_match.group(1) + index_hints.append(f"node_label_index:{label}") + + if 'WHERE' in optimized.upper() and '.' in optimized: + # Property access patterns + property_pattern = re.compile(r'\.(\w+)', re.IGNORECASE) + properties = property_pattern.findall(optimized) + for prop in set(properties): + index_hints.append(f"property_index:{prop}") + + # Optimize relationship traversals + if '-[' in optimized and '*' in optimized: + notes.append("Variable length path detected - consider adding relationship type filters") + + # Early filtering optimization + if 'WHERE' in optimized.upper(): + # Move WHERE clauses closer to MATCH clauses + notes.append("WHERE clauses present - ensure early filtering for performance") + + return optimized, notes, index_hints + + def _optimize_cypher_accuracy(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]: + """Apply accuracy optimizations to Cypher query. + + Args: + query: Cypher query string + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Optimized query and optimization notes + """ + optimized = query + notes = [] + + # Add DISTINCT if not present for retrieval queries + if (question_components.question_type == QuestionType.RETRIEVAL and + 'DISTINCT' not in optimized.upper() and + 'RETURN' in optimized.upper()): + optimized = re.sub(r'RETURN\s+', 'RETURN DISTINCT ', optimized, count=1, flags=re.IGNORECASE) + notes.append("Added DISTINCT to eliminate duplicate results") + + # Ensure proper relationship direction + if '-[' in optimized and question_components.relationships: + notes.append("Verified relationship directions for semantic accuracy") + + # Add null checks for optional properties + if '?' in optimized or 'OPTIONAL' in optimized.upper(): + notes.append("Consider adding null checks for optional properties") + + return optimized, notes + + def _estimate_sparql_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float: + """Estimate execution cost for SPARQL query. + + Args: + query: SPARQL query string + ontology_subset: Ontology subset + + Returns: + Estimated cost (0.0 to 1.0) + """ + cost = 0.0 + + # Basic query complexity + cost += len(query.split('\n')) * 0.01 + + # Join complexity + triple_patterns = len(re.findall(r'\?\w+\s+\?\w+\s+\?\w+', query)) + cost += triple_patterns * 0.1 + + # OPTIONAL clauses + optional_count = len(re.findall(r'OPTIONAL', query, re.IGNORECASE)) + cost += optional_count * 0.15 + + # FILTER clauses + filter_count = len(re.findall(r'FILTER', query, re.IGNORECASE)) + cost += filter_count * 0.1 + + # Property paths + path_count = len(re.findall(r'\*|\+', query)) + cost += path_count * 0.2 + + # Ontology subset size impact + total_elements = (len(ontology_subset.classes) + + len(ontology_subset.object_properties) + + len(ontology_subset.datatype_properties)) + cost += (total_elements / 100.0) * 0.1 + + return min(cost, 1.0) + + def _estimate_cypher_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float: + """Estimate execution cost for Cypher query. + + Args: + query: Cypher query string + ontology_subset: Ontology subset + + Returns: + Estimated cost (0.0 to 1.0) + """ + cost = 0.0 + + # Basic query complexity + cost += len(query.split('\n')) * 0.01 + + # Pattern complexity + match_count = len(re.findall(r'MATCH', query, re.IGNORECASE)) + cost += match_count * 0.1 + + # Relationship traversals + rel_count = len(re.findall(r'-\[.*?\]-', query)) + cost += rel_count * 0.1 + + # Variable length paths + var_path_count = len(re.findall(r'\*\d*\.\.', query)) + cost += var_path_count * 0.3 + + # WHERE clauses + where_count = len(re.findall(r'WHERE', query, re.IGNORECASE)) + cost += where_count * 0.05 + + # Aggregation functions + agg_count = len(re.findall(r'COUNT|SUM|AVG|MIN|MAX', query, re.IGNORECASE)) + cost += agg_count * 0.1 + + # Ontology subset size impact + total_elements = (len(ontology_subset.classes) + + len(ontology_subset.object_properties) + + len(ontology_subset.datatype_properties)) + cost += (total_elements / 100.0) * 0.1 + + return min(cost, 1.0) + + def should_use_cache(self, + query: str, + question_components: QuestionComponents, + optimization_hint: OptimizationHint) -> bool: + """Determine if query results should be cached. + + Args: + query: Query string + question_components: Question analysis + optimization_hint: Optimization hints + + Returns: + True if results should be cached + """ + if not optimization_hint.cache_results: + return False + + # Cache simple retrieval and factual queries + if question_components.question_type in [QuestionType.RETRIEVAL, QuestionType.FACTUAL]: + return True + + # Cache expensive aggregation queries + if (question_components.question_type == QuestionType.AGGREGATION and + ('COUNT' in query.upper() or 'SUM' in query.upper())): + return True + + # Don't cache real-time or time-sensitive queries + if any(keyword in question_components.original_question.lower() + for keyword in ['now', 'current', 'latest', 'recent']): + return False + + return False + + def get_cache_key(self, + query: str, + ontology_subset: QueryOntologySubset) -> str: + """Generate cache key for query. + + Args: + query: Query string + ontology_subset: Ontology subset + + Returns: + Cache key string + """ + import hashlib + + # Create stable representation + ontology_repr = f"{sorted(ontology_subset.classes.keys())}-{sorted(ontology_subset.object_properties.keys())}" + combined = f"{query.strip()}-{ontology_repr}" + + return hashlib.md5(combined.encode()).hexdigest() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py new file mode 100644 index 00000000..ec7884ed --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -0,0 +1,438 @@ +""" +Main OntoRAG query service. +Orchestrates question analysis, ontology matching, query generation, execution, and answer generation. +""" + +import logging +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass +from datetime import datetime + +from ....flow.flow_processor import FlowProcessor +from ....tables.config import ConfigTableStore +from ...extract.kg.ontology.ontology_loader import OntologyLoader +from ...extract.kg.ontology.vector_store import InMemoryVectorStore + +from .question_analyzer import QuestionAnalyzer, QuestionComponents +from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .backend_router import BackendRouter, QueryRoute, BackendType +from .sparql_generator import SPARQLGenerator, SPARQLQuery +from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult +from .cypher_generator import CypherGenerator, CypherQuery +from .cypher_executor import CypherExecutor, CypherResult +from .answer_generator import AnswerGenerator, GeneratedAnswer + +logger = logging.getLogger(__name__) + + +@dataclass +class QueryRequest: + """Query request from user.""" + question: str + context: Optional[str] = None + ontology_hint: Optional[str] = None + max_results: int = 10 + confidence_threshold: float = 0.7 + + +@dataclass +class QueryResponse: + """Complete query response.""" + answer: str + confidence: float + execution_time: float + question_analysis: QuestionComponents + ontology_subsets: List[QueryOntologySubset] + query_route: QueryRoute + generated_query: Union[SPARQLQuery, CypherQuery] + raw_results: Union[SPARQLResult, CypherResult] + supporting_facts: List[str] + metadata: Dict[str, Any] + + +class OntoRAGQueryService(FlowProcessor): + """Main OntoRAG query service orchestrating all components.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize OntoRAG query service. + + Args: + config: Service configuration + """ + super().__init__(config) + self.config = config + + # Initialize components + self.config_store = None + self.ontology_loader = None + self.vector_store = None + self.question_analyzer = None + self.ontology_matcher = None + self.backend_router = None + self.sparql_generator = None + self.sparql_engine = None + self.cypher_generator = None + self.cypher_executor = None + self.answer_generator = None + + # Cache for loaded ontologies + self.ontology_cache = {} + + async def init(self): + """Initialize all components.""" + await super().init() + + # Initialize configuration store + self.config_store = ConfigTableStore(self.config.get('config_store', {})) + + # Initialize ontology components + self.ontology_loader = OntologyLoader(self.config_store) + + # Initialize vector store + vector_config = self.config.get('vector_store', {}) + self.vector_store = InMemoryVectorStore.create( + store_type=vector_config.get('type', 'numpy'), + dimension=vector_config.get('dimension', 384), + similarity_threshold=vector_config.get('similarity_threshold', 0.7) + ) + + # Initialize question analyzer + analyzer_config = self.config.get('question_analyzer', {}) + self.question_analyzer = QuestionAnalyzer( + prompt_service=self.prompt_service, + config=analyzer_config + ) + + # Initialize ontology matcher + matcher_config = self.config.get('ontology_matcher', {}) + self.ontology_matcher = OntologyMatcher( + vector_store=self.vector_store, + embedding_service=self.embedding_service, + config=matcher_config + ) + + # Initialize backend router + router_config = self.config.get('backend_router', {}) + self.backend_router = BackendRouter(router_config) + + # Initialize query generators + self.sparql_generator = SPARQLGenerator(prompt_service=self.prompt_service) + self.cypher_generator = CypherGenerator(prompt_service=self.prompt_service) + + # Initialize executors + sparql_config = self.config.get('sparql_executor', {}) + if self.backend_router.is_backend_enabled(BackendType.CASSANDRA): + cassandra_config = self.backend_router.get_backend_config(BackendType.CASSANDRA) + if cassandra_config: + self.sparql_engine = SPARQLCassandraEngine(cassandra_config) + await self.sparql_engine.initialize() + + cypher_config = self.config.get('cypher_executor', {}) + enabled_graph_backends = [ + bt for bt in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] + if self.backend_router.is_backend_enabled(bt) + ] + if enabled_graph_backends: + self.cypher_executor = CypherExecutor(cypher_config) + await self.cypher_executor.initialize() + + # Initialize answer generator + self.answer_generator = AnswerGenerator(prompt_service=self.prompt_service) + + logger.info("OntoRAG query service initialized") + + async def process(self, request: QueryRequest) -> QueryResponse: + """Process a natural language query. + + Args: + request: Query request + + Returns: + Complete query response + """ + start_time = datetime.now() + + try: + logger.info(f"Processing query: {request.question}") + + # Step 1: Analyze question + question_components = await self.question_analyzer.analyze_question( + request.question, context=request.context + ) + logger.debug(f"Question analysis: {question_components.question_type}") + + # Step 2: Load and match ontologies + ontology_subsets = await self._load_and_match_ontologies( + question_components, request.ontology_hint + ) + logger.debug(f"Found {len(ontology_subsets)} relevant ontology subsets") + + # Step 3: Route to appropriate backend + query_route = self.backend_router.route_query( + question_components, ontology_subsets + ) + logger.debug(f"Routed to {query_route.backend_type.value} backend") + + # Step 4: Generate and execute query + if query_route.query_language == 'sparql': + query_results = await self._execute_sparql_path( + question_components, ontology_subsets, query_route + ) + else: # cypher + query_results = await self._execute_cypher_path( + question_components, ontology_subsets, query_route + ) + + # Step 5: Generate natural language answer + generated_answer = await self.answer_generator.generate_answer( + question_components, + query_results['raw_results'], + ontology_subsets[0] if ontology_subsets else None, + query_route.backend_type.value + ) + + # Build response + execution_time = (datetime.now() - start_time).total_seconds() + + response = QueryResponse( + answer=generated_answer.answer, + confidence=min(query_route.confidence, generated_answer.metadata.confidence), + execution_time=execution_time, + question_analysis=question_components, + ontology_subsets=ontology_subsets, + query_route=query_route, + generated_query=query_results['generated_query'], + raw_results=query_results['raw_results'], + supporting_facts=generated_answer.supporting_facts, + metadata={ + 'backend_used': query_route.backend_type.value, + 'query_language': query_route.query_language, + 'ontology_count': len(ontology_subsets), + 'result_count': generated_answer.metadata.result_count, + 'routing_reasoning': query_route.reasoning, + 'generation_time': generated_answer.generation_time + } + ) + + logger.info(f"Query processed successfully in {execution_time:.2f}s") + return response + + except Exception as e: + logger.error(f"Query processing failed: {e}") + execution_time = (datetime.now() - start_time).total_seconds() + + # Return error response + return QueryResponse( + answer=f"I encountered an error processing your query: {str(e)}", + confidence=0.0, + execution_time=execution_time, + question_analysis=QuestionComponents( + original_question=request.question, + normalized_question=request.question, + question_type=None, + entities=[], keywords=[], relationships=[], constraints=[], + aggregations=[], expected_answer_type="unknown" + ), + ontology_subsets=[], + query_route=None, + generated_query=None, + raw_results=None, + supporting_facts=[], + metadata={'error': str(e), 'execution_time': execution_time} + ) + + async def _load_and_match_ontologies(self, + question_components: QuestionComponents, + ontology_hint: Optional[str] = None) -> List[QueryOntologySubset]: + """Load ontologies and find relevant subsets. + + Args: + question_components: Analyzed question + ontology_hint: Optional ontology hint + + Returns: + List of relevant ontology subsets + """ + try: + # Load available ontologies + if ontology_hint: + # Load specific ontology + ontologies = [await self.ontology_loader.load_ontology(ontology_hint)] + else: + # Load all available ontologies + available_ontologies = await self.ontology_loader.list_available_ontologies() + ontologies = [] + for ontology_id in available_ontologies[:5]: # Limit to 5 for performance + try: + ontology = await self.ontology_loader.load_ontology(ontology_id) + ontologies.append(ontology) + except Exception as e: + logger.warning(f"Failed to load ontology {ontology_id}: {e}") + + if not ontologies: + logger.warning("No ontologies loaded") + return [] + + # Extract relevant subsets + ontology_subsets = [] + for ontology in ontologies: + subset = await self.ontology_matcher.select_relevant_subset( + question_components, ontology + ) + if subset and (subset.classes or subset.object_properties or subset.datatype_properties): + ontology_subsets.append(subset) + + return ontology_subsets + + except Exception as e: + logger.error(f"Failed to load and match ontologies: {e}") + return [] + + async def _execute_sparql_path(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + query_route: QueryRoute) -> Dict[str, Any]: + """Execute SPARQL query path. + + Args: + question_components: Question analysis + ontology_subsets: Ontology subsets + query_route: Query route + + Returns: + Query execution results + """ + if not self.sparql_engine: + raise RuntimeError("SPARQL engine not initialized") + + # Generate SPARQL query + primary_subset = ontology_subsets[0] if ontology_subsets else None + sparql_query = await self.sparql_generator.generate_sparql( + question_components, primary_subset + ) + + logger.debug(f"Generated SPARQL: {sparql_query.query}") + + # Execute query + sparql_results = self.sparql_engine.execute_sparql(sparql_query.query) + + return { + 'generated_query': sparql_query, + 'raw_results': sparql_results + } + + async def _execute_cypher_path(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + query_route: QueryRoute) -> Dict[str, Any]: + """Execute Cypher query path. + + Args: + question_components: Question analysis + ontology_subsets: Ontology subsets + query_route: Query route + + Returns: + Query execution results + """ + if not self.cypher_executor: + raise RuntimeError("Cypher executor not initialized") + + # Generate Cypher query + primary_subset = ontology_subsets[0] if ontology_subsets else None + cypher_query = await self.cypher_generator.generate_cypher( + question_components, primary_subset + ) + + logger.debug(f"Generated Cypher: {cypher_query.query}") + + # Execute query + database_type = query_route.backend_type.value + cypher_results = await self.cypher_executor.execute_query( + cypher_query.query, database_type=database_type + ) + + return { + 'generated_query': cypher_query, + 'raw_results': cypher_results + } + + async def get_supported_backends(self) -> List[str]: + """Get list of supported and enabled backends. + + Returns: + List of backend names + """ + return [bt.value for bt in self.backend_router.get_available_backends()] + + async def get_available_ontologies(self) -> List[str]: + """Get list of available ontologies. + + Returns: + List of ontology identifiers + """ + if self.ontology_loader: + return await self.ontology_loader.list_available_ontologies() + return [] + + async def health_check(self) -> Dict[str, Any]: + """Perform health check on all components. + + Returns: + Health status of all components + """ + health = { + 'service': 'healthy', + 'components': {}, + 'backends': {}, + 'ontologies': {} + } + + try: + # Check ontology loader + if self.ontology_loader: + ontologies = await self.ontology_loader.list_available_ontologies() + health['components']['ontology_loader'] = 'healthy' + health['ontologies']['count'] = len(ontologies) + else: + health['components']['ontology_loader'] = 'not_initialized' + + # Check vector store + if self.vector_store: + health['components']['vector_store'] = 'healthy' + health['components']['vector_store_type'] = type(self.vector_store).__name__ + else: + health['components']['vector_store'] = 'not_initialized' + + # Check backends + for backend_type in self.backend_router.get_available_backends(): + if backend_type == BackendType.CASSANDRA and self.sparql_engine: + health['backends']['cassandra'] = 'healthy' + elif backend_type in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] and self.cypher_executor: + health['backends'][backend_type.value] = 'healthy' + else: + health['backends'][backend_type.value] = 'configured_but_not_initialized' + + except Exception as e: + health['service'] = 'degraded' + health['error'] = str(e) + + return health + + async def close(self): + """Close all connections and cleanup resources.""" + try: + if self.sparql_engine: + self.sparql_engine.close() + + if self.cypher_executor: + await self.cypher_executor.close() + + if self.config_store: + # ConfigTableStore cleanup if needed + pass + + logger.info("OntoRAG query service closed") + + except Exception as e: + logger.error(f"Error closing OntoRAG query service: {e}") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py new file mode 100644 index 00000000..3e48ac78 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py @@ -0,0 +1,364 @@ +""" +Question analyzer for ontology-sensitive query system. +Decomposes user questions into semantic components. +""" + +import logging +import re +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +from enum import Enum + +logger = logging.getLogger(__name__) + + +class QuestionType(Enum): + """Types of questions that can be asked.""" + FACTUAL = "factual" # What is X? + RETRIEVAL = "retrieval" # Find all X + AGGREGATION = "aggregation" # How many X? + COMPARISON = "comparison" # Is X better than Y? + RELATIONSHIP = "relationship" # How is X related to Y? + BOOLEAN = "boolean" # Yes/no questions + PROCESS = "process" # How to do X? + TEMPORAL = "temporal" # When did X happen? + SPATIAL = "spatial" # Where is X? + + +@dataclass +class QuestionComponents: + """Components extracted from a question.""" + original_question: str + question_type: QuestionType + entities: List[str] + relationships: List[str] + constraints: List[str] + aggregations: List[str] + expected_answer_type: str + keywords: List[str] + + +class QuestionAnalyzer: + """Analyzes natural language questions to extract semantic components.""" + + def __init__(self): + """Initialize question analyzer.""" + # Question word patterns + self.question_patterns = { + QuestionType.FACTUAL: [ + r'^what\s+(?:is|are)', + r'^who\s+(?:is|are)', + r'^which\s+', + ], + QuestionType.RETRIEVAL: [ + r'^find\s+', + r'^list\s+', + r'^show\s+', + r'^get\s+', + r'^retrieve\s+', + ], + QuestionType.AGGREGATION: [ + r'^how\s+many', + r'^count\s+', + r'^what\s+(?:is|are)\s+the\s+(?:number|total|sum)', + ], + QuestionType.COMPARISON: [ + r'(?:better|worse|more|less|greater|smaller)\s+than', + r'compare\s+', + r'difference\s+between', + ], + QuestionType.RELATIONSHIP: [ + r'^how\s+(?:is|are).*related', + r'relationship\s+between', + r'connection\s+between', + ], + QuestionType.BOOLEAN: [ + r'^(?:is|are|was|were|do|does|did|can|could|will|would|should)', + r'^has\s+', + r'^have\s+', + ], + QuestionType.PROCESS: [ + r'^how\s+(?:to|do)', + r'^explain\s+how', + ], + QuestionType.TEMPORAL: [ + r'^when\s+', + r'what\s+time', + r'what\s+date', + ], + QuestionType.SPATIAL: [ + r'^where\s+', + r'location\s+of', + ], + } + + # Aggregation keywords + self.aggregation_keywords = [ + 'count', 'sum', 'total', 'average', 'mean', 'median', + 'maximum', 'minimum', 'max', 'min', 'number of' + ] + + # Constraint patterns + self.constraint_patterns = [ + r'(?:with|having|where)\s+(.+?)(?:\s+and|\s+or|$)', + r'(?:greater|less|more|fewer)\s+than\s+(\d+)', + r'(?:between|from)\s+(.+?)\s+(?:and|to)\s+(.+)', + r'(?:before|after|since|until)\s+(.+)', + ] + + def analyze(self, question: str) -> QuestionComponents: + """Analyze a question to extract components. + + Args: + question: Natural language question + + Returns: + QuestionComponents with extracted information + """ + # Normalize question + question_lower = question.lower().strip() + + # Determine question type + question_type = self._identify_question_type(question_lower) + + # Extract entities + entities = self._extract_entities(question) + + # Extract relationships + relationships = self._extract_relationships(question_lower) + + # Extract constraints + constraints = self._extract_constraints(question_lower) + + # Extract aggregations + aggregations = self._extract_aggregations(question_lower) + + # Determine expected answer type + answer_type = self._determine_answer_type(question_type, aggregations) + + # Extract keywords + keywords = self._extract_keywords(question_lower) + + return QuestionComponents( + original_question=question, + question_type=question_type, + entities=entities, + relationships=relationships, + constraints=constraints, + aggregations=aggregations, + expected_answer_type=answer_type, + keywords=keywords + ) + + def _identify_question_type(self, question: str) -> QuestionType: + """Identify the type of question. + + Args: + question: Lowercase question text + + Returns: + QuestionType enum value + """ + for q_type, patterns in self.question_patterns.items(): + for pattern in patterns: + if re.search(pattern, question): + return q_type + + # Default to factual + return QuestionType.FACTUAL + + def _extract_entities(self, question: str) -> List[str]: + """Extract potential entities from question. + + Args: + question: Original question text + + Returns: + List of entity strings + """ + entities = [] + + # Extract capitalized words/phrases (potential proper nouns) + # Pattern for consecutive capitalized words + pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b' + matches = re.findall(pattern, question) + entities.extend(matches) + + # Extract quoted strings + quoted = re.findall(r'"([^"]+)"', question) + entities.extend(quoted) + quoted = re.findall(r"'([^']+)'", question) + entities.extend(quoted) + + # Remove duplicates while preserving order + seen = set() + unique_entities = [] + for entity in entities: + if entity not in seen: + seen.add(entity) + unique_entities.append(entity) + + return unique_entities + + def _extract_relationships(self, question: str) -> List[str]: + """Extract relationship indicators from question. + + Args: + question: Lowercase question text + + Returns: + List of relationship strings + """ + relationships = [] + + # Common relationship patterns + rel_patterns = [ + r'(\w+)\s+(?:of|by|from|to|with|for)\s+', + r'has\s+(\w+)', + r'belongs?\s+to', + r'(?:created|written|authored|owned)\s+by', + r'related\s+to', + r'connected\s+to', + r'associated\s+with', + ] + + for pattern in rel_patterns: + matches = re.findall(pattern, question) + relationships.extend(matches) + + # Clean up + relationships = [r for r in relationships if len(r) > 2] + return list(set(relationships)) + + def _extract_constraints(self, question: str) -> List[str]: + """Extract constraints from question. + + Args: + question: Lowercase question text + + Returns: + List of constraint strings + """ + constraints = [] + + for pattern in self.constraint_patterns: + matches = re.findall(pattern, question) + if matches: + if isinstance(matches[0], tuple): + constraints.extend(list(matches[0])) + else: + constraints.extend(matches) + + # Clean up + constraints = [c.strip() for c in constraints if c and len(c.strip()) > 0] + return constraints + + def _extract_aggregations(self, question: str) -> List[str]: + """Extract aggregation operations from question. + + Args: + question: Lowercase question text + + Returns: + List of aggregation operations + """ + aggregations = [] + + for keyword in self.aggregation_keywords: + if keyword in question: + aggregations.append(keyword) + + return aggregations + + def _determine_answer_type(self, question_type: QuestionType, + aggregations: List[str]) -> str: + """Determine expected answer type. + + Args: + question_type: Type of question + aggregations: Aggregation operations found + + Returns: + Expected answer type string + """ + if aggregations: + if any(a in ['count', 'number of', 'total'] for a in aggregations): + return 'number' + elif any(a in ['average', 'mean', 'median'] for a in aggregations): + return 'number' + elif any(a in ['sum'] for a in aggregations): + return 'number' + + if question_type == QuestionType.BOOLEAN: + return 'boolean' + elif question_type == QuestionType.TEMPORAL: + return 'datetime' + elif question_type == QuestionType.SPATIAL: + return 'location' + elif question_type == QuestionType.RETRIEVAL: + return 'list' + elif question_type == QuestionType.COMPARISON: + return 'comparison' + else: + return 'text' + + def _extract_keywords(self, question: str) -> List[str]: + """Extract important keywords from question. + + Args: + question: Lowercase question text + + Returns: + List of keywords + """ + # Remove common stop words + stop_words = { + 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', + 'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', + 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', + 'does', 'did', 'will', 'would', 'could', 'should', 'may', + 'might', 'must', 'can', 'shall', 'what', 'which', 'who', + 'when', 'where', 'why', 'how' + } + + # Extract words + words = re.findall(r'\b\w+\b', question) + + # Filter stop words and short words + keywords = [w for w in words if w not in stop_words and len(w) > 2] + + # Remove duplicates while preserving order + seen = set() + unique_keywords = [] + for kw in keywords: + if kw not in seen: + seen.add(kw) + unique_keywords.append(kw) + + return unique_keywords + + def get_question_segments(self, question: str) -> List[str]: + """Split question into segments for embedding. + + Args: + question: Question text + + Returns: + List of question segments + """ + segments = [] + + # Add full question + segments.append(question) + + # Split by clauses + clauses = re.split(r'[,;]', question) + segments.extend([c.strip() for c in clauses if len(c.strip()) > 3]) + + # Extract key phrases + components = self.analyze(question) + segments.extend(components.entities) + segments.extend(components.keywords) + + # Remove duplicates + return list(dict.fromkeys(segments)) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py new file mode 100644 index 00000000..688e7371 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py @@ -0,0 +1,481 @@ +""" +SPARQL-Cassandra engine using Python rdflib. +Executes SPARQL queries against Cassandra using a custom Store implementation. +""" + +import logging +from typing import Dict, Any, List, Optional, Iterator, Tuple +from dataclasses import dataclass +import json + +# Try to import rdflib +try: + from rdflib import Graph, Namespace, URIRef, Literal, BNode + from rdflib.store import Store + from rdflib.plugins.sparql.processor import SPARQLResult + from rdflib.plugins.sparql import prepareQuery + from rdflib.term import Node + RDFLIB_AVAILABLE = True +except ImportError: + RDFLIB_AVAILABLE = False + +# Try to import Cassandra driver +try: + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + from cassandra.policies import DCAwareRoundRobinPolicy + CASSANDRA_AVAILABLE = True +except ImportError: + CASSANDRA_AVAILABLE = False + +from ....tables.config import ConfigTableStore + +logger = logging.getLogger(__name__) + + +@dataclass +class SPARQLResult: + """Result from SPARQL query execution.""" + bindings: List[Dict[str, Any]] + variables: List[str] + ask_result: Optional[bool] = None # For ASK queries + execution_time: float = 0.0 + query_plan: Optional[str] = None + + +class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object): + """Custom rdflib Store implementation for Cassandra.""" + + def __init__(self, cassandra_config: Dict[str, Any]): + """Initialize Cassandra triple store. + + Args: + cassandra_config: Cassandra connection configuration + """ + if not CASSANDRA_AVAILABLE: + raise RuntimeError("Cassandra driver not available") + if not RDFLIB_AVAILABLE: + raise RuntimeError("rdflib not available") + + super().__init__() + + self.cassandra_config = cassandra_config + self.cluster = None + self.session = None + self.keyspace = cassandra_config.get('keyspace', 'trustgraph') + + # Triple storage table structure + self.triple_table = f"{self.keyspace}.triples" + self.metadata_table = f"{self.keyspace}.triple_metadata" + + def open(self, configuration=None, create=False): + """Open connection to Cassandra.""" + try: + # Create authentication if provided + auth_provider = None + if 'username' in self.cassandra_config and 'password' in self.cassandra_config: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_config['username'], + password=self.cassandra_config['password'] + ) + + # Create cluster + self.cluster = Cluster( + [self.cassandra_config.get('host', 'localhost')], + port=self.cassandra_config.get('port', 9042), + auth_provider=auth_provider, + load_balancing_policy=DCAwareRoundRobinPolicy() + ) + + # Connect + self.session = self.cluster.connect() + + # Ensure keyspace exists + if create: + self._create_schema() + + # Set keyspace + self.session.set_keyspace(self.keyspace) + + logger.info(f"Connected to Cassandra cluster: {self.cassandra_config.get('host')}") + return True + + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}") + return False + + def close(self, commit_pending_transaction=True): + """Close Cassandra connection.""" + if self.session: + self.session.shutdown() + if self.cluster: + self.cluster.shutdown() + + def _create_schema(self): + """Create Cassandra schema for triple storage.""" + # Create keyspace + self.session.execute(f""" + CREATE KEYSPACE IF NOT EXISTS {self.keyspace} + WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} + """) + + # Create triples table optimized for SPARQL queries + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.triple_table} ( + subject text, + predicate text, + object text, + object_datatype text, + object_language text, + is_literal boolean, + graph_id text, + PRIMARY KEY ((subject), predicate, object) + ) + """) + + # Create indexes for efficient querying + self.session.execute(f""" + CREATE INDEX IF NOT EXISTS ON {self.triple_table} (predicate) + """) + self.session.execute(f""" + CREATE INDEX IF NOT EXISTS ON {self.triple_table} (object) + """) + + # Metadata table for graph information + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.metadata_table} ( + graph_id text PRIMARY KEY, + created timestamp, + modified timestamp, + triple_count counter + ) + """) + + def triples(self, triple_pattern, context=None): + """Retrieve triples matching the given pattern. + + Args: + triple_pattern: (subject, predicate, object) pattern with None for variables + context: Graph context (optional) + + Yields: + Matching triples as (subject, predicate, object) tuples + """ + if not self.session: + return + + subject, predicate, object_val = triple_pattern + + # Build CQL query based on pattern + cql_queries = self._pattern_to_cql(subject, predicate, object_val) + + for cql, params in cql_queries: + try: + rows = self.session.execute(cql, params) + for row in rows: + yield self._row_to_triple(row) + except Exception as e: + logger.error(f"Error executing CQL query: {e}") + + def _pattern_to_cql(self, subject, predicate, object_val) -> List[Tuple[str, List]]: + """Convert triple pattern to CQL queries. + + Args: + subject: Subject node or None + predicate: Predicate node or None + object_val: Object node or None + + Returns: + List of (CQL query, parameters) tuples + """ + queries = [] + + # Convert None to wildcard, nodes to strings + s_str = str(subject) if subject else None + p_str = str(predicate) if predicate else None + o_str = str(object_val) if object_val else None + + if s_str and p_str and o_str: + # Specific triple lookup + cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ? AND object = ?" + queries.append((cql, [s_str, p_str, o_str])) + + elif s_str and p_str: + # Subject and predicate known + cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ?" + queries.append((cql, [s_str, p_str])) + + elif s_str: + # Subject known + cql = f"SELECT * FROM {self.triple_table} WHERE subject = ?" + queries.append((cql, [s_str])) + + elif p_str: + # Predicate known (requires index scan) + cql = f"SELECT * FROM {self.triple_table} WHERE predicate = ? ALLOW FILTERING" + queries.append((cql, [p_str])) + + elif o_str: + # Object known (requires index scan) + cql = f"SELECT * FROM {self.triple_table} WHERE object = ? ALLOW FILTERING" + queries.append((cql, [o_str])) + + else: + # Full scan (should be avoided in production) + cql = f"SELECT * FROM {self.triple_table}" + queries.append((cql, [])) + + return queries + + def _row_to_triple(self, row): + """Convert Cassandra row to RDF triple. + + Args: + row: Cassandra row object + + Returns: + (subject, predicate, object) tuple with rdflib nodes + """ + # Convert to rdflib nodes + subject = URIRef(row.subject) if row.subject.startswith('http') else BNode(row.subject) + + predicate = URIRef(row.predicate) + + if row.is_literal: + # Create literal with datatype/language + if row.object_datatype: + object_node = Literal(row.object, datatype=URIRef(row.object_datatype)) + elif row.object_language: + object_node = Literal(row.object, lang=row.object_language) + else: + object_node = Literal(row.object) + else: + object_node = URIRef(row.object) if row.object.startswith('http') else BNode(row.object) + + return (subject, predicate, object_node) + + def add(self, triple, context=None, quoted=False): + """Add a triple to the store. + + Args: + triple: (subject, predicate, object) tuple + context: Graph context + quoted: Whether triple is quoted + """ + if not self.session: + return + + subject, predicate, object_val = triple + + # Convert to storage format + s_str = str(subject) + p_str = str(predicate) + + is_literal = isinstance(object_val, Literal) + o_str = str(object_val) + o_datatype = str(object_val.datatype) if is_literal and object_val.datatype else None + o_language = object_val.language if is_literal and object_val.language else None + + # Insert into Cassandra + cql = f""" + INSERT INTO {self.triple_table} + (subject, predicate, object, object_datatype, object_language, is_literal, graph_id) + VALUES (?, ?, ?, ?, ?, ?, ?) + """ + + try: + self.session.execute(cql, [ + s_str, p_str, o_str, o_datatype, o_language, is_literal, + str(context) if context else 'default' + ]) + except Exception as e: + logger.error(f"Error adding triple: {e}") + + def remove(self, triple, context=None): + """Remove a triple from the store. + + Args: + triple: (subject, predicate, object) tuple + context: Graph context + """ + if not self.session: + return + + subject, predicate, object_val = triple + + cql = f""" + DELETE FROM {self.triple_table} + WHERE subject = ? AND predicate = ? AND object = ? + """ + + try: + self.session.execute(cql, [str(subject), str(predicate), str(object_val)]) + except Exception as e: + logger.error(f"Error removing triple: {e}") + + def __len__(self, context=None): + """Get number of triples in store. + + Args: + context: Graph context + + Returns: + Number of triples + """ + if not self.session: + return 0 + + try: + cql = f"SELECT COUNT(*) FROM {self.triple_table}" + result = self.session.execute(cql) + return result.one().count + except Exception as e: + logger.error(f"Error counting triples: {e}") + return 0 + + +class SPARQLCassandraEngine: + """SPARQL processor using Cassandra backend.""" + + def __init__(self, cassandra_config: Dict[str, Any]): + """Initialize SPARQL-Cassandra engine. + + Args: + cassandra_config: Cassandra configuration + """ + if not RDFLIB_AVAILABLE: + raise RuntimeError("rdflib is required for SPARQL processing") + if not CASSANDRA_AVAILABLE: + raise RuntimeError("Cassandra driver is required") + + self.cassandra_config = cassandra_config + self.store = CassandraTripleStore(cassandra_config) + self.graph = Graph(store=self.store) + + # Common namespaces + self.namespaces = { + 'rdf': Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'), + 'rdfs': Namespace('http://www.w3.org/2000/01/rdf-schema#'), + 'owl': Namespace('http://www.w3.org/2002/07/owl#'), + 'xsd': Namespace('http://www.w3.org/2001/XMLSchema#'), + } + + # Bind namespaces to graph + for prefix, namespace in self.namespaces.items(): + self.graph.bind(prefix, namespace) + + async def initialize(self, create_schema=False): + """Initialize the engine. + + Args: + create_schema: Whether to create Cassandra schema + """ + success = self.store.open(create=create_schema) + if not success: + raise RuntimeError("Failed to connect to Cassandra") + + logger.info("SPARQL-Cassandra engine initialized") + + def execute_sparql(self, sparql_query: str) -> SPARQLResult: + """Execute SPARQL query against Cassandra. + + Args: + sparql_query: SPARQL query string + + Returns: + Query results + """ + import time + start_time = time.time() + + try: + # Prepare and execute query + prepared_query = prepareQuery(sparql_query) + result = self.graph.query(prepared_query) + + execution_time = time.time() - start_time + + # Format results based on query type + if sparql_query.strip().upper().startswith('ASK'): + return SPARQLResult( + bindings=[], + variables=[], + ask_result=bool(result), + execution_time=execution_time + ) + else: + # SELECT query + bindings = [] + variables = result.vars if hasattr(result, 'vars') else [] + + for row in result: + binding = {} + for i, var in enumerate(variables): + if i < len(row): + value = row[i] + binding[str(var)] = self._format_result_value(value) + bindings.append(binding) + + return SPARQLResult( + bindings=bindings, + variables=[str(v) for v in variables], + execution_time=execution_time + ) + + except Exception as e: + logger.error(f"SPARQL execution error: {e}") + return SPARQLResult( + bindings=[], + variables=[], + execution_time=time.time() - start_time + ) + + def _format_result_value(self, value): + """Format result value for output. + + Args: + value: RDF value (URIRef, Literal, BNode) + + Returns: + Formatted value + """ + if isinstance(value, URIRef): + return {'type': 'uri', 'value': str(value)} + elif isinstance(value, Literal): + result = {'type': 'literal', 'value': str(value)} + if value.datatype: + result['datatype'] = str(value.datatype) + if value.language: + result['language'] = value.language + return result + elif isinstance(value, BNode): + return {'type': 'bnode', 'value': str(value)} + else: + return {'type': 'unknown', 'value': str(value)} + + def load_triples_from_store(self, config_store: ConfigTableStore): + """Load triples from TrustGraph's storage into the RDF graph. + + Args: + config_store: Configuration store with triples + """ + # This would need to be implemented based on how triples are stored + # in TrustGraph's Cassandra tables + logger.info("Loading triples from TrustGraph store...") + + # Example implementation - would need to be adapted + # to actual TrustGraph storage format + try: + # Get all triple data + # This is a placeholder - actual implementation would need + # to query the appropriate TrustGraph tables + pass + + except Exception as e: + logger.error(f"Error loading triples: {e}") + + def close(self): + """Close the engine and connections.""" + if self.store: + self.store.close() + logger.info("SPARQL-Cassandra engine closed") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py new file mode 100644 index 00000000..44c7e0a1 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py @@ -0,0 +1,487 @@ +""" +SPARQL query generator for ontology-sensitive queries. +Converts natural language questions to SPARQL queries for Cassandra execution. +""" + +import logging +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset + +logger = logging.getLogger(__name__) + + +@dataclass +class SPARQLQuery: + """Generated SPARQL query with metadata.""" + query: str + variables: List[str] + query_type: str # SELECT, ASK, CONSTRUCT, DESCRIBE + explanation: str + complexity_score: float + + +class SPARQLGenerator: + """Generates SPARQL queries from natural language questions using LLM assistance.""" + + def __init__(self, prompt_service=None): + """Initialize SPARQL generator. + + Args: + prompt_service: Service for LLM-based query generation + """ + self.prompt_service = prompt_service + + # SPARQL query templates for common patterns + self.templates = { + 'simple_class_query': """ +PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?entity ?label WHERE {{ + ?entity rdf:type :{class_name} . + OPTIONAL {{ ?entity rdfs:label ?label }} +}}""", + + 'property_query': """ +PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subject ?object WHERE {{ + ?subject :{property} ?object . + ?subject rdf:type :{subject_class} . +}}""", + + 'hierarchy_query': """ +PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subclass ?superclass WHERE {{ + ?subclass rdfs:subClassOf* ?superclass . + ?superclass rdf:type :{root_class} . +}}""", + + 'count_query': """ +PREFIX : <{namespace}> +PREFIX rdf: + +SELECT (COUNT(?entity) AS ?count) WHERE {{ + ?entity rdf:type :{class_name} . + {additional_constraints} +}}""", + + 'boolean_query': """ +PREFIX : <{namespace}> +PREFIX rdf: + +ASK {{ + {triple_pattern} +}}""" + } + + async def generate_sparql(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> SPARQLQuery: + """Generate SPARQL query for a question. + + Args: + question_components: Analyzed question components + ontology_subset: Relevant ontology subset + + Returns: + Generated SPARQL query + """ + # Try template-based generation first + template_query = self._try_template_generation(question_components, ontology_subset) + if template_query: + logger.debug("Generated SPARQL using template") + return template_query + + # Fall back to LLM-based generation + if self.prompt_service: + llm_query = await self._generate_with_llm(question_components, ontology_subset) + if llm_query: + logger.debug("Generated SPARQL using LLM") + return llm_query + + # Final fallback to simple pattern + logger.warning("Falling back to simple SPARQL pattern") + return self._generate_fallback_query(question_components, ontology_subset) + + def _try_template_generation(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]: + """Try to generate query using templates. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Generated query or None if no template matches + """ + namespace = ontology_subset.metadata.get('namespace', 'http://example.org/') + + # Simple class query (What are the animals?) + if (question_components.question_type == QuestionType.RETRIEVAL and + len(question_components.entities) == 1 and + question_components.entities[0].lower() in [c.lower() for c in ontology_subset.classes]): + + class_name = self._find_matching_class(question_components.entities[0], ontology_subset) + if class_name: + query = self.templates['simple_class_query'].format( + namespace=namespace, + class_name=class_name + ) + return SPARQLQuery( + query=query, + variables=['entity', 'label'], + query_type='SELECT', + explanation=f"Retrieve all instances of {class_name}", + complexity_score=0.3 + ) + + # Count query (How many animals are there?) + if (question_components.question_type == QuestionType.AGGREGATION and + 'count' in question_components.aggregations and + len(question_components.entities) >= 1): + + class_name = self._find_matching_class(question_components.entities[0], ontology_subset) + if class_name: + query = self.templates['count_query'].format( + namespace=namespace, + class_name=class_name, + additional_constraints=self._build_constraints(question_components, ontology_subset) + ) + return SPARQLQuery( + query=query, + variables=['count'], + query_type='SELECT', + explanation=f"Count instances of {class_name}", + complexity_score=0.4 + ) + + # Boolean query (Is X a Y?) + if question_components.question_type == QuestionType.BOOLEAN: + triple_pattern = self._build_boolean_pattern(question_components, ontology_subset) + if triple_pattern: + query = self.templates['boolean_query'].format( + namespace=namespace, + triple_pattern=triple_pattern + ) + return SPARQLQuery( + query=query, + variables=[], + query_type='ASK', + explanation="Boolean query for fact checking", + complexity_score=0.2 + ) + + return None + + async def _generate_with_llm(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]: + """Generate SPARQL using LLM. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Generated query or None if failed + """ + try: + prompt = self._build_sparql_prompt(question_components, ontology_subset) + response = await self.prompt_service.generate_sparql(prompt=prompt) + + if response and isinstance(response, dict): + query = response.get('query', '').strip() + if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')): + return SPARQLQuery( + query=query, + variables=self._extract_variables(query), + query_type=query.split()[0].upper(), + explanation=response.get('explanation', 'Generated by LLM'), + complexity_score=self._calculate_complexity(query) + ) + + except Exception as e: + logger.error(f"LLM SPARQL generation failed: {e}") + + return None + + def _build_sparql_prompt(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> str: + """Build prompt for LLM SPARQL generation. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Formatted prompt string + """ + namespace = ontology_subset.metadata.get('namespace', 'http://example.org/') + + # Format ontology elements + classes_str = self._format_classes_for_prompt(ontology_subset.classes, namespace) + props_str = self._format_properties_for_prompt( + ontology_subset.object_properties, + ontology_subset.datatype_properties, + namespace + ) + + prompt = f"""Generate a SPARQL query for the following question using the provided ontology. + +QUESTION: {question_components.original_question} + +ONTOLOGY NAMESPACE: {namespace} + +AVAILABLE CLASSES: +{classes_str} + +AVAILABLE PROPERTIES: +{props_str} + +RULES: +- Use proper SPARQL syntax +- Include appropriate prefixes +- Use property paths for hierarchical queries (rdfs:subClassOf*) +- Add FILTER clauses for constraints +- Optimize for Cassandra backend +- Return both query and explanation + +QUERY TYPE HINTS: +- Question type: {question_components.question_type.value} +- Expected answer: {question_components.expected_answer_type} +- Entities mentioned: {', '.join(question_components.entities)} +- Aggregations: {', '.join(question_components.aggregations)} + +Generate a complete SPARQL query:""" + + return prompt + + def _generate_fallback_query(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> SPARQLQuery: + """Generate simple fallback query. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Basic SPARQL query + """ + namespace = ontology_subset.metadata.get('namespace', 'http://example.org/') + + # Very basic SELECT query + query = f"""PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subject ?predicate ?object WHERE {{ + ?subject ?predicate ?object . + FILTER(CONTAINS(STR(?subject), "{question_components.keywords[0] if question_components.keywords else 'entity'}")) +}} +LIMIT 10""" + + return SPARQLQuery( + query=query, + variables=['subject', 'predicate', 'object'], + query_type='SELECT', + explanation="Fallback query for basic pattern matching", + complexity_score=0.1 + ) + + def _find_matching_class(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]: + """Find matching class in ontology subset. + + Args: + entity: Entity string to match + ontology_subset: Ontology subset + + Returns: + Matching class name or None + """ + entity_lower = entity.lower() + + # Direct match + for class_id in ontology_subset.classes: + if class_id.lower() == entity_lower: + return class_id + + # Label match + for class_id, class_def in ontology_subset.classes.items(): + labels = class_def.get('labels', []) + for label in labels: + if isinstance(label, dict): + label_value = label.get('value', '').lower() + if label_value == entity_lower: + return class_id + + # Partial match + for class_id in ontology_subset.classes: + if entity_lower in class_id.lower() or class_id.lower() in entity_lower: + return class_id + + return None + + def _build_constraints(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> str: + """Build constraint clauses for SPARQL. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + SPARQL constraint string + """ + constraints = [] + + for constraint in question_components.constraints: + # Simple constraint patterns + if 'greater than' in constraint.lower(): + # Extract number + import re + numbers = re.findall(r'\d+', constraint) + if numbers: + constraints.append(f"FILTER(?value > {numbers[0]})") + + elif 'less than' in constraint.lower(): + numbers = re.findall(r'\d+', constraint) + if numbers: + constraints.append(f"FILTER(?value < {numbers[0]})") + + return '\n '.join(constraints) + + def _build_boolean_pattern(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[str]: + """Build triple pattern for boolean queries. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + SPARQL triple pattern or None + """ + if len(question_components.entities) >= 2: + subject = question_components.entities[0] + object_val = question_components.entities[1] + + # Try to find connecting property + for prop_id in ontology_subset.object_properties: + return f":{subject} :{prop_id} :{object_val} ." + + # Fallback to type check + return f":{subject} rdf:type :{object_val} ." + + return None + + def _format_classes_for_prompt(self, classes: Dict[str, Any], namespace: str) -> str: + """Format classes for prompt. + + Args: + classes: Classes dictionary + namespace: Ontology namespace + + Returns: + Formatted classes string + """ + if not classes: + return "None" + + lines = [] + for class_id, definition in classes.items(): + comment = definition.get('comment', '') + parent = definition.get('subclass_of', 'Thing') + lines.append(f"- :{class_id} (subclass of :{parent}) - {comment}") + + return '\n'.join(lines) + + def _format_properties_for_prompt(self, + object_props: Dict[str, Any], + datatype_props: Dict[str, Any], + namespace: str) -> str: + """Format properties for prompt. + + Args: + object_props: Object properties + datatype_props: Datatype properties + namespace: Ontology namespace + + Returns: + Formatted properties string + """ + lines = [] + + for prop_id, definition in object_props.items(): + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'Any') + comment = definition.get('comment', '') + lines.append(f"- :{prop_id} (:{domain} -> :{range_val}) - {comment}") + + for prop_id, definition in datatype_props.items(): + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'xsd:string') + comment = definition.get('comment', '') + lines.append(f"- :{prop_id} (:{domain} -> {range_val}) - {comment}") + + return '\n'.join(lines) if lines else "None" + + def _extract_variables(self, query: str) -> List[str]: + """Extract variables from SPARQL query. + + Args: + query: SPARQL query string + + Returns: + List of variable names + """ + import re + variables = re.findall(r'\?(\w+)', query) + return list(set(variables)) + + def _calculate_complexity(self, query: str) -> float: + """Calculate complexity score for SPARQL query. + + Args: + query: SPARQL query string + + Returns: + Complexity score (0.0 to 1.0) + """ + complexity = 0.0 + + # Count different SPARQL features + query_upper = query.upper() + + if 'JOIN' in query_upper or 'UNION' in query_upper: + complexity += 0.3 + if 'FILTER' in query_upper: + complexity += 0.2 + if 'OPTIONAL' in query_upper: + complexity += 0.1 + if 'GROUP BY' in query_upper: + complexity += 0.2 + if 'ORDER BY' in query_upper: + complexity += 0.1 + if '*' in query: # Property paths + complexity += 0.1 + + # Count variables + variables = self._extract_variables(query) + complexity += len(variables) * 0.05 + + return min(complexity, 1.0) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index fae3a09a..012d91b7 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -132,20 +132,20 @@ class Processor(DocumentEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_create_collection(self, request): - """Create a Milvus collection for document embeddings""" + """ + No-op for collection creation - collections are created lazily on first write + with the correct dimension determined from the actual embeddings. + """ try: - if self.vecstore.collection_exists(request.user, request.collection): - logger.info(f"Collection {request.user}/{request.collection} already exists") - else: - self.vecstore.create_collection(request.user, request.collection) - logger.info(f"Created collection {request.user}/{request.collection}") + logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") + self.vecstore.create_collection(request.user, request.collection) # Send success response response = StorageManagementResponse(error=None) await self.storage_response_producer.send(response) except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) + logger.error(f"Failed to handle create collection request: {e}", exc_info=True) response = StorageManagementResponse( error=Error( type="creation_error", diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index f940ce2a..4d3c43bb 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -123,19 +123,6 @@ class Processor(DocumentEmbeddingsStoreService): async def store_document_embeddings(self, message): - index_name = ( - "d-" + message.metadata.user + "-" + message.metadata.collection - ) - - # Validate collection exists before accepting writes - if not self.pinecone.has_index(index_name): - error_msg = ( - f"Collection {message.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - for emb in message.chunks: if emb.chunk is None or emb.chunk == b"": continue @@ -145,6 +132,17 @@ class Processor(DocumentEmbeddingsStoreService): for vec in emb.vectors: + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) + + # Lazily create index if it doesn't exist + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) + index = self.pinecone.Index(index_name) # Generate unique ID for each vector @@ -220,23 +218,19 @@ class Processor(DocumentEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_create_collection(self, request): - """Create a Pinecone index for document embeddings""" + """ + No-op for collection creation - indexes are created lazily on first write + with the correct dimension determined from the actual embeddings. + """ try: - index_name = f"d-{request.user}-{request.collection}" - - if self.pinecone.has_index(index_name): - logger.info(f"Pinecone index {index_name} already exists") - else: - # Create with default dimension - will need to be recreated if dimension doesn't match - self.create_index(index_name, dim=384) - logger.info(f"Created Pinecone index: {index_name}") + logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") # Send success response response = StorageManagementResponse(error=None) await self.storage_response_producer.send(response) except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) + logger.error(f"Failed to handle create collection request: {e}", exc_info=True) response = StorageManagementResponse( error=Error( type="creation_error", @@ -246,22 +240,34 @@ class Processor(DocumentEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_delete_collection(self, request): - """Delete the collection for document embeddings""" + """ + Delete all dimension variants of the index for document embeddings. + Since indexes are created with dimension suffixes (e.g., d-user-coll-384), + we need to find and delete all matching indexes. + """ try: - index_name = f"d-{request.user}-{request.collection}" + prefix = f"d-{request.user}-{request.collection}-" - if self.pinecone.has_index(index_name): - self.pinecone.delete_index(index_name) - logger.info(f"Deleted Pinecone index: {index_name}") + # Get all indexes and filter for matches + all_indexes = self.pinecone.list_indexes() + matching_indexes = [ + idx.name for idx in all_indexes + if idx.name.startswith(prefix) + ] + + if not matching_indexes: + logger.info(f"No indexes found matching prefix {prefix}") else: - logger.info(f"Index {index_name} does not exist, nothing to delete") + for index_name in matching_indexes: + self.pinecone.delete_index(index_name) + logger.info(f"Deleted Pinecone index: {index_name}") + logger.info(f"Deleted {len(matching_indexes)} index(es) for {request.user}/{request.collection}") # Send success response response = StorageManagementResponse( error=None # No error means success ) await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index dfb9980f..225beb9c 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -79,20 +79,6 @@ class Processor(DocumentEmbeddingsStoreService): async def store_document_embeddings(self, message): - # Validate collection exists before accepting writes - collection = ( - "d_" + message.metadata.user + "_" + - message.metadata.collection - ) - - if not self.qdrant.collection_exists(collection): - error_msg = ( - f"Collection {message.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - for emb in message.chunks: chunk = emb.chunk.decode("utf-8") @@ -100,6 +86,23 @@ class Processor(DocumentEmbeddingsStoreService): for vec in emb.vectors: + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) + + # Lazily create collection if it doesn't exist + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( + collection_name=collection, + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) + ) + self.qdrant.upsert( collection_name=collection, points=[ @@ -160,30 +163,19 @@ class Processor(DocumentEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_create_collection(self, request): - """Create a Qdrant collection for document embeddings""" + """ + No-op for collection creation - collections are created lazily on first write + with the correct dimension determined from the actual embeddings. + """ try: - collection_name = f"d_{request.user}_{request.collection}" - - if self.qdrant.collection_exists(collection_name): - logger.info(f"Qdrant collection {collection_name} already exists") - else: - # Create collection with default dimension (will be recreated with correct dim on first write if needed) - # Using a placeholder dimension - actual dimension determined by first embedding - self.qdrant.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=384, # Default dimension, common for many models - distance=Distance.COSINE - ) - ) - logger.info(f"Created Qdrant collection: {collection_name}") + logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") # Send success response response = StorageManagementResponse(error=None) await self.storage_response_producer.send(response) except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) + logger.error(f"Failed to handle create collection request: {e}", exc_info=True) response = StorageManagementResponse( error=Error( type="creation_error", @@ -193,22 +185,34 @@ class Processor(DocumentEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_delete_collection(self, request): - """Delete the collection for document embeddings""" + """ + Delete all dimension variants of the collection for document embeddings. + Since collections are created with dimension suffixes (e.g., d_user_coll_384), + we need to find and delete all matching collections. + """ try: - collection_name = f"d_{request.user}_{request.collection}" + prefix = f"d_{request.user}_{request.collection}_" - if self.qdrant.collection_exists(collection_name): - self.qdrant.delete_collection(collection_name) - logger.info(f"Deleted Qdrant collection: {collection_name}") + # Get all collections and filter for matches + all_collections = self.qdrant.get_collections().collections + matching_collections = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if not matching_collections: + logger.info(f"No collections found matching prefix {prefix}") else: - logger.info(f"Collection {collection_name} does not exist, nothing to delete") + for collection_name in matching_collections: + self.qdrant.delete_collection(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {request.user}/{request.collection}") # Send success response response = StorageManagementResponse( error=None # No error means success ) await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 7ccd027b..cca0de95 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -128,20 +128,20 @@ class Processor(GraphEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_create_collection(self, request): - """Create a Milvus collection for graph embeddings""" + """ + No-op for collection creation - collections are created lazily on first write + with the correct dimension determined from the actual embeddings. + """ try: - if self.vecstore.collection_exists(request.user, request.collection): - logger.info(f"Collection {request.user}/{request.collection} already exists") - else: - self.vecstore.create_collection(request.user, request.collection) - logger.info(f"Created collection {request.user}/{request.collection}") + logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") + self.vecstore.create_collection(request.user, request.collection) # Send success response response = StorageManagementResponse(error=None) await self.storage_response_producer.send(response) except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) + logger.error(f"Failed to handle create collection request: {e}", exc_info=True) response = StorageManagementResponse( error=Error( type="creation_error", diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index f97f2a46..30d3d3e5 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -123,19 +123,6 @@ class Processor(GraphEmbeddingsStoreService): async def store_graph_embeddings(self, message): - index_name = ( - "t-" + message.metadata.user + "-" + message.metadata.collection - ) - - # Validate collection exists before accepting writes - if not self.pinecone.has_index(index_name): - error_msg = ( - f"Collection {message.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - for entity in message.entities: if entity.entity.value == "" or entity.entity.value is None: @@ -143,6 +130,17 @@ class Processor(GraphEmbeddingsStoreService): for vec in entity.vectors: + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) + + # Lazily create index if it doesn't exist + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) + index = self.pinecone.Index(index_name) # Generate unique ID for each vector @@ -218,23 +216,19 @@ class Processor(GraphEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_create_collection(self, request): - """Create a Pinecone index for graph embeddings""" + """ + No-op for collection creation - indexes are created lazily on first write + with the correct dimension determined from the actual embeddings. + """ try: - index_name = f"t-{request.user}-{request.collection}" - - if self.pinecone.has_index(index_name): - logger.info(f"Pinecone index {index_name} already exists") - else: - # Create with default dimension - will need to be recreated if dimension doesn't match - self.create_index(index_name, dim=384) - logger.info(f"Created Pinecone index: {index_name}") + logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") # Send success response response = StorageManagementResponse(error=None) await self.storage_response_producer.send(response) except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) + logger.error(f"Failed to handle create collection request: {e}", exc_info=True) response = StorageManagementResponse( error=Error( type="creation_error", @@ -244,22 +238,34 @@ class Processor(GraphEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_delete_collection(self, request): - """Delete the collection for graph embeddings""" + """ + Delete all dimension variants of the index for graph embeddings. + Since indexes are created with dimension suffixes (e.g., t-user-coll-384), + we need to find and delete all matching indexes. + """ try: - index_name = f"t-{request.user}-{request.collection}" + prefix = f"t-{request.user}-{request.collection}-" - if self.pinecone.has_index(index_name): - self.pinecone.delete_index(index_name) - logger.info(f"Deleted Pinecone index: {index_name}") + # Get all indexes and filter for matches + all_indexes = self.pinecone.list_indexes() + matching_indexes = [ + idx.name for idx in all_indexes + if idx.name.startswith(prefix) + ] + + if not matching_indexes: + logger.info(f"No indexes found matching prefix {prefix}") else: - logger.info(f"Index {index_name} does not exist, nothing to delete") + for index_name in matching_indexes: + self.pinecone.delete_index(index_name) + logger.info(f"Deleted Pinecone index: {index_name}") + logger.info(f"Deleted {len(matching_indexes)} index(es) for {request.user}/{request.collection}") # Send success response response = StorageManagementResponse( error=None # No error means success ) await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 6446f7f5..0b15996f 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -69,22 +69,6 @@ class Processor(GraphEmbeddingsStoreService): metrics=storage_response_metrics, ) - def get_collection(self, user, collection): - """Get collection name and validate it exists""" - cname = ( - "t_" + user + "_" + collection - ) - - if not self.qdrant.collection_exists(cname): - error_msg = ( - f"Collection {collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - - return cname - async def start(self): """Start the processor and its storage management consumer""" await super().start() @@ -101,10 +85,23 @@ class Processor(GraphEmbeddingsStoreService): for vec in entity.vectors: - collection = self.get_collection( - message.metadata.user, message.metadata.collection + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" ) + # Lazily create collection if it doesn't exist + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( + collection_name=collection, + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) + ) + self.qdrant.upsert( collection_name=collection, points=[ @@ -165,30 +162,19 @@ class Processor(GraphEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_create_collection(self, request): - """Create a Qdrant collection for graph embeddings""" + """ + No-op for collection creation - collections are created lazily on first write + with the correct dimension determined from the actual embeddings. + """ try: - collection_name = f"t_{request.user}_{request.collection}" - - if self.qdrant.collection_exists(collection_name): - logger.info(f"Qdrant collection {collection_name} already exists") - else: - # Create collection with default dimension (will be recreated with correct dim on first write if needed) - # Using a placeholder dimension - actual dimension determined by first embedding - self.qdrant.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=384, # Default dimension, common for many models - distance=Distance.COSINE - ) - ) - logger.info(f"Created Qdrant collection: {collection_name}") + logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") # Send success response response = StorageManagementResponse(error=None) await self.storage_response_producer.send(response) except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) + logger.error(f"Failed to handle create collection request: {e}", exc_info=True) response = StorageManagementResponse( error=Error( type="creation_error", @@ -198,22 +184,34 @@ class Processor(GraphEmbeddingsStoreService): await self.storage_response_producer.send(response) async def handle_delete_collection(self, request): - """Delete the collection for graph embeddings""" + """ + Delete all dimension variants of the collection for graph embeddings. + Since collections are created with dimension suffixes (e.g., t_user_coll_384), + we need to find and delete all matching collections. + """ try: - collection_name = f"t_{request.user}_{request.collection}" + prefix = f"t_{request.user}_{request.collection}_" - if self.qdrant.collection_exists(collection_name): - self.qdrant.delete_collection(collection_name) - logger.info(f"Deleted Qdrant collection: {collection_name}") + # Get all collections and filter for matches + all_collections = self.qdrant.get_collections().collections + matching_collections = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if not matching_collections: + logger.info(f"No collections found matching prefix {prefix}") else: - logger.info(f"Collection {collection_name} does not exist, nothing to delete") + for collection_name in matching_collections: + self.qdrant.delete_collection(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + logger.info(f"Deleted {len(matching_collections)} collection(s) for {request.user}/{request.collection}") # Send success response response = StorageManagementResponse( error=None # No error means success ) await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index 89aafbb3..8f1d4d2a 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 7c3fc82f..5e1f98ce 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", "pulsar-client", "google-cloud-aiplatform", "prometheus-client", diff --git a/trustgraph/pyproject.toml b/trustgraph/pyproject.toml index 1ee4fc88..8f4fcaf8 100644 --- a/trustgraph/pyproject.toml +++ b/trustgraph/pyproject.toml @@ -10,12 +10,12 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.4,<1.5", - "trustgraph-bedrock>=1.4,<1.5", - "trustgraph-cli>=1.4,<1.5", - "trustgraph-embeddings-hf>=1.4,<1.5", - "trustgraph-flow>=1.4,<1.5", - "trustgraph-vertexai>=1.4,<1.5", + "trustgraph-base>=1.5,<1.6", + "trustgraph-bedrock>=1.5,<1.6", + "trustgraph-cli>=1.5,<1.6", + "trustgraph-embeddings-hf>=1.5,<1.6", + "trustgraph-flow>=1.5,<1.6", + "trustgraph-vertexai>=1.5,<1.6", ] classifiers = [ "Programming Language :: Python :: 3",