Merge branch 'release/v1.5'

This commit is contained in:
Cyber MacGeddon 2025-11-24 09:58:45 +00:00
commit 85c8b175f2
95 changed files with 17496 additions and 729 deletions

View file

@ -15,14 +15,14 @@ jobs:
runs-on: ubuntu-latest
container:
image: python:3.12
image: python:3.13
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup packages
run: make update-package-versions VERSION=1.4.999
run: make update-package-versions VERSION=1.5.999
- name: Setup environment
run: python3 -m venv env

View file

@ -1,84 +0,0 @@
# ----------------------------------------------------------------------------
# Build an AI container. This does the torch install which is huge, and I
# like to avoid re-doing this.
# ----------------------------------------------------------------------------
FROM docker.io/fedora:40 AS ai
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \
python3-rdflib
RUN pip3 install torch==2.5.1+cpu \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
anthropic boto3 cohere mistralai openai google-cloud-aiplatform \
ollama google-generativeai \
langchain==0.3.13 langchain-core==0.3.28 langchain-huggingface==0.1.2 \
langchain-text-splitters==0.3.4 \
langchain-community==0.3.13 \
sentence-transformers==3.4.0 transformers==4.47.1 \
huggingface-hub==0.27.0 \
pymilvus \
pulsar-client==3.5.0 cassandra-driver pyyaml \
neo4j tiktoken falkordb && \
pip3 cache purge
# Most commonly used embeddings model, just build it into the container
# image
RUN huggingface-cli download sentence-transformers/all-MiniLM-L6-v2
# ----------------------------------------------------------------------------
# Build a container which contains the built Python packages. The build
# creates a bunch of left-over cruft, a separate phase means this is only
# needed to support package build
# ----------------------------------------------------------------------------
FROM ai AS build
COPY trustgraph-base/ /root/build/trustgraph-base/
COPY trustgraph-flow/ /root/build/trustgraph-flow/
COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/
COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/
COPY trustgraph-embeddings-hf/ /root/build/trustgraph-embeddings-hf/
COPY trustgraph-cli/ /root/build/trustgraph-cli/
COPY trustgraph-ocr/ /root/build/trustgraph-ocr/
WORKDIR /root/build/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-flow/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-embeddings-hf/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-ocr/
RUN ls /root/wheels
# ----------------------------------------------------------------------------
# Finally, the target container. Start with base and add the package.
# ----------------------------------------------------------------------------
FROM ai
COPY --from=build /root/wheels /root/wheels
RUN \
pip3 install /root/wheels/trustgraph_base-* && \
pip3 install /root/wheels/trustgraph_flow-* && \
pip3 install /root/wheels/trustgraph_vertexai-* && \
pip3 install /root/wheels/trustgraph_bedrock-* && \
pip3 install /root/wheels/trustgraph_embeddings_hf-* && \
pip3 install /root/wheels/trustgraph_cli-* && \
pip3 install /root/wheels/trustgraph_ocr-* && \
pip3 cache purge && \
rm -rf /root/wheels
WORKDIR /
CMD sleep 1000000

View file

@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \

View file

@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \

View file

@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp rdflib && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
@ -22,7 +22,7 @@ RUN pip3 install --no-cache-dir \
langchain-text-splitters==0.3.8 \
langchain-community==0.3.24 \
pymilvus \
pulsar-client==3.7.0 cassandra-driver pyyaml \
pulsar-client==3.7.0 scylla-driver pyyaml \
neo4j tiktoken falkordb && \
pip3 cache purge

View file

@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS ai
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \

View file

@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir mcp websockets && \
dnf clean all

View file

@ -8,9 +8,9 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
RUN dnf install -y python3.13 && \
dnf install -y tesseract poppler-utils && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \

View file

@ -8,8 +8,8 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \

View file

@ -3,12 +3,12 @@
## Synopsis
```
tg-set-mcp-tool [OPTIONS] --name NAME --tool-url URL
tg-set-mcp-tool [OPTIONS] --id ID --tool-url URL [--auth-token TOKEN]
```
## Description
The `tg-set-mcp-tool` command configures and registers MCP (Model Control Protocol) tools in the TrustGraph system. It allows defining MCP tool configurations with name and URL. Tools are stored in the 'mcp' configuration group for discovery and execution.
The `tg-set-mcp-tool` command configures and registers MCP (Model Control Protocol) tools in the TrustGraph system. It allows defining MCP tool configurations with id, URL, and optional authentication token. Tools are stored in the 'mcp' configuration group for discovery and execution.
This command is useful for:
- Registering MCP tool endpoints for agent use
@ -25,16 +25,27 @@ The command stores MCP tool configurations in the 'mcp' configuration group, sep
- Default: `http://localhost:8088/` (or `TRUSTGRAPH_URL` environment variable)
- Should point to a running TrustGraph API instance
- `--name NAME`
- **Required.** MCP tool name identifier
- `-i, --id ID`
- **Required.** MCP tool identifier
- Used to reference the MCP tool in configurations
- Must be unique within the MCP tool registry
- `-r, --remote-name NAME`
- **Optional.** Remote MCP tool name used by the MCP server
- If not specified, defaults to the value of `--id`
- Use when the MCP server expects a different tool name
- `--tool-url URL`
- **Required.** MCP tool URL endpoint
- Should point to the MCP server endpoint providing the tool functionality
- Must be a valid URL accessible by the TrustGraph system
- `--auth-token TOKEN`
- **Optional.** Bearer token for authentication
- Used to authenticate with secured MCP endpoints
- Token is sent as `Authorization: Bearer {TOKEN}` header
- Stored in plaintext in configuration (see Security Considerations)
- `-h, --help`
- Show help message and exit
@ -44,54 +55,96 @@ The command stores MCP tool configurations in the 'mcp' configuration group, sep
Register a weather service MCP tool:
```bash
tg-set-mcp-tool --name weather --tool-url "http://localhost:3000/weather"
tg-set-mcp-tool --id weather --tool-url "http://localhost:3000/weather"
```
### Calculator MCP Tool
Register a calculator MCP tool:
```bash
tg-set-mcp-tool --name calculator --tool-url "http://mcp-tools.example.com/calc"
tg-set-mcp-tool --id calculator --tool-url "http://mcp-tools.example.com/calc"
```
### Remote MCP Service
Register a remote MCP service:
```bash
tg-set-mcp-tool --name document-processor \
tg-set-mcp-tool --id document-processor \
--tool-url "https://api.example.com/mcp/documents"
```
### Secured MCP Tool with Authentication
Register an MCP tool that requires bearer token authentication:
```bash
tg-set-mcp-tool --id secure-tool \
--tool-url "https://api.example.com/mcp" \
--auth-token "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
```
### MCP Tool with Remote Name
Register an MCP tool where the server uses a different name:
```bash
tg-set-mcp-tool --id my-weather \
--remote-name weather_v2 \
--tool-url "http://weather-server:3000/api"
```
### Custom API URL
Register MCP tool with custom TrustGraph API:
```bash
tg-set-mcp-tool -u http://trustgraph.example.com:8088/ \
--name custom-mcp --tool-url "http://custom.mcp.com/api"
--id custom-mcp --tool-url "http://custom.mcp.com/api"
```
### Local Development Setup
Register MCP tools for local development:
```bash
tg-set-mcp-tool --name dev-tool --tool-url "http://localhost:8080/mcp"
tg-set-mcp-tool --id dev-tool --tool-url "http://localhost:8080/mcp"
```
### Production Setup with Authentication
Register authenticated MCP tools for production:
```bash
# Using environment variable for token
export MCP_AUTH_TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
tg-set-mcp-tool --id prod-tool \
--tool-url "https://prod-mcp.example.com/api" \
--auth-token "$MCP_AUTH_TOKEN"
```
## MCP Tool Configuration
MCP tools are configured with minimal metadata:
MCP tools are configured with the following metadata:
- **name**: Unique identifier for the tool
- **id**: Unique identifier for the tool (configuration key)
- **remote-name**: Name used by the MCP server (optional, defaults to id)
- **url**: Endpoint URL for the MCP server
- **auth-token**: Bearer token for authentication (optional)
The configuration is stored as JSON in the 'mcp' configuration group:
**Basic configuration:**
```json
{
"name": "weather",
"remote-name": "weather",
"url": "http://localhost:3000/weather"
}
```
**Configuration with authentication:**
```json
{
"remote-name": "secure-tool",
"url": "https://api.example.com/mcp",
"auth-token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
```
## Advanced Usage
### Updating Existing MCP Tools
@ -99,7 +152,15 @@ The configuration is stored as JSON in the 'mcp' configuration group:
Update an existing MCP tool configuration:
```bash
# Update MCP tool URL
tg-set-mcp-tool --name weather --tool-url "http://new-weather-server:3000/api"
tg-set-mcp-tool --id weather --tool-url "http://new-weather-server:3000/api"
# Add authentication to existing tool
tg-set-mcp-tool --id weather \
--tool-url "http://weather-server:3000/api" \
--auth-token "new-token-here"
# Remove authentication (by setting tool without auth-token)
tg-set-mcp-tool --id weather --tool-url "http://weather-server:3000/api"
```
### Batch MCP Tool Registration
@ -108,22 +169,33 @@ Register multiple MCP tools in a script:
```bash
#!/bin/bash
# Register a suite of MCP tools
tg-set-mcp-tool --name search --tool-url "http://search-mcp:3000/api"
tg-set-mcp-tool --name translate --tool-url "http://translate-mcp:3000/api"
tg-set-mcp-tool --name summarize --tool-url "http://summarize-mcp:3000/api"
tg-set-mcp-tool --id search --tool-url "http://search-mcp:3000/api"
tg-set-mcp-tool --id translate --tool-url "http://translate-mcp:3000/api"
tg-set-mcp-tool --id summarize --tool-url "http://summarize-mcp:3000/api"
# Register secured tools with authentication
tg-set-mcp-tool --id secure-search \
--tool-url "https://secure-search:3000/api" \
--auth-token "$SEARCH_TOKEN"
tg-set-mcp-tool --id secure-translate \
--tool-url "https://secure-translate:3000/api" \
--auth-token "$TRANSLATE_TOKEN"
```
### Environment-Specific Configuration
Configure MCP tools for different environments:
```bash
# Development environment
# Development environment (no auth)
export TRUSTGRAPH_URL="http://dev.trustgraph.com:8088/"
tg-set-mcp-tool --name dev-mcp --tool-url "http://dev.mcp.com/api"
tg-set-mcp-tool --id dev-mcp --tool-url "http://dev.mcp.com/api"
# Production environment
# Production environment (with auth)
export TRUSTGRAPH_URL="http://prod.trustgraph.com:8088/"
tg-set-mcp-tool --name prod-mcp --tool-url "http://prod.mcp.com/api"
export PROD_MCP_TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
tg-set-mcp-tool --id prod-mcp \
--tool-url "https://prod.mcp.com/api" \
--auth-token "$PROD_MCP_TOKEN"
```
### MCP Tool Validation
@ -131,10 +203,10 @@ tg-set-mcp-tool --name prod-mcp --tool-url "http://prod.mcp.com/api"
Verify MCP tool registration:
```bash
# Register MCP tool and verify
tg-set-mcp-tool --name test-mcp --tool-url "http://test.mcp.com/api"
tg-set-mcp-tool --id test-mcp --tool-url "http://test.mcp.com/api"
# Check if MCP tool was registered
tg-show-mcp-tools | grep test-mcp
# Check if MCP tool was registered and view auth status
tg-show-mcp-tools
```
## Error Handling
@ -149,15 +221,15 @@ The command handles various error conditions:
Common error scenarios:
```bash
# Missing required field
tg-set-mcp-tool --name tool1
tg-set-mcp-tool --id tool1
# Output: Exception: Must specify --tool-url for MCP tool
# Missing name
# Missing id
tg-set-mcp-tool --tool-url "http://example.com/mcp"
# Output: Exception: Must specify --name for MCP tool
# Output: Exception: Must specify --id for MCP tool
# Invalid API URL
tg-set-mcp-tool -u "invalid-url" --name tool1 --tool-url "http://mcp.com"
tg-set-mcp-tool -u "invalid-url" --id tool1 --tool-url "http://mcp.com"
# Output: Exception: [API connection error]
```
@ -168,9 +240,9 @@ tg-set-mcp-tool -u "invalid-url" --name tool1 --tool-url "http://mcp.com"
View registered MCP tools:
```bash
# Register MCP tool
tg-set-mcp-tool --name new-mcp --tool-url "http://new.mcp.com/api"
tg-set-mcp-tool --id new-mcp --tool-url "http://new.mcp.com/api"
# View all MCP tools
# View all MCP tools (shows auth status)
tg-show-mcp-tools
```
@ -178,11 +250,13 @@ tg-show-mcp-tools
Use MCP tools in agent workflows:
```bash
# Register MCP tool
tg-set-mcp-tool --name weather --tool-url "http://weather.mcp.com/api"
# Register MCP tool with authentication
tg-set-mcp-tool --id weather \
--tool-url "https://weather.mcp.com/api" \
--auth-token "$WEATHER_TOKEN"
# Invoke MCP tool directly
tg-invoke-mcp-tool --name weather --input "location=London"
# Invoke MCP tool directly (auth handled automatically)
tg-invoke-mcp-tool --name weather --parameters '{"location": "London"}'
```
### With Configuration Management
@ -190,21 +264,24 @@ tg-invoke-mcp-tool --name weather --input "location=London"
MCP tools integrate with configuration management:
```bash
# Register MCP tool
tg-set-mcp-tool --name config-mcp --tool-url "http://config.mcp.com/api"
tg-set-mcp-tool --id config-mcp --tool-url "http://config.mcp.com/api"
# View configuration including MCP tools
tg-show-config
# View all MCP tool configurations
tg-show-mcp-tools
```
## Best Practices
1. **Clear Naming**: Use descriptive, unique MCP tool names
1. **Clear Naming**: Use descriptive, unique MCP tool identifiers
2. **Reliable URLs**: Ensure MCP endpoints are stable and accessible
3. **Health Checks**: Verify MCP endpoints are operational before registration
4. **Documentation**: Document MCP tool capabilities and usage
5. **Error Handling**: Implement proper error handling for MCP endpoints
6. **Security**: Use secure URLs (HTTPS) when possible
7. **Monitoring**: Monitor MCP tool availability and performance
3. **Use HTTPS**: Always use HTTPS URLs when authentication is required
4. **Secure Tokens**: Store auth tokens in environment variables, not in scripts
5. **Token Rotation**: Regularly rotate authentication tokens
6. **Health Checks**: Verify MCP endpoints are operational before registration
7. **Documentation**: Document MCP tool capabilities and usage
8. **Error Handling**: Implement proper error handling for MCP endpoints
9. **Monitoring**: Monitor MCP tool availability and performance
10. **Access Control**: Restrict access to configuration system containing tokens
## Troubleshooting
@ -248,10 +325,45 @@ The Model Control Protocol (MCP) is a standardized interface for AI model tools:
When registering MCP tools:
1. **URL Validation**: Ensure URLs are legitimate and secure
2. **Network Security**: Use HTTPS when possible
3. **Access Control**: Implement proper authentication for MCP endpoints
4. **Input Validation**: Validate all inputs to MCP tools
5. **Error Handling**: Don't expose sensitive information in error messages
2. **Network Security**: Always use HTTPS for authenticated endpoints
3. **Token Storage**: Auth tokens are stored in plaintext in the configuration system
- Ensure proper access control on the configuration storage
- Use short-lived tokens when possible
- Implement token rotation policies
4. **Token Transmission**: Use HTTPS to prevent token interception
5. **Access Control**: Implement proper authentication for MCP endpoints
6. **Token Exposure**:
- Use environment variables to pass tokens to the command
- Don't hardcode tokens in scripts or commit them to version control
- The `tg-show-mcp-tools` command masks token values for security
7. **Input Validation**: Validate all inputs to MCP tools
8. **Error Handling**: Don't expose sensitive information in error messages
9. **Least Privilege**: Grant tokens minimum required permissions
10. **Audit Logging**: Monitor configuration changes for security events
### Authentication Best Practices
When using the `--auth-token` parameter:
- **Store tokens securely**: Use environment variables or secrets management systems
- **Use HTTPS**: Always use HTTPS URLs when providing authentication tokens
- **Rotate regularly**: Implement a token rotation schedule
- **Monitor usage**: Track which services are accessing authenticated endpoints
- **Revoke on compromise**: Have a process to quickly revoke and rotate compromised tokens
Example secure workflow:
```bash
# Store token in environment variable (not in script)
export MCP_TOKEN=$(cat /secure/path/to/token)
# Use HTTPS for authenticated endpoints
tg-set-mcp-tool --id secure-service \
--tool-url "https://secure.example.com/mcp" \
--auth-token "$MCP_TOKEN"
# Clear token from environment after use
unset MCP_TOKEN
```
## Related Commands

View file

@ -0,0 +1,554 @@
# MCP Tool Bearer Token Authentication Specification
> **⚠️ IMPORTANT: SINGLE-TENANT ONLY**
>
> This specification describes a **basic, service-level authentication mechanism** for MCP tools. It is **NOT** a complete authentication solution and is **NOT suitable** for:
> - Multi-user environments
> - Multi-tenant deployments
> - Federated authentication
> - User context propagation
> - Per-user authorization
>
> This feature provides **one static token per MCP tool**, shared across all users and sessions. If you need per-user or per-tenant authentication, this is not the right solution.
## Overview
**Feature Name**: MCP Tool Bearer Token Authentication Support
**Author**: Claude Code Assistant
**Date**: 2025-11-11
**Status**: In Development
### Executive Summary
Enable MCP tool configurations to specify optional bearer tokens for authenticating with protected MCP servers. This allows TrustGraph to securely invoke MCP tools hosted on servers that require authentication, without modifying the agent or tool invocation interfaces.
**IMPORTANT**: This is a basic authentication mechanism designed for single-tenant, service-to-service authentication scenarios. It is **NOT** suitable for:
- Multi-user environments where different users need different credentials
- Multi-tenant deployments requiring per-tenant isolation
- Federated authentication scenarios
- User-level authentication or authorization
- Dynamic credential management or token refresh
This feature provides a static, system-wide bearer token per MCP tool configuration, shared across all users and invocations of that tool.
### Problem Statement
Currently, MCP tools can only connect to publicly accessible MCP servers. Many production MCP deployments require authentication via bearer tokens for security. Without authentication support:
- MCP tools cannot connect to secured MCP servers
- Users must either expose MCP servers publicly or implement reverse proxies
- No standardized way to pass credentials to MCP connections
- Security best practices cannot be enforced on MCP endpoints
### Goals
- [ ] Allow MCP tool configurations to specify optional `auth-token` parameter
- [ ] Update MCP tool service to use bearer tokens when connecting to MCP servers
- [ ] Update CLI tools to support setting/displaying auth tokens
- [ ] Maintain backward compatibility with unauthenticated MCP configurations
- [ ] Document security considerations for token storage
### Non-Goals
- Dynamic token refresh or OAuth flows (static tokens only)
- Encryption of stored tokens (configuration system security is out of scope)
- Alternative authentication methods (Basic auth, API keys, etc.)
- Token validation or expiration checking
- **Per-user authentication**: This feature does NOT support user-specific credentials
- **Multi-tenant isolation**: This feature does NOT provide per-tenant token management
- **Federated authentication**: This feature does NOT integrate with identity providers (SSO, OAuth, SAML, etc.)
- **Context-aware authentication**: Tokens are not passed based on user context or session
## Background and Context
### Current State
MCP tool configurations are stored in the `mcp` configuration group with this structure:
```json
{
"remote-name": "tool_name",
"url": "http://mcp-server:3000/api"
}
```
The MCP tool service connects to servers using `streamablehttp_client(url)` without any authentication headers.
### Limitations
**Current System Limitations:**
1. **No authentication support**: Cannot connect to secured MCP servers
2. **Security exposure**: MCP servers must be publicly accessible or use network-level security only
3. **Production deployment issues**: Cannot follow security best practices for API endpoints
**Limitations of This Solution:**
1. **Single-tenant only**: One static token per MCP tool, shared across all users
2. **No per-user credentials**: Cannot authenticate as different users or pass user context
3. **No multi-tenant support**: Cannot isolate credentials by tenant or organization
4. **Static tokens only**: No support for token refresh, rotation, or expiration handling
5. **Service-level authentication**: Authenticates the TrustGraph service, not individual users
6. **Shared security context**: All invocations of an MCP tool use the same credential
### Use Case Applicability
**✅ Appropriate Use Cases:**
- Single-tenant TrustGraph deployments
- Service-to-service authentication (TrustGraph → MCP Server)
- Development and testing environments
- Internal MCP tools accessed by the TrustGraph system
- Scenarios where all users share the same MCP tool access level
- Static, long-lived service credentials
**❌ Inappropriate Use Cases:**
- Multi-user systems requiring per-user authentication
- Multi-tenant SaaS deployments with tenant isolation requirements
- Federated authentication scenarios (SSO, OAuth, SAML)
- Systems requiring user context propagation to MCP servers
- Environments needing dynamic token refresh or short-lived tokens
- Applications where different users need different permission levels
- Compliance requirements for user-level audit trails
**Example Appropriate Scenario:**
A single-organization TrustGraph deployment where all employees use the same internal MCP tool (e.g., company database lookup). The MCP server requires authentication to prevent external access, but all internal users have the same access level.
**Example Inappropriate Scenario:**
A multi-tenant TrustGraph SaaS platform where Tenant A and Tenant B each need to access their own isolated MCP servers with separate credentials. This feature does NOT support per-tenant token management.
### Related Components
- **trustgraph-flow/trustgraph/agent/mcp_tool/service.py**: MCP tool invocation service
- **trustgraph-cli/trustgraph/cli/set_mcp_tool.py**: CLI tool for creating/updating MCP configurations
- **trustgraph-cli/trustgraph/cli/show_mcp_tools.py**: CLI tool for displaying MCP configurations
- **MCP Python SDK**: `streamablehttp_client` from `mcp.client.streamable_http`
## Requirements
### Functional Requirements
1. **MCP Configuration Auth Token**: MCP tool configurations MUST support an optional `auth-token` field
2. **Bearer Token Usage**: MCP tool service MUST send `Authorization: Bearer {token}` header when auth-token is configured
3. **CLI Support**: `tg-set-mcp-tool` MUST accept optional `--auth-token` parameter
4. **Token Display**: `tg-show-mcp-tools` MUST indicate when auth-token is configured (masked for security)
5. **Backward Compatibility**: Existing MCP tool configurations without auth-token MUST continue to work
### Non-Functional Requirements
1. **Backward Compatibility**: Zero breaking changes for existing MCP tool configurations
2. **Performance**: No significant performance impact on MCP tool invocation
3. **Security**: Tokens stored in configuration (document security implications)
### User Stories
1. As a **DevOps engineer**, I want to configure bearer tokens for MCP tools so that I can secure MCP server endpoints
2. As a **CLI user**, I want to set auth tokens when creating MCP tools so that I can connect to protected servers
3. As a **system administrator**, I want to see which MCP tools have authentication configured so that I can audit security settings
## Design
### High-Level Architecture
Extend MCP tool configuration and service to support bearer token authentication:
1. Add optional `auth-token` field to MCP tool configuration schema
2. Modify MCP tool service to read auth-token and pass to HTTP client
3. Update CLI tools to support setting and displaying auth tokens
4. Document security considerations and best practices
### Configuration Schema
**Current Schema**:
```json
{
"remote-name": "tool_name",
"url": "http://mcp-server:3000/api"
}
```
**New Schema** (with optional auth-token):
```json
{
"remote-name": "tool_name",
"url": "http://mcp-server:3000/api",
"auth-token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
```
**Field Descriptions**:
- `remote-name` (optional): Name used by MCP server (defaults to config key)
- `url` (required): MCP server endpoint URL
- `auth-token` (optional): Bearer token for authentication
### Data Flow
1. **Configuration Storage**: User runs `tg-set-mcp-tool --id my-tool --tool-url http://server/api --auth-token xyz123`
2. **Config Loading**: MCP tool service receives config update via `on_mcp_config()` callback
3. **Tool Invocation**: When tool is invoked:
- Service reads `auth-token` from config (if present)
- Creates headers dict: `{"Authorization": "Bearer {token}"}`
- Passes headers to `streamablehttp_client(url, headers=headers)`
- MCP server validates token and processes request
### API Changes
No external API changes - configuration schema extension only.
### Component Details
#### Component 1: service.py (MCP Tool Service)
**File**: `trustgraph-flow/trustgraph/agent/mcp_tool/service.py`
**Purpose**: Invoke MCP tools on remote servers
**Changes Required** (in `invoke_tool()` method):
1. Check for `auth-token` in `self.mcp_services[name]` config
2. Build headers dict with Authorization header if token exists
3. Pass headers to `streamablehttp_client(url, headers=headers)`
**Current Code** (lines 42-89):
```python
async def invoke_tool(self, name, parameters):
try:
if name not in self.mcp_services:
raise RuntimeError(f"MCP service {name} not known")
if "url" not in self.mcp_services[name]:
raise RuntimeError(f"MCP service {name} URL not defined")
url = self.mcp_services[name]["url"]
if "remote-name" in self.mcp_services[name]:
remote_name = self.mcp_services[name]["remote-name"]
else:
remote_name = name
logger.info(f"Invoking {remote_name} at {url}")
# Connect to a streamable HTTP server
async with streamablehttp_client(url) as (
read_stream,
write_stream,
_,
):
# ... rest of method
```
**Modified Code**:
```python
async def invoke_tool(self, name, parameters):
try:
if name not in self.mcp_services:
raise RuntimeError(f"MCP service {name} not known")
if "url" not in self.mcp_services[name]:
raise RuntimeError(f"MCP service {name} URL not defined")
url = self.mcp_services[name]["url"]
if "remote-name" in self.mcp_services[name]:
remote_name = self.mcp_services[name]["remote-name"]
else:
remote_name = name
# Build headers with optional bearer token
headers = {}
if "auth-token" in self.mcp_services[name]:
token = self.mcp_services[name]["auth-token"]
headers["Authorization"] = f"Bearer {token}"
logger.info(f"Invoking {remote_name} at {url}")
# Connect to a streamable HTTP server with headers
async with streamablehttp_client(url, headers=headers) as (
read_stream,
write_stream,
_,
):
# ... rest of method (unchanged)
```
#### Component 2: set_mcp_tool.py (CLI Configuration Tool)
**File**: `trustgraph-cli/trustgraph/cli/set_mcp_tool.py`
**Purpose**: Create/update MCP tool configurations
**Changes Required**:
1. Add `--auth-token` optional argument to argparse
2. Include `auth-token` in configuration JSON when provided
**Current Arguments**:
- `--id` (required): MCP tool identifier
- `--remote-name` (optional): Remote MCP tool name
- `--tool-url` (required): MCP tool URL endpoint
- `-u, --api-url` (optional): TrustGraph API URL
**New Argument**:
- `--auth-token` (optional): Bearer token for authentication
**Modified Configuration Building**:
```python
# Build configuration object
config = {
"url": args.tool_url,
}
if args.remote_name:
config["remote-name"] = args.remote_name
if args.auth_token:
config["auth-token"] = args.auth_token
# Store configuration
api.config().put([
ConfigValue(type="mcp", key=args.id, value=json.dumps(config))
])
```
#### Component 3: show_mcp_tools.py (CLI Display Tool)
**File**: `trustgraph-cli/trustgraph/cli/show_mcp_tools.py`
**Purpose**: Display MCP tool configurations
**Changes Required**:
1. Add "Auth" column to output table
2. Display "Yes" or "No" based on presence of auth-token
3. Do not display actual token value (security)
**Current Output**:
```
ID Remote Name URL
---------- ------------- ------------------------
my-tool my-tool http://server:3000/api
```
**New Output**:
```
ID Remote Name URL Auth
---------- ------------- ------------------------ ------
my-tool my-tool http://server:3000/api Yes
other-tool other-tool http://other:3000/api No
```
#### Component 4: Documentation
**File**: `docs/cli/tg-set-mcp-tool.md`
**Changes Required**:
1. Document new `--auth-token` parameter
2. Provide example usage with authentication
3. Document security considerations
## Implementation Plan
### Phase 1: Create Technical Specification
- [x] Write comprehensive tech spec documenting all changes
### Phase 2: Update MCP Tool Service
- [ ] Modify `invoke_tool()` in `service.py` to read auth-token from config
- [ ] Build headers dict and pass to `streamablehttp_client`
- [ ] Test with authenticated MCP server
### Phase 3: Update CLI Tools
- [ ] Add `--auth-token` argument to `set_mcp_tool.py`
- [ ] Include auth-token in configuration JSON
- [ ] Add "Auth" column to `show_mcp_tools.py` output
- [ ] Test CLI tool changes
### Phase 4: Update Documentation
- [ ] Document `--auth-token` parameter in `tg-set-mcp-tool.md`
- [ ] Add security considerations section
- [ ] Provide example usage
### Phase 5: Testing
- [ ] Test MCP tool with auth-token connects successfully
- [ ] Test backward compatibility (tools without auth-token still work)
- [ ] Test CLI tools accept and store auth-token correctly
- [ ] Test show command displays auth status correctly
### Code Changes Summary
| File | Change Type | Lines | Description |
|------|------------|-------|-------------|
| `service.py` | Modified | ~52-66 | Add auth-token reading and header building |
| `set_mcp_tool.py` | Modified | ~30-60 | Add --auth-token argument and config storage |
| `show_mcp_tools.py` | Modified | ~40-70 | Add Auth column to display |
| `tg-set-mcp-tool.md` | Modified | Various | Document new parameter |
## Testing Strategy
### Unit Tests
- **Auth Token Reading**: Test `invoke_tool()` correctly reads auth-token from config
- **Header Building**: Test Authorization header is built correctly with Bearer prefix
- **Backward Compatibility**: Test tools without auth-token work unchanged
- **CLI Argument Parsing**: Test `--auth-token` argument is parsed correctly
### Integration Tests
- **Authenticated Connection**: Test MCP tool service connects to authenticated server
- **End-to-End**: Test CLI → config storage → service invocation with auth token
- **Token Not Required**: Test connection to unauthenticated server still works
### Manual Testing
- **Real MCP Server**: Test with actual MCP server requiring bearer token authentication
- **CLI Workflow**: Test complete workflow: set tool with auth → invoke tool → verify success
- **Display Masking**: Verify auth status shown but token value not exposed
## Migration and Rollout
### Migration Strategy
No migration required - this is purely additive functionality:
- Existing MCP tool configurations without `auth-token` continue to work unchanged
- New configurations can optionally include `auth-token` field
- CLI tools accept but don't require `--auth-token` parameter
### Rollout Plan
1. **Phase 1**: Deploy core service changes to development/staging
2. **Phase 2**: Deploy CLI tool updates
3. **Phase 3**: Update documentation
4. **Phase 4**: Production rollout with monitoring
### Rollback Plan
- Core changes are backward compatible - existing tools unaffected
- If issues arise, auth-token handling can be disabled by removing header building logic
- CLI changes are independent and can be rolled back separately
## Security Considerations
### ⚠️ Critical Limitation: Single-Tenant Authentication Only
**This authentication mechanism is NOT suitable for multi-user or multi-tenant environments.**
- **Shared credentials**: All users and invocations share the same token per MCP tool
- **No user context**: The MCP server cannot distinguish between different TrustGraph users
- **No tenant isolation**: All tenants share the same credential for each MCP tool
- **Audit trail limitation**: MCP server logs show all requests from the same credential
- **Permission scope**: Cannot enforce different permission levels for different users
**Do NOT use this feature if:**
- Your TrustGraph deployment serves multiple organizations (multi-tenant)
- You need to track which user accessed which MCP tool
- Different users require different permission levels
- You need to comply with user-level audit requirements
- Your MCP server enforces per-user rate limits or quotas
**Alternative solutions for multi-user/multi-tenant scenarios:**
- Implement user context propagation through custom headers
- Deploy separate TrustGraph instances per tenant
- Use network-level isolation (VPCs, service meshes)
- Implement a proxy layer that handles per-user authentication
### Token Storage
**Risk**: Auth tokens stored in plaintext in configuration system
**Mitigation**:
- Document that tokens are stored unencrypted
- Recommend using short-lived tokens when possible
- Recommend proper access control on configuration storage
- Consider future enhancement for encrypted token storage
### Token Exposure
**Risk**: Tokens could be exposed in logs or CLI output
**Mitigation**:
- Do not log token values (only log "auth configured: yes/no")
- CLI show command displays masked status only, not actual token
- Do not include tokens in error messages
### Network Security
**Risk**: Tokens transmitted over unencrypted connections
**Mitigation**:
- Document recommendation to use HTTPS URLs for MCP servers
- Warn users about plaintext transmission risk with HTTP
### Configuration Access
**Risk**: Unauthorized access to configuration system exposes tokens
**Mitigation**:
- Document importance of securing configuration system access
- Recommend principle of least privilege for configuration access
- Consider audit logging for configuration changes (future enhancement)
### Multi-User Environments
**Risk**: In multi-user deployments, all users share the same MCP credentials
**Understanding the Risk**:
- User A and User B both use the same token when accessing an MCP tool
- MCP server cannot distinguish between different TrustGraph users
- No way to enforce per-user permissions or rate limits
- Audit logs on MCP server show all requests from same credential
- If one user's session is compromised, attacker has same MCP access as all users
**This is NOT a bug - it's a fundamental limitation of this design.**
## Performance Impact
- **Minimal overhead**: Header building adds negligible processing time
- **Network impact**: Additional HTTP header adds ~50-200 bytes per request
- **Memory usage**: Negligible increase for storing token string in config
## Documentation
### User Documentation
- [ ] Update `tg-set-mcp-tool.md` with `--auth-token` parameter
- [ ] Add security considerations section
- [ ] Provide example usage with bearer token
- [ ] Document token storage implications
### Developer Documentation
- [ ] Add inline comments for auth token handling in `service.py`
- [ ] Document header building logic
- [ ] Update MCP tool configuration schema documentation
## Open Questions
1. **Token encryption**: Should we implement encrypted token storage in configuration system?
2. **Token refresh**: Future support for OAuth refresh flows or token rotation?
3. **Alternative auth methods**: Should we support Basic auth, API keys, or other methods?
## Alternatives Considered
1. **Environment variables for tokens**: Store tokens in env vars instead of config
- **Rejected**: Complicates deployment and configuration management
2. **Separate secrets store**: Use dedicated secrets management system
- **Deferred**: Out of scope for initial implementation, consider future enhancement
3. **Multiple auth methods**: Support Basic, API key, OAuth, etc.
- **Rejected**: Bearer tokens cover most use cases, keep initial implementation simple
4. **Encrypted token storage**: Encrypt tokens in configuration system
- **Deferred**: Configuration system security is broader concern, defer to future work
5. **Per-invocation tokens**: Allow tokens to be passed at invocation time
- **Rejected**: Violates separation of concerns, agent shouldn't handle credentials
## References
- [MCP Protocol Specification](https://github.com/modelcontextprotocol/spec)
- [HTTP Bearer Authentication (RFC 6750)](https://tools.ietf.org/html/rfc6750)
- [Current MCP Tool Service](../trustgraph-flow/trustgraph/agent/mcp_tool/service.py)
- [MCP Tool Arguments Specification](./mcp-tool-arguments.md)
## Appendix
### Example Usage
**Setting MCP tool with authentication**:
```bash
tg-set-mcp-tool \
--id secure-tool \
--tool-url https://secure-server.example.com/mcp \
--auth-token eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
```
**Showing MCP tools**:
```bash
tg-show-mcp-tools
ID Remote Name URL Auth
----------- ----------- ------------------------------------ ------
secure-tool secure-tool https://secure-server.example.com/mcp Yes
public-tool public-tool http://localhost:3000/mcp No
```
### Configuration Example
**Stored in configuration system**:
```json
{
"type": "mcp",
"key": "secure-tool",
"value": "{\"url\": \"https://secure-server.example.com/mcp\", \"auth-token\": \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...\"}"
}
```
### Security Best Practices
1. **Use HTTPS**: Always use HTTPS URLs for MCP servers with authentication
2. **Short-lived tokens**: Use tokens with expiration when possible
3. **Least privilege**: Grant tokens minimum required permissions
4. **Access control**: Restrict access to configuration system
5. **Token rotation**: Rotate tokens regularly
6. **Audit logging**: Monitor configuration changes for security events

286
docs/tech-specs/ontology.md Normal file
View file

@ -0,0 +1,286 @@
# Ontology Structure Technical Specification
## Overview
This specification describes the structure and format of ontologies within the TrustGraph system. Ontologies provide formal knowledge models that define classes, properties, and relationships, supporting reasoning and inference capabilities. The system uses an OWL-inspired configuration format that broadly represents OWL/RDFS concepts while being optimized for TrustGraph's requirements.
**Naming Convention**: This project uses kebab-case for all identifiers (configuration keys, API endpoints, module names, etc.) rather than snake_case.
## Goals
- **Class and Property Management**: Define OWL-like classes with properties, domains, ranges, and type constraints
- **Rich Semantic Support**: Enable comprehensive RDFS/OWL properties including labels, multi-language support, and formal constraints
- **Multi-Ontology Support**: Allow multiple ontologies to coexist and interoperate
- **Validation and Reasoning**: Ensure ontologies conform to OWL-like standards with consistency checking and inference support
- **Standard Compatibility**: Support import/export in standard formats (Turtle, RDF/XML, OWL/XML) while maintaining internal optimization
## Background
TrustGraph stores ontologies as configuration items in a flexible key-value system. While the format is inspired by OWL (Web Ontology Language), it is optimized for TrustGraph's specific use cases and does not strictly adhere to all OWL specifications.
Ontologies in TrustGraph enable:
- Definition of formal object types and their properties
- Specification of property domains, ranges, and type constraints
- Logical reasoning and inference
- Complex relationships and cardinality constraints
- Multi-language support for internationalization
## Ontology Structure
### Configuration Storage
Ontologies are stored as configuration items with the following pattern:
- **Type**: `ontology`
- **Key**: Unique ontology identifier (e.g., `natural-world`, `domain-model`)
- **Value**: Complete ontology in JSON format
### JSON Structure
The ontology JSON format consists of four main sections:
#### 1. Metadata
Contains administrative and descriptive information about the ontology:
```json
{
"metadata": {
"name": "The natural world",
"description": "Ontology covering the natural order",
"version": "1.0.0",
"created": "2025-09-20T12:07:37.068Z",
"modified": "2025-09-20T12:12:20.725Z",
"creator": "current-user",
"namespace": "http://trustgraph.ai/ontologies/natural-world",
"imports": ["http://www.w3.org/2002/07/owl#"]
}
}
```
**Fields:**
- `name`: Human-readable name of the ontology
- `description`: Brief description of the ontology's purpose
- `version`: Semantic version number
- `created`: ISO 8601 timestamp of creation
- `modified`: ISO 8601 timestamp of last modification
- `creator`: Identifier of the creating user/system
- `namespace`: Base URI for ontology elements
- `imports`: Array of imported ontology URIs
#### 2. Classes
Defines the object types and their hierarchical relationships:
```json
{
"classes": {
"animal": {
"uri": "http://trustgraph.ai/ontologies/natural-world#animal",
"type": "owl:Class",
"rdfs:label": [{"value": "Animal", "lang": "en"}],
"rdfs:comment": "An animal",
"rdfs:subClassOf": "lifeform",
"owl:equivalentClass": ["creature"],
"owl:disjointWith": ["plant"],
"dcterms:identifier": "ANI-001"
}
}
}
```
**Supported Properties:**
- `uri`: Full URI of the class
- `type`: Always `"owl:Class"`
- `rdfs:label`: Array of language-tagged labels
- `rdfs:comment`: Description of the class
- `rdfs:subClassOf`: Parent class identifier (single inheritance)
- `owl:equivalentClass`: Array of equivalent class identifiers
- `owl:disjointWith`: Array of disjoint class identifiers
- `dcterms:identifier`: Optional external reference identifier
#### 3. Object Properties
Properties that link instances to other instances:
```json
{
"objectProperties": {
"has-parent": {
"uri": "http://trustgraph.ai/ontologies/natural-world#has-parent",
"type": "owl:ObjectProperty",
"rdfs:label": [{"value": "has parent", "lang": "en"}],
"rdfs:comment": "Links an animal to its parent",
"rdfs:domain": "animal",
"rdfs:range": "animal",
"owl:inverseOf": "parent-of",
"owl:functionalProperty": false
}
}
}
```
**Supported Properties:**
- `uri`: Full URI of the property
- `type`: Always `"owl:ObjectProperty"`
- `rdfs:label`: Array of language-tagged labels
- `rdfs:comment`: Description of the property
- `rdfs:domain`: Class identifier that has this property
- `rdfs:range`: Class identifier for property values
- `owl:inverseOf`: Identifier of inverse property
- `owl:functionalProperty`: Boolean indicating at most one value
- `owl:inverseFunctionalProperty`: Boolean for unique identifying properties
#### 4. Datatype Properties
Properties that link instances to literal values:
```json
{
"datatypeProperties": {
"number-of-legs": {
"uri": "http://trustgraph.ai/ontologies/natural-world#number-of-legs",
"type": "owl:DatatypeProperty",
"rdfs:label": [{"value": "number of legs", "lang": "en"}],
"rdfs:comment": "Count of number of legs of the animal",
"rdfs:domain": "animal",
"rdfs:range": "xsd:nonNegativeInteger",
"owl:functionalProperty": true,
"owl:minCardinality": 0,
"owl:maxCardinality": 1
}
}
}
```
**Supported Properties:**
- `uri`: Full URI of the property
- `type`: Always `"owl:DatatypeProperty"`
- `rdfs:label`: Array of language-tagged labels
- `rdfs:comment`: Description of the property
- `rdfs:domain`: Class identifier that has this property
- `rdfs:range`: XSD datatype for property values
- `owl:functionalProperty`: Boolean indicating at most one value
- `owl:minCardinality`: Minimum number of values (optional)
- `owl:maxCardinality`: Maximum number of values (optional)
- `owl:cardinality`: Exact number of values (optional)
### Supported XSD Datatypes
The following XML Schema datatypes are supported for datatype property ranges:
- `xsd:string` - Text values
- `xsd:integer` - Integer numbers
- `xsd:nonNegativeInteger` - Non-negative integers
- `xsd:float` - Floating point numbers
- `xsd:double` - Double precision numbers
- `xsd:boolean` - True/false values
- `xsd:dateTime` - Date and time values
- `xsd:date` - Date values
- `xsd:anyURI` - URI references
### Language Support
Labels and comments support multiple languages using the W3C language tag format:
```json
{
"rdfs:label": [
{"value": "Animal", "lang": "en"},
{"value": "Tier", "lang": "de"},
{"value": "Animal", "lang": "es"}
]
}
```
## Example Ontology
Here's a complete example of a simple ontology:
```json
{
"metadata": {
"name": "The natural world",
"description": "Ontology covering the natural order",
"version": "1.0.0",
"created": "2025-09-20T12:07:37.068Z",
"modified": "2025-09-20T12:12:20.725Z",
"creator": "current-user",
"namespace": "http://trustgraph.ai/ontologies/natural-world",
"imports": ["http://www.w3.org/2002/07/owl#"]
},
"classes": {
"lifeform": {
"uri": "http://trustgraph.ai/ontologies/natural-world#lifeform",
"type": "owl:Class",
"rdfs:label": [{"value": "Lifeform", "lang": "en"}],
"rdfs:comment": "A living thing"
},
"animal": {
"uri": "http://trustgraph.ai/ontologies/natural-world#animal",
"type": "owl:Class",
"rdfs:label": [{"value": "Animal", "lang": "en"}],
"rdfs:comment": "An animal",
"rdfs:subClassOf": "lifeform"
},
"cat": {
"uri": "http://trustgraph.ai/ontologies/natural-world#cat",
"type": "owl:Class",
"rdfs:label": [{"value": "Cat", "lang": "en"}],
"rdfs:comment": "A cat",
"rdfs:subClassOf": "animal"
},
"dog": {
"uri": "http://trustgraph.ai/ontologies/natural-world#dog",
"type": "owl:Class",
"rdfs:label": [{"value": "Dog", "lang": "en"}],
"rdfs:comment": "A dog",
"rdfs:subClassOf": "animal",
"owl:disjointWith": ["cat"]
}
},
"objectProperties": {},
"datatypeProperties": {
"number-of-legs": {
"uri": "http://trustgraph.ai/ontologies/natural-world#number-of-legs",
"type": "owl:DatatypeProperty",
"rdfs:label": [{"value": "number-of-legs", "lang": "en"}],
"rdfs:comment": "Count of number of legs of the animal",
"rdfs:range": "xsd:nonNegativeInteger",
"rdfs:domain": "animal"
}
}
}
```
## Validation Rules
### Structural Validation
1. **URI Consistency**: All URIs should follow the pattern `{namespace}#{identifier}`
2. **Class Hierarchy**: No circular inheritance in `rdfs:subClassOf`
3. **Property Domains/Ranges**: Must reference existing classes or valid XSD types
4. **Disjoint Classes**: Cannot be subclasses of each other
5. **Inverse Properties**: Must be bidirectional if specified
### Semantic Validation
1. **Unique Identifiers**: Class and property identifiers must be unique within an ontology
2. **Language Tags**: Must follow BCP 47 language tag format
3. **Cardinality Constraints**: `minCardinality``maxCardinality` when both specified
4. **Functional Properties**: Cannot have `maxCardinality` > 1
## Import/Export Format Support
While the internal format is JSON, the system supports conversion to/from standard ontology formats:
- **Turtle (.ttl)** - Compact RDF serialization
- **RDF/XML (.rdf, .owl)** - W3C standard format
- **OWL/XML (.owx)** - OWL-specific XML format
- **JSON-LD (.jsonld)** - JSON for Linked Data
## References
- [OWL 2 Web Ontology Language](https://www.w3.org/TR/owl2-overview/)
- [RDF Schema 1.1](https://www.w3.org/TR/rdf-schema/)
- [XML Schema Datatypes](https://www.w3.org/TR/xmlschema-2/)
- [BCP 47 Language Tags](https://tools.ietf.org/html/bcp47)

1067
docs/tech-specs/ontorag.md Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,299 @@
# Vector Store Lifecycle Management
## Overview
This document describes how TrustGraph manages vector store collections across different backend implementations (Qdrant, Pinecone, Milvus). The design addresses the challenge of supporting embeddings with different dimensions without hardcoding dimension values.
## Problem Statement
Vector stores require the embedding dimension to be specified when creating collections/indexes. However:
- Different embedding models produce different dimensions (e.g., 384, 768, 1536)
- The dimension is not known until the first embedding is generated
- A single TrustGraph collection may receive embeddings from multiple models
- Hardcoding a dimension (e.g., 384) causes failures with other embedding sizes
## Design Principles
1. **Lazy Creation**: Collections are created on-demand during first write, not during collection management operations
2. **Dimension-Based Naming**: Collection names include the embedding dimension as a suffix
3. **Graceful Degradation**: Queries against non-existent collections return empty results, not errors
4. **Multi-Dimension Support**: A single logical collection can have multiple physical collections (one per dimension)
## Architecture
### Collection Naming Convention
Vector store collections use dimension suffixes to support multiple embedding sizes:
**Document Embeddings:**
- Qdrant: `d_{user}_{collection}_{dimension}`
- Pinecone: `d-{user}-{collection}-{dimension}`
- Milvus: `doc_{user}_{collection}_{dimension}`
**Graph Embeddings:**
- Qdrant: `t_{user}_{collection}_{dimension}`
- Pinecone: `t-{user}-{collection}-{dimension}`
- Milvus: `entity_{user}_{collection}_{dimension}`
Examples:
- `d_alice_papers_384` - Alice's papers collection with 384-dimensional embeddings
- `d_alice_papers_768` - Same logical collection with 768-dimensional embeddings
- `t_bob_knowledge_1536` - Bob's knowledge graph with 1536-dimensional embeddings
### Lifecycle Phases
#### 1. Collection Creation Request
**Request Flow:**
```
User/System → Librarian → Storage Management Topic → Vector Stores
```
**Behavior:**
- The librarian broadcasts `create-collection` requests to all storage backends
- Vector store processors acknowledge the request but **do not create physical collections**
- Response is returned immediately with success
- Actual collection creation is deferred until first write
**Rationale:**
- Dimension is unknown at creation time
- Avoids creating collections with wrong dimensions
- Simplifies collection management logic
#### 2. Write Operations (Lazy Creation)
**Write Flow:**
```
Data → Storage Processor → Check Collection → Create if Needed → Insert
```
**Behavior:**
1. Extract embedding dimension from the vector: `dim = len(vector)`
2. Construct collection name with dimension suffix
3. Check if collection exists with that specific dimension
4. If not exists:
- Create collection with correct dimension
- Log: `"Lazily creating collection {name} with dimension {dim}"`
5. Insert the embedding into the dimension-specific collection
**Example Scenario:**
```
1. User creates collection "papers"
→ No physical collections created yet
2. First document with 384-dim embedding arrives
→ Creates d_user_papers_384
→ Inserts data
3. Second document with 768-dim embedding arrives
→ Creates d_user_papers_768
→ Inserts data
Result: Two physical collections for one logical collection
```
#### 3. Query Operations
**Query Flow:**
```
Query Vector → Determine Dimension → Check Collection → Search or Return Empty
```
**Behavior:**
1. Extract dimension from query vector: `dim = len(vector)`
2. Construct collection name with dimension suffix
3. Check if collection exists
4. If exists:
- Perform similarity search
- Return results
5. If not exists:
- Log: `"Collection {name} does not exist, returning empty results"`
- Return empty list (no error raised)
**Multiple Dimensions in Same Query:**
- If query contains vectors of different dimensions
- Each dimension queries its corresponding collection
- Results are aggregated
- Missing collections are skipped (not treated as errors)
**Rationale:**
- Querying an empty collection is a valid use case
- Returning empty results is semantically correct
- Avoids errors during system startup or before data ingestion
#### 4. Collection Deletion
**Delete Flow:**
```
Delete Request → List All Collections → Filter by Prefix → Delete All Matches
```
**Behavior:**
1. Construct prefix pattern: `d_{user}_{collection}_` (note trailing underscore)
2. List all collections in the vector store
3. Filter collections matching the prefix
4. Delete all matching collections
5. Log each deletion: `"Deleted collection {name}"`
6. Summary log: `"Deleted {count} collection(s) for {user}/{collection}"`
**Example:**
```
Collections in store:
- d_alice_papers_384
- d_alice_papers_768
- d_alice_reports_384
- d_bob_papers_384
Delete "papers" for alice:
→ Deletes: d_alice_papers_384, d_alice_papers_768
→ Keeps: d_alice_reports_384, d_bob_papers_384
```
**Rationale:**
- Ensures complete cleanup of all dimension variants
- Pattern matching prevents accidental deletion of unrelated collections
- Atomic operation from user perspective (all dimensions deleted together)
## Behavioral Characteristics
### Normal Operations
**Collection Creation:**
- ✓ Returns success immediately
- ✓ No physical storage allocated
- ✓ Fast operation (no backend I/O)
**First Write:**
- ✓ Creates collection with correct dimension
- ✓ Slightly slower due to collection creation overhead
- ✓ Subsequent writes to same dimension are fast
**Queries Before Any Writes:**
- ✓ Returns empty results
- ✓ No errors or exceptions
- ✓ System remains stable
**Mixed Dimension Writes:**
- ✓ Automatically creates separate collections per dimension
- ✓ Each dimension isolated in its own collection
- ✓ No dimension conflicts or schema errors
**Collection Deletion:**
- ✓ Removes all dimension variants
- ✓ Complete cleanup
- ✓ No orphaned collections
### Edge Cases
**Multiple Embedding Models:**
```
Scenario: User switches from model A (384-dim) to model B (768-dim)
Behavior:
- Both dimensions coexist in separate collections
- Old data (384-dim) remains queryable with 384-dim vectors
- New data (768-dim) queryable with 768-dim vectors
- Cross-dimension queries return results only for matching dimension
```
**Concurrent First Writes:**
```
Scenario: Multiple processes write to same collection simultaneously
Behavior:
- Each process checks for existence before creating
- Most vector stores handle concurrent creation gracefully
- If race condition occurs, second create is typically idempotent
- Final state: Collection exists and both writes succeed
```
**Dimension Migration:**
```
Scenario: User wants to migrate from 384-dim to 768-dim embeddings
Behavior:
- No automatic migration
- Old collection (384-dim) persists
- New collection (768-dim) created on first new write
- Both dimensions remain accessible
- Manual deletion of old dimension collections possible
```
**Empty Collection Queries:**
```
Scenario: Query a collection that has never received data
Behavior:
- Collection doesn't exist (never created)
- Query returns empty list
- No error state
- System logs: "Collection does not exist, returning empty results"
```
## Implementation Notes
### Storage Backend Specifics
**Qdrant:**
- Uses `collection_exists()` for existence checks
- Uses `get_collections()` for listing during deletion
- Collection creation requires `VectorParams(size=dim, distance=Distance.COSINE)`
**Pinecone:**
- Uses `has_index()` for existence checks
- Uses `list_indexes()` for listing during deletion
- Index creation requires waiting for "ready" status
- Serverless spec configured with cloud/region
**Milvus:**
- Direct classes (`DocVectors`, `EntityVectors`) manage lifecycle
- Internal cache `self.collections[(dim, user, collection)]` for performance
- Collection names sanitized (alphanumeric + underscore only)
- Supports schema with auto-incrementing IDs
### Performance Considerations
**First Write Latency:**
- Additional overhead due to collection creation
- Qdrant: ~100-500ms
- Pinecone: ~10-30 seconds (serverless provisioning)
- Milvus: ~500-2000ms (includes indexing)
**Query Performance:**
- Existence check adds minimal overhead (~1-10ms)
- No performance impact once collection exists
- Each dimension collection is independently optimized
**Storage Overhead:**
- Minimal metadata per collection
- Main overhead is per-dimension storage
- Trade-off: Storage space vs. dimension flexibility
## Future Considerations
**Automatic Dimension Consolidation:**
- Could add background process to identify and merge unused dimension variants
- Would require re-embedding or dimension reduction
**Dimension Discovery:**
- Could expose API to list all dimensions in use for a collection
- Useful for administration and monitoring
**Default Dimension Preference:**
- Could track "primary" dimension per collection
- Use for queries when dimension context is unavailable
**Storage Quotas:**
- May need per-collection dimension limits
- Prevent proliferation of dimension variants
## Migration Notes
**From Pre-Dimension-Suffix System:**
- Old collections: `d_{user}_{collection}` (no dimension suffix)
- New collections: `d_{user}_{collection}_{dim}` (with dimension suffix)
- No automatic migration - old collections remain accessible
- Consider manual migration script if needed
- Can run both naming schemes simultaneously
## References
- Collection Management: `docs/tech-specs/collection-management.md`
- Storage Schema: `trustgraph-base/trustgraph/schema/services/storage.py`
- Librarian Service: `trustgraph-flow/trustgraph/librarian/service.py`

View file

@ -11,7 +11,7 @@ langchain-text-splitters
langchain-community
huggingface-hub
requests
cassandra-driver
scylla-driver
pulsar-client
pypdf
anthropic

View file

@ -0,0 +1,376 @@
"""
Unit tests for AgentStep arguments type conversion
Tests the fix for converting agent tool arguments to strings when creating
AgentStep records, ensuring compatibility with Pulsar schema that requires
Map(String()) for the arguments field.
"""
import pytest
from unittest.mock import Mock, AsyncMock
from trustgraph.schema import AgentStep
from trustgraph.agent.react.types import Action
class TestAgentStepArgumentsConversion:
"""Test cases for AgentStep arguments string conversion"""
def test_agent_step_with_integer_arguments(self):
"""Test that integer arguments are converted to strings"""
# Arrange
action = Action(
thought="Set volume to 10",
name="set_volume",
arguments={"volume_level": 10, "device": "speaker"},
observation="Volume set successfully"
)
# Act - simulate the conversion that happens in service.py
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["volume_level"] == "10"
assert isinstance(agent_step.arguments["volume_level"], str)
assert agent_step.arguments["device"] == "speaker"
assert isinstance(agent_step.arguments["device"], str)
def test_agent_step_with_float_arguments(self):
"""Test that float arguments are converted to strings"""
# Arrange
action = Action(
thought="Set temperature",
name="set_temperature",
arguments={"temperature": 23.5, "unit": "celsius"},
observation="Temperature set"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["temperature"] == "23.5"
assert isinstance(agent_step.arguments["temperature"], str)
assert agent_step.arguments["unit"] == "celsius"
def test_agent_step_with_boolean_arguments(self):
"""Test that boolean arguments are converted to strings"""
# Arrange
action = Action(
thought="Enable feature",
name="toggle_feature",
arguments={"enabled": True, "feature_name": "dark_mode"},
observation="Feature toggled"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["enabled"] == "True"
assert isinstance(agent_step.arguments["enabled"], str)
assert agent_step.arguments["feature_name"] == "dark_mode"
def test_agent_step_with_none_arguments(self):
"""Test that None arguments are converted to strings"""
# Arrange
action = Action(
thought="Check status",
name="get_status",
arguments={"filter": None, "category": "all"},
observation="Status retrieved"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["filter"] == "None"
assert isinstance(agent_step.arguments["filter"], str)
assert agent_step.arguments["category"] == "all"
def test_agent_step_with_mixed_type_arguments(self):
"""Test that mixed type arguments are all converted to strings"""
# Arrange
action = Action(
thought="Configure device",
name="configure_device",
arguments={
"name": "Hifi",
"volume_level": 10,
"bass_boost": 1.5,
"enabled": True,
"preset": None
},
observation="Device configured"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert - all values should be strings
assert all(isinstance(v, str) for v in agent_step.arguments.values())
assert agent_step.arguments["name"] == "Hifi"
assert agent_step.arguments["volume_level"] == "10"
assert agent_step.arguments["bass_boost"] == "1.5"
assert agent_step.arguments["enabled"] == "True"
assert agent_step.arguments["preset"] == "None"
def test_agent_step_with_string_arguments(self):
"""Test that string arguments remain strings (no double conversion)"""
# Arrange
action = Action(
thought="Search for information",
name="search",
arguments={"query": "test query", "limit": "10"},
observation="Search completed"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["query"] == "test query"
assert agent_step.arguments["limit"] == "10"
assert all(isinstance(v, str) for v in agent_step.arguments.values())
def test_agent_step_with_empty_arguments(self):
"""Test that empty arguments dict works correctly"""
# Arrange
action = Action(
thought="Perform action",
name="do_something",
arguments={},
observation="Action completed"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments == {}
assert isinstance(agent_step.arguments, dict)
def test_agent_step_with_numeric_string_values(self):
"""Test arguments that are already strings containing numbers"""
# Arrange
action = Action(
thought="Process order",
name="process_order",
arguments={"order_id": "12345", "quantity": 10},
observation="Order processed"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["order_id"] == "12345"
assert agent_step.arguments["quantity"] == "10"
assert all(isinstance(v, str) for v in agent_step.arguments.values())
def test_agent_step_conversion_preserves_keys(self):
"""Test that argument keys are preserved during conversion"""
# Arrange
action = Action(
thought="Test",
name="test_action",
arguments={
"param1": 1,
"param_two": 2,
"PARAM_THREE": 3,
"param-four": 4
},
observation="Done"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert - verify all keys are preserved
assert set(agent_step.arguments.keys()) == {
"param1", "param_two", "PARAM_THREE", "param-four"
}
# Verify values are converted
assert agent_step.arguments["param1"] == "1"
assert agent_step.arguments["param_two"] == "2"
assert agent_step.arguments["PARAM_THREE"] == "3"
assert agent_step.arguments["param-four"] == "4"
def test_real_world_home_assistant_example(self):
"""Test with real-world Home Assistant volume control example"""
# Arrange - this is the exact scenario from the bug report
action = Action(
thought='The user wants to set the volume of the Hifi. The `set_device_volume` tool can be used for this purpose. The device name is "Hifi" and the desired volume level is 10.',
name='set_device_volume',
arguments={'name': 'Hifi', 'volume_level': 10},
observation='{"speech": {}, "response_type": "action_done", "data": {"targets": [], "success": [{"name": "Hifi", "type": "entity", "id": "media_player.hifi"}], "failed": []}}'
)
# Act - this should not raise TypeError
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["name"] == "Hifi"
assert agent_step.arguments["volume_level"] == "10"
assert isinstance(agent_step.arguments["volume_level"], str)
def test_multiple_actions_in_history(self):
"""Test converting multiple actions in history (as done in service.py)"""
# Arrange
history = [
Action(
thought="First action",
name="action1",
arguments={"count": 5},
observation="Done 1"
),
Action(
thought="Second action",
name="action2",
arguments={"enabled": True, "name": "test"},
observation="Done 2"
),
Action(
thought="Third action",
name="action3",
arguments={"value": 3.14},
observation="Done 3"
)
]
# Act - simulate the list comprehension in service.py
agent_steps = [
AgentStep(
thought=h.thought,
action=h.name,
arguments={k: str(v) for k, v in h.arguments.items()},
observation=h.observation
)
for h in history
]
# Assert
assert len(agent_steps) == 3
# First action
assert agent_steps[0].arguments["count"] == "5"
assert isinstance(agent_steps[0].arguments["count"], str)
# Second action
assert agent_steps[1].arguments["enabled"] == "True"
assert agent_steps[1].arguments["name"] == "test"
assert all(isinstance(v, str) for v in agent_steps[1].arguments.values())
# Third action
assert agent_steps[2].arguments["value"] == "3.14"
assert isinstance(agent_steps[2].arguments["value"], str)
def test_arguments_with_special_characters(self):
"""Test arguments containing special characters are properly converted"""
# Arrange
action = Action(
thought="Process data",
name="process",
arguments={
"text": "Hello, World!",
"path": "/home/user/file.txt",
"pattern": "test-*-pattern",
"count": 42
},
observation="Processed"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["text"] == "Hello, World!"
assert agent_step.arguments["path"] == "/home/user/file.txt"
assert agent_step.arguments["pattern"] == "test-*-pattern"
assert agent_step.arguments["count"] == "42"
assert all(isinstance(v, str) for v in agent_step.arguments.values())
def test_zero_and_negative_numbers(self):
"""Test that zero and negative numbers are converted correctly"""
# Arrange
action = Action(
thought="Test edge cases",
name="edge_test",
arguments={
"zero": 0,
"negative": -5,
"negative_float": -3.14,
"positive": 10
},
observation="Done"
)
# Act
agent_step = AgentStep(
thought=action.thought,
action=action.name,
arguments={k: str(v) for k, v in action.arguments.items()},
observation=action.observation
)
# Assert
assert agent_step.arguments["zero"] == "0"
assert agent_step.arguments["negative"] == "-5"
assert agent_step.arguments["negative_float"] == "-3.14"
assert agent_step.arguments["positive"] == "10"
assert all(isinstance(v, str) for v in agent_step.arguments.values())

View file

@ -0,0 +1,233 @@
"""
Unit tests for MCP tool bearer token authentication
Tests the authentication feature added to MCP tool service that allows
configuring optional bearer tokens for MCP server connections.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import json
class TestMcpToolAuthentication:
"""Test cases for MCP tool bearer token authentication"""
def test_mcp_tool_with_auth_token_header_building(self):
"""Test that auth token is correctly formatted in headers"""
# Arrange
mcp_config = {
"url": "https://secure.example.com/mcp",
"remote-name": "secure-tool",
"auth-token": "test-token-12345"
}
# Act - simulate header building logic from service.py
headers = {}
if "auth-token" in mcp_config and mcp_config["auth-token"]:
token = mcp_config["auth-token"]
headers["Authorization"] = f"Bearer {token}"
# Assert
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test-token-12345"
assert headers["Authorization"].startswith("Bearer ")
def test_mcp_tool_without_auth_token_header_building(self):
"""Test that no auth header is added when token is not present (backward compatibility)"""
# Arrange
mcp_config = {
"url": "http://public.example.com/mcp",
"remote-name": "public-tool"
# No auth-token field
}
# Act - simulate header building logic from service.py
headers = {}
if "auth-token" in mcp_config and mcp_config["auth-token"]:
token = mcp_config["auth-token"]
headers["Authorization"] = f"Bearer {token}"
# Assert
assert headers == {}
assert "Authorization" not in headers
def test_mcp_config_with_auth_token(self):
"""Test MCP configuration parsing with auth-token"""
# Arrange
config = {
"mcp": {
"secure-tool": json.dumps({
"url": "https://secure.example.com/mcp",
"remote-name": "secure-tool",
"auth-token": "test-token-xyz"
}),
"public-tool": json.dumps({
"url": "http://public.example.com/mcp",
"remote-name": "public-tool"
})
}
}
# Act - simulate on_mcp_config
mcp_services = {
k: json.loads(v)
for k, v in config["mcp"].items()
}
# Assert
assert "secure-tool" in mcp_services
assert mcp_services["secure-tool"]["auth-token"] == "test-token-xyz"
assert mcp_services["secure-tool"]["url"] == "https://secure.example.com/mcp"
assert "public-tool" in mcp_services
assert "auth-token" not in mcp_services["public-tool"]
assert mcp_services["public-tool"]["url"] == "http://public.example.com/mcp"
def test_auth_token_with_empty_string(self):
"""Test that empty auth-token string is treated as no auth"""
# Arrange
config_data = {
"url": "https://example.com/mcp",
"remote-name": "test-tool",
"auth-token": ""
}
# Act - simulate header building logic
headers = {}
if "auth-token" in config_data and config_data["auth-token"]:
headers["Authorization"] = f"Bearer {config_data['auth-token']}"
# Assert
assert headers == {}, "Empty auth-token should not add Authorization header"
def test_auth_token_with_special_characters(self):
"""Test auth token with special characters (JWT-like)"""
# Arrange
jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
config_data = {
"url": "https://example.com/mcp",
"auth-token": jwt_token
}
# Act - simulate header building
headers = {}
if "auth-token" in config_data and config_data["auth-token"]:
token = config_data["auth-token"]
headers["Authorization"] = f"Bearer {token}"
# Assert
assert headers["Authorization"] == f"Bearer {jwt_token}"
assert "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" in headers["Authorization"]
def test_multiple_tools_with_different_auth_configs(self):
"""Test handling multiple MCP tools with mixed auth configurations"""
# Arrange
mcp_services = {
"tool-a": {
"url": "https://a.example.com/mcp",
"auth-token": "token-a"
},
"tool-b": {
"url": "https://b.example.com/mcp",
"auth-token": "token-b"
},
"tool-c": {
"url": "http://c.example.com/mcp"
# No auth-token
}
}
# Act - simulate header building for each tool
headers_a = {}
if "auth-token" in mcp_services["tool-a"] and mcp_services["tool-a"]["auth-token"]:
headers_a["Authorization"] = f"Bearer {mcp_services['tool-a']['auth-token']}"
headers_b = {}
if "auth-token" in mcp_services["tool-b"] and mcp_services["tool-b"]["auth-token"]:
headers_b["Authorization"] = f"Bearer {mcp_services['tool-b']['auth-token']}"
headers_c = {}
if "auth-token" in mcp_services["tool-c"] and mcp_services["tool-c"]["auth-token"]:
headers_c["Authorization"] = f"Bearer {mcp_services['tool-c']['auth-token']}"
# Assert
assert headers_a == {"Authorization": "Bearer token-a"}
assert headers_b == {"Authorization": "Bearer token-b"}
assert headers_c == {}
def test_auth_token_not_logged(self):
"""Test that auth tokens are not exposed in logs"""
# This is more of a guideline test - in real implementation,
# we should ensure tokens are never logged
# Arrange
auth_token = "super-secret-token-123"
config = {
"url": "https://secure.example.com/mcp",
"auth-token": auth_token
}
# Act - simulate log-safe representation
def get_log_safe_config(cfg):
"""Return config with sensitive data masked"""
safe_config = cfg.copy()
if "auth-token" in safe_config and safe_config["auth-token"]:
safe_config["auth-token"] = "****"
return safe_config
log_safe = get_log_safe_config(config)
# Assert
assert log_safe["auth-token"] == "****"
assert auth_token not in str(log_safe)
assert "url" in log_safe
assert log_safe["url"] == "https://secure.example.com/mcp"
def test_auth_token_with_remote_name_configuration(self):
"""Test MCP tool configuration with both auth-token and remote-name"""
# Arrange
mcp_config = {
"url": "https://server.example.com/mcp",
"remote-name": "actual_tool_name",
"auth-token": "my-token-456"
}
# Act - simulate header building and remote name extraction
headers = {}
if "auth-token" in mcp_config and mcp_config["auth-token"]:
token = mcp_config["auth-token"]
headers["Authorization"] = f"Bearer {token}"
remote_name = mcp_config.get("remote-name", "default-name")
# Assert
assert headers["Authorization"] == "Bearer my-token-456"
assert remote_name == "actual_tool_name"
assert "url" in mcp_config
assert mcp_config["url"] == "https://server.example.com/mcp"
def test_bearer_token_format(self):
"""Test that Bearer token format is correct"""
# Arrange
tokens = [
"simple-token",
"token_with_underscore",
"token-with-dash",
"TokenWithMixedCase123",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature"
]
# Act & Assert
for token in tokens:
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
# Verify format is "Bearer <token>" with single space
assert headers["Authorization"].startswith("Bearer ")
assert headers["Authorization"] == f"Bearer {token}"
# Verify no extra spaces
assert headers["Authorization"].count("Bearer") == 1
assert headers["Authorization"].split("Bearer ")[1] == token

View file

@ -30,14 +30,16 @@ class TestMilvusUserCollectionIntegration:
for user, collection, vector in test_cases:
doc_vectors.insert(vector, "test document", user, collection)
expected_collection_name = make_safe_collection_name(
user, collection, "doc"
)
# Verify collection was created with correct name
# Add dimension suffix to expected name
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
# Verify collection was created with correct name (including dimension)
assert (len(vector), user, collection) in doc_vectors.collections
assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name
assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name_with_dim
@patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient')
def test_entity_vectors_collection_creation_with_user_collection(self, mock_milvus_client):
@ -56,14 +58,16 @@ class TestMilvusUserCollectionIntegration:
for user, collection, vector in test_cases:
entity_vectors.insert(vector, "test entity", user, collection)
expected_collection_name = make_safe_collection_name(
user, collection, "entity"
)
# Verify collection was created with correct name
# Add dimension suffix to expected name
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
# Verify collection was created with correct name (including dimension)
assert (len(vector), user, collection) in entity_vectors.collections
assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name
assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name_with_dim
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
def test_doc_vectors_search_uses_correct_collection(self, mock_milvus_client):
@ -88,11 +92,12 @@ class TestMilvusUserCollectionIntegration:
# Now search
result = doc_vectors.search(vector, user, collection, limit=5)
# Verify search was called with correct collection name
# Verify search was called with correct collection name (including dimension)
expected_collection_name = make_safe_collection_name(user, collection, "doc")
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
mock_client.search.assert_called_once()
search_call = mock_client.search.call_args
assert search_call[1]["collection_name"] == expected_collection_name
assert search_call[1]["collection_name"] == expected_collection_name_with_dim
@patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient')
def test_entity_vectors_search_uses_correct_collection(self, mock_milvus_client):
@ -117,11 +122,12 @@ class TestMilvusUserCollectionIntegration:
# Now search
result = entity_vectors.search(vector, user, collection, limit=5)
# Verify search was called with correct collection name
# Verify search was called with correct collection name (including dimension)
expected_collection_name = make_safe_collection_name(user, collection, "entity")
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
mock_client.search.assert_called_once()
search_call = mock_client.search.call_args
assert search_call[1]["collection_name"] == expected_collection_name
assert search_call[1]["collection_name"] == expected_collection_name_with_dim
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
def test_doc_vectors_collection_isolation(self, mock_milvus_client):
@ -141,10 +147,11 @@ class TestMilvusUserCollectionIntegration:
assert len(doc_vectors.collections) == 3
collection_names = set(doc_vectors.collections.values())
# All vectors are 3-dimensional, so all names should have _3 suffix
expected_names = {
"doc_user1_collection1",
"doc_user2_collection2",
"doc_user1_collection2"
"doc_user1_collection1_3",
"doc_user2_collection2_3",
"doc_user1_collection2_3"
}
assert collection_names == expected_names
@ -166,10 +173,11 @@ class TestMilvusUserCollectionIntegration:
assert len(entity_vectors.collections) == 3
collection_names = set(entity_vectors.collections.values())
# All vectors are 3-dimensional, so all names should have _3 suffix
expected_names = {
"entity_user1_collection1",
"entity_user2_collection2",
"entity_user1_collection2"
"entity_user1_collection1_3",
"entity_user2_collection2_3",
"entity_user1_collection2_3"
}
assert collection_names == expected_names
@ -191,16 +199,16 @@ class TestMilvusUserCollectionIntegration:
# Verify three separate collections were created for different dimensions
assert len(doc_vectors.collections) == 3
collection_names = set(doc_vectors.collections.values())
# Different dimensions now create different collections with dimension suffixes
expected_names = {
"doc_test_user_test_collection", # Same name for all dimensions
"doc_test_user_test_collection", # now stored per dimension in key
"doc_test_user_test_collection" # but collection name is the same
"doc_test_user_test_collection_2", # 2D vector
"doc_test_user_test_collection_3", # 3D vector
"doc_test_user_test_collection_4" # 4D vector
}
# Note: Now all dimensions use the same collection name, they are differentiated by the key
assert len(collection_names) == 1 # Only one unique collection name
assert "doc_test_user_test_collection" in collection_names
# Each dimension gets its own collection
assert len(collection_names) == 3 # Three unique collection names
assert collection_names == expected_names
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
@ -222,8 +230,9 @@ class TestMilvusUserCollectionIntegration:
# Verify only one collection was created
assert len(doc_vectors.collections) == 1
expected_collection_name = "doc_test_user_test_collection"
# Collection name now includes dimension suffix
expected_collection_name = "doc_test_user_test_collection_3"
assert doc_vectors.collections[(3, user, collection)] == expected_collection_name
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
@ -235,19 +244,20 @@ class TestMilvusUserCollectionIntegration:
doc_vectors = DocVectors(uri="http://test:19530", prefix="doc")
# Test various special character combinations
# All expected names now include dimension suffix _3
test_cases = [
("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1"),
("user_123", "collection_456", "doc_user_123_collection_456"),
("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces"),
("user@@@test", "collection---test", "doc_user_test_collection_test"),
("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1_3"),
("user_123", "collection_456", "doc_user_123_collection_456_3"),
("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces_3"),
("user@@@test", "collection---test", "doc_user_test_collection_test_3"),
]
vector = [0.1, 0.2, 0.3]
for user, collection, expected_name in test_cases:
doc_vectors_instance = DocVectors(uri="http://test:19530", prefix="doc")
doc_vectors_instance.insert(vector, "test doc", user, collection)
assert doc_vectors_instance.collections[(3, user, collection)] == expected_name
def test_collection_name_backward_compatibility(self):

View file

@ -0,0 +1,157 @@
"""
Contract tests for EmbeddingsService base class
Tests the contract between the EmbeddingsService base class and its
implementations, ensuring proper integration of the model parameter handling.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base import EmbeddingsService
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
class ConcreteEmbeddingsService(EmbeddingsService):
"""Concrete implementation for testing the abstract base class"""
def __init__(self, **params):
self.on_embeddings_calls = []
self.default_model = params.get("model", "default-test-model")
# Don't call super().__init__ to avoid taskgroup requirements in tests
# We're only testing the on_embeddings interface
async def on_embeddings(self, text, model=None):
"""Implementation that tracks calls"""
self.on_embeddings_calls.append({
"text": text,
"model": model
})
# Return a simple embedding
return [[0.1, 0.2, 0.3]]
class TestEmbeddingsServiceModelParameterContract(IsolatedAsyncioTestCase):
"""Test the model parameter contract in embeddings implementations"""
async def test_on_embeddings_accepts_model_parameter(self):
"""Test that on_embeddings method accepts optional model parameter"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
# Act
result1 = await service.on_embeddings("test text")
result2 = await service.on_embeddings("test text", model="custom-model")
result3 = await service.on_embeddings("test text", model=None)
# Assert
assert len(service.on_embeddings_calls) == 3
assert service.on_embeddings_calls[0]["model"] is None # No model specified
assert service.on_embeddings_calls[1]["model"] == "custom-model"
assert service.on_embeddings_calls[2]["model"] is None
async def test_implementation_tracks_model_changes(self):
"""Test that implementations properly track which model is requested"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
# Act - multiple requests with different models
await service.on_embeddings("text1", model="model-a")
await service.on_embeddings("text2", model="model-b")
await service.on_embeddings("text3") # Use default (None passed)
await service.on_embeddings("text4", model="model-a")
# Assert
assert len(service.on_embeddings_calls) == 4
assert service.on_embeddings_calls[0]["model"] == "model-a"
assert service.on_embeddings_calls[1]["model"] == "model-b"
assert service.on_embeddings_calls[2]["model"] is None
assert service.on_embeddings_calls[3]["model"] == "model-a"
async def test_model_parameter_with_various_text_inputs(self):
"""Test model parameter works with different text inputs"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
test_cases = [
("Simple text", "model-1"),
("", "model-2"),
("Unicode: 世界 🌍", "model-3"),
("Very " * 100 + "long text", None),
]
# Act
for text, model in test_cases:
await service.on_embeddings(text, model=model)
# Assert
assert len(service.on_embeddings_calls) == len(test_cases)
for i, (text, model) in enumerate(test_cases):
assert service.on_embeddings_calls[i]["text"] == text
assert service.on_embeddings_calls[i]["model"] == model
async def test_embeddings_return_format(self):
"""Test that embeddings are returned in correct format"""
# Arrange
service = ConcreteEmbeddingsService(model="default-model")
# Act
result = await service.on_embeddings("test text", model="test-model")
# Assert
assert isinstance(result, list)
assert len(result) > 0
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])
class TestEmbeddingsResponseSchema:
"""Test the EmbeddingsResponse schema contract"""
def test_success_response(self):
"""Test creating success response"""
# Act
response = EmbeddingsResponse(
error=None,
vectors=[[0.1, 0.2, 0.3]]
)
# Assert
assert response.error is None
assert response.vectors == [[0.1, 0.2, 0.3]]
def test_error_response(self):
"""Test creating error response"""
# Act
error = Error(type="test-error", message="Test message")
response = EmbeddingsResponse(
error=error,
vectors=None
)
# Assert
assert response.error is not None
assert response.error.type == "test-error"
assert response.error.message == "Test message"
assert response.vectors is None
def test_response_with_multiple_vectors(self):
"""Test response with multiple embedding vectors"""
# Act
vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
response = EmbeddingsResponse(
error=None,
vectors=vectors
)
# Assert
assert len(response.vectors) == 3
assert response.vectors[0] == [0.1, 0.2, 0.3]
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,216 @@
"""
Unit tests for FastEmbed dynamic model loading
Tests the model caching and dynamic loading functionality for FastEmbed
embeddings service.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
from trustgraph.embeddings.fastembed.processor import Processor
class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test FastEmbed dynamic model loading and caching"""
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that default model is loaded during initialization"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
base_params = {
"id": "test-embeddings",
"concurrency": 1,
"model": "test-model",
"taskgroup": AsyncMock()
}
# Act
processor = Processor(**base_params)
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="test-model")
assert processor.default_model == "test-model"
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that using the same model doesn't reload it"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act - use same model multiple times
processor._load_model("test-model")
processor._load_model("test-model")
processor._load_model("test-model")
# Assert - model should not be reloaded
mock_text_embedding_class.assert_not_called()
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that changing model name triggers reload"""
# Arrange
mock_fastembed_instance = Mock()
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act - switch to different model
processor._load_model("different-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
assert processor.cached_model_name == "custom-model"
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test switching between multiple models"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_call_count = mock_text_embedding_class.call_count
# Act - switch between models
await processor.on_embeddings("text1", model="model-a")
call_count_after_a = mock_text_embedding_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload
call_count_after_a_repeat = mock_text_embedding_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload
call_count_after_b = mock_text_embedding_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
call_count_after_a_again = mock_text_embedding_class.call_count
# Assert
assert call_count_after_a == initial_call_count + 1 # First load
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_fastembed_instance = Mock()
mock_fastembed_instance.embed.return_value = [Mock(tolist=lambda: [0.1, 0.2, 0.3, 0.4, 0.5])]
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_count = mock_text_embedding_class.call_count
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
# No reload, using cached default
assert mock_text_embedding_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_text_embedding_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_fastembed_instance = Mock()
mock_text_embedding_class.return_value = mock_fastembed_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "sentence-transformers/all-MiniLM-L6-v2"
mock_text_embedding_class.assert_called_once_with(model_name=expected_default)
assert processor.default_model == expected_default
assert processor.cached_model_name == expected_default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,213 @@
"""
Unit tests for HuggingFace dynamic model loading
Tests the model caching and dynamic loading functionality for HuggingFace
embeddings service using LangChain's HuggingFaceEmbeddings.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
# Skip all tests in this module if trustgraph.embeddings.hf is not installed
pytest.importorskip("trustgraph.embeddings.hf")
from trustgraph.embeddings.hf.hf import Processor
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test HuggingFace dynamic model loading and caching"""
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that default model is loaded during initialization"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Assert
mock_hf_class.assert_called_once_with(model_name="test-model")
assert processor.default_model == "test-model"
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that using the same model doesn't reload it"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act - use same model multiple times
processor._load_model("test-model")
processor._load_model("test-model")
processor._load_model("test-model")
# Assert - model should not be reloaded
mock_hf_class.assert_not_called()
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that changing model name triggers reload"""
# Arrange
mock_hf_instance = Mock()
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act - switch to different model
processor._load_model("different-model")
# Assert
mock_hf_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
mock_hf_class.reset_mock()
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_hf_class.assert_called_once_with(model_name="custom-model")
assert processor.cached_model_name == "custom-model"
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test switching between multiple models"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_call_count = mock_hf_class.call_count
# Act - switch between models
await processor.on_embeddings("text1", model="model-a")
call_count_after_a = mock_hf_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload
call_count_after_a_repeat = mock_hf_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload
call_count_after_b = mock_hf_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
call_count_after_a_again = mock_hf_class.call_count
# Assert
assert call_count_after_a == initial_call_count + 1 # First load
assert call_count_after_a_repeat == initial_call_count + 1 # No reload
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_hf_instance = Mock()
mock_hf_instance.embed_documents.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
initial_count = mock_hf_class.call_count
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
# No reload, using cached default
assert mock_hf_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_hf_instance = Mock()
mock_hf_class.return_value = mock_hf_instance
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "all-MiniLM-L6-v2"
mock_hf_class.assert_called_once_with(model_name=expected_default)
assert processor.default_model == expected_default
assert processor.cached_model_name == expected_default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,167 @@
"""
Unit tests for Ollama dynamic model loading
Tests the dynamic model selection functionality for Ollama embeddings service.
Since Ollama is server-side, no model caching is needed on the client side.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from unittest import IsolatedAsyncioTestCase
from trustgraph.embeddings.ollama.processor import Processor
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test Ollama dynamic model selection"""
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized with correct host"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test", concurrency=1, model="test-model",
ollama="http://localhost:11434", taskgroup=AsyncMock())
# Assert
mock_client_class.assert_called_once_with(host="http://localhost:11434")
assert processor.default_model == "test-model"
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text")
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="test-model",
input="test text"
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text", model="custom-model")
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="custom-model",
input="test text"
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test switching between multiple models"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act - switch between different models
await processor.on_embeddings("text1", model="model-a")
await processor.on_embeddings("text2", model="model-b")
await processor.on_embeddings("text3", model="model-a")
await processor.on_embeddings("text4") # Use default
# Assert
calls = mock_ollama_client.embed.call_args_list
assert len(calls) == 4
assert calls[0][1]['model'] == "model-a"
assert calls[1][1]['model'] == "model-b"
assert calls[2][1]['model'] == "model-a"
assert calls[3][1]['model'] == "test-model" # Default
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_ollama_client = Mock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text", model=None)
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="test-model",
input="test text"
)
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_ollama_client = Mock()
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None
# Act
processor = Processor(id="test-embeddings", concurrency=1, taskgroup=AsyncMock())
# Assert
# Should use default_model from module
expected_default = "mxbai-embed-large"
assert processor.default_model == expected_default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,148 @@
# Ontology Extractor Unit Tests
Comprehensive unit tests for the OntoRAG ontology extraction system.
## Test Coverage
### 1. `test_ontology_selector.py` - Auto-Include Properties Feature
Tests the critical dependency resolution that automatically includes all properties related to selected classes.
**Key Tests:**
- `test_auto_include_properties_for_recipe_class` - Verifies Recipe class auto-includes `ingredients`, `method`, `produces`, `serves`
- `test_auto_include_properties_for_ingredient_class` - Verifies Ingredient class auto-includes `food` property
- `test_auto_include_properties_for_range_class` - Tests properties are included when class appears in range
- `test_auto_include_adds_domain_and_range_classes` - Ensures related classes are added too
- `test_multiple_classes_get_all_related_properties` - Tests combining multiple class selections
- `test_no_duplicate_properties_added` - Ensures properties aren't duplicated
### 2. `test_uri_expansion.py` - URI Expansion
Tests that URIs are properly expanded using ontology definitions instead of constructed fallback URIs.
**Key Tests:**
- `test_expand_class_uri_from_ontology` - Class names expand to ontology URIs
- `test_expand_object_property_uri_from_ontology` - Object properties use ontology URIs
- `test_expand_datatype_property_uri_from_ontology` - Datatype properties use ontology URIs
- `test_expand_rdf_prefix` - Standard RDF prefixes expand correctly
- `test_expand_rdfs_prefix`, `test_expand_owl_prefix`, `test_expand_xsd_prefix` - Other standard prefixes
- `test_fallback_uri_for_instance` - Entity instances get constructed URIs
- `test_already_full_uri_unchanged` - Full URIs pass through
- `test_dict_access_not_object_attribute` - **Critical test** verifying dict access works (not object attributes)
### 3. `test_ontology_triples.py` - Ontology Triple Generation
Tests that ontology elements (classes and properties) are properly converted to RDF triples with labels, comments, domains, and ranges.
**Key Tests:**
- `test_generates_class_type_triples` - Classes get `rdf:type owl:Class` triples
- `test_generates_class_labels` - Classes get `rdfs:label` triples
- `test_generates_class_comments` - Classes get `rdfs:comment` triples
- `test_generates_object_property_type_triples` - Object properties get proper type triples
- `test_generates_object_property_labels` - Object properties get labels
- `test_generates_object_property_domain` - Object properties get `rdfs:domain` triples
- `test_generates_object_property_range` - Object properties get `rdfs:range` triples
- `test_generates_datatype_property_type_triples` - Datatype properties get proper type triples
- `test_generates_datatype_property_range` - Datatype properties get XSD type ranges
- `test_uses_dict_field_names_not_rdf_names` - **Critical test** verifying dict field names work
- `test_total_triple_count_is_reasonable` - Validates expected number of triples
### 4. `test_text_processing.py` - Text Processing and Segmentation
Tests that text is properly split into sentences for ontology matching, including NLTK tokenization and TextSegment creation.
**Key Tests:**
- `test_segment_single_sentence` - Single sentence produces one segment
- `test_segment_multiple_sentences` - Multiple sentences split correctly
- `test_segment_positions` - Segment start/end positions are correct
- `test_segment_complex_punctuation` - Handles abbreviations (Dr., U.S.A., Mr.)
- `test_segment_question_and_exclamation` - Different sentence terminators
- `test_segment_preserves_original_text` - Segments can reconstruct original
- `test_text_segment_non_overlapping` - Segments don't overlap
- `test_nltk_punkt_availability` - NLTK tokenizer is available
- `test_unicode_text` - Handles unicode characters
- `test_quoted_text` - Handles quoted text correctly
### 5. `test_prompt_and_extraction.py` - LLM Prompt Construction and Triple Extraction
Tests that the system correctly constructs prompts with ontology constraints and extracts/validates triples from LLM responses.
**Key Tests:**
- `test_build_extraction_variables_includes_text` - Prompt includes input text
- `test_build_extraction_variables_includes_classes` - Prompt includes ontology classes
- `test_build_extraction_variables_includes_properties` - Prompt includes properties
- `test_validates_rdf_type_triple_with_valid_class` - Validates rdf:type against ontology
- `test_rejects_rdf_type_triple_with_invalid_class` - Rejects invalid classes
- `test_validates_object_property_triple` - Validates object properties
- `test_rejects_unknown_property` - Rejects properties not in ontology
- `test_parse_simple_triple_dict` - Parses triple from dict format
- `test_filters_invalid_triples` - Filters out invalid triples
- `test_expands_uris_in_parsed_triples` - Expands URIs using ontology
- `test_creates_proper_triple_objects` - Creates Triple objects with Value subjects/predicates/objects
### 6. `test_embedding_and_similarity.py` - Ontology Embedding and Similarity Matching
Tests that ontology elements are properly embedded and matched against input text using vector similarity.
**Key Tests:**
- `test_create_text_from_class_with_id` - Text representation includes class ID
- `test_create_text_from_class_with_labels` - Includes labels in text
- `test_create_text_from_class_with_comment` - Includes comments in text
- `test_create_text_from_property_with_domain_range` - Includes domain/range in property text
- `test_normalizes_id_with_underscores` - Normalizes IDs (underscores to spaces)
- `test_includes_subclass_info_for_classes` - Includes subclass relationships
- `test_vector_store_api_structure` - Vector store has expected API
- `test_selector_handles_text_segments` - Selector processes text segments
- `test_merge_subsets_combines_elements` - Merging combines ontology elements
- `test_ontology_element_metadata_structure` - Metadata structure is correct
## Running the Tests
### Run all ontology extractor tests:
```bash
cd /home/mark/work/trustgraph.ai/trustgraph
pytest tests/unit/test_extract/test_ontology/ -v
```
### Run specific test file:
```bash
pytest tests/unit/test_extract/test_ontology/test_ontology_selector.py -v
pytest tests/unit/test_extract/test_ontology/test_uri_expansion.py -v
pytest tests/unit/test_extract/test_ontology/test_ontology_triples.py -v
pytest tests/unit/test_extract/test_ontology/test_text_processing.py -v
pytest tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py -v
pytest tests/unit/test_extract/test_ontology/test_embedding_and_similarity.py -v
```
### Run specific test:
```bash
pytest tests/unit/test_extract/test_ontology/test_ontology_selector.py::TestOntologySelector::test_auto_include_properties_for_recipe_class -v
```
### Run with coverage:
```bash
pytest tests/unit/test_extract/test_ontology/ --cov=trustgraph.extract.kg.ontology --cov-report=html
```
## Test Fixtures
- `sample_ontology` - Complete Food Ontology with Recipe, Ingredient, Food, Method classes
- `ontology_loader_with_sample` - Mock OntologyLoader with the sample ontology
- `ontology_embedder` - Mock embedder for testing
- `mock_embedding_service` - Mock service for generating deterministic embeddings
- `vector_store` - InMemoryVectorStore for testing
- `extractor` - Processor instance for URI expansion tests
- `ontology_subset_with_uris` - OntologySubset with proper URIs defined
- `sample_ontology_subset` - OntologySubset for testing triple generation
- `text_processor` - TextProcessor instance for text segmentation tests
- `sample_ontology_class` - Sample OntologyClass for testing
- `sample_ontology_property` - Sample OntologyProperty for testing
## Implementation Notes
These tests verify the fixes made to address:
1. **Disconnected graph problem** - Auto-include properties feature ensures all relevant relationships are available
2. **Wrong URIs problem** - URI expansion using ontology definitions instead of constructed fallbacks
3. **Dict vs object attribute problem** - URI expansion works with dicts (from `cls.__dict__`) not object attributes
4. **Ontology visibility in KG** - Ontology elements themselves appear in the knowledge graph with proper metadata
5. **Text segmentation** - Proper sentence splitting for ontology matching using NLTK

View file

@ -0,0 +1 @@
"""Tests for ontology-based extraction."""

View file

@ -0,0 +1,423 @@
"""
Unit tests for ontology embedding and similarity matching.
Tests that ontology elements are properly embedded and matched against
input text using vector similarity.
"""
import pytest
import numpy as np
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.ontology.ontology_embedder import (
OntologyEmbedder,
OntologyElementMetadata
)
from trustgraph.extract.kg.ontology.ontology_loader import (
Ontology,
OntologyClass,
OntologyProperty
)
from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore, SearchResult
from trustgraph.extract.kg.ontology.text_processor import TextSegment
from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
@pytest.fixture
def mock_embedding_service():
"""Create a mock embedding service."""
service = AsyncMock()
# Return deterministic embeddings for testing
async def mock_embed(text):
# Simple hash-based embedding for deterministic tests
hash_val = hash(text) % 1000
return np.array([hash_val / 1000.0, (1000 - hash_val) / 1000.0])
service.embed = mock_embed
return service
@pytest.fixture
def vector_store():
"""Create an empty vector store."""
return InMemoryVectorStore()
@pytest.fixture
def ontology_embedder(mock_embedding_service, vector_store):
"""Create an ontology embedder with mock service."""
return OntologyEmbedder(
embedding_service=mock_embedding_service,
vector_store=vector_store
)
@pytest.fixture
def sample_ontology_class():
"""Create a sample ontology class."""
return OntologyClass(
uri="http://purl.org/ontology/fo/Recipe",
type="owl:Class",
labels=[{"value": "Recipe", "lang": "en-gb"}],
comment="A Recipe is a combination of ingredients and a method.",
subclass_of=None
)
@pytest.fixture
def sample_ontology_property():
"""Create a sample ontology property."""
return OntologyProperty(
uri="http://purl.org/ontology/fo/ingredients",
type="owl:ObjectProperty",
labels=[{"value": "ingredients", "lang": "en-gb"}],
comment="The ingredients property relates a recipe to an ingredient list.",
domain="Recipe",
range="IngredientList"
)
class TestTextRepresentation:
"""Test suite for creating text representations of ontology elements."""
def test_create_text_from_class_with_id(self, ontology_embedder, sample_ontology_class):
"""Test that class ID is included in text representation."""
text = ontology_embedder._create_text_representation(
"Recipe",
sample_ontology_class,
"class"
)
assert "Recipe" in text, "Should include class ID"
def test_create_text_from_class_with_labels(self, ontology_embedder, sample_ontology_class):
"""Test that class labels are included in text representation."""
text = ontology_embedder._create_text_representation(
"Recipe",
sample_ontology_class,
"class"
)
assert "Recipe" in text, "Should include label value"
def test_create_text_from_class_with_comment(self, ontology_embedder, sample_ontology_class):
"""Test that class comments are included in text representation."""
text = ontology_embedder._create_text_representation(
"Recipe",
sample_ontology_class,
"class"
)
assert "combination of ingredients" in text, "Should include comment"
def test_create_text_from_property_with_domain_range(self, ontology_embedder, sample_ontology_property):
"""Test that property domain and range are included in text."""
text = ontology_embedder._create_text_representation(
"ingredients",
sample_ontology_property,
"objectProperty"
)
assert "domain: Recipe" in text, "Should include domain"
assert "range: IngredientList" in text, "Should include range"
def test_normalizes_id_with_underscores(self, ontology_embedder):
"""Test that IDs with underscores are normalized."""
mock_element = MagicMock()
mock_element.labels = []
mock_element.comment = None
text = ontology_embedder._create_text_representation(
"some_property_name",
mock_element,
"objectProperty"
)
assert "some property name" in text, "Should replace underscores with spaces"
def test_normalizes_id_with_hyphens(self, ontology_embedder):
"""Test that IDs with hyphens are normalized."""
mock_element = MagicMock()
mock_element.labels = []
mock_element.comment = None
text = ontology_embedder._create_text_representation(
"some-property-name",
mock_element,
"objectProperty"
)
assert "some property name" in text, "Should replace hyphens with spaces"
def test_handles_element_without_labels(self, ontology_embedder):
"""Test handling of elements without labels."""
mock_element = MagicMock()
mock_element.labels = None
mock_element.comment = "Test comment"
text = ontology_embedder._create_text_representation(
"TestElement",
mock_element,
"class"
)
assert "TestElement" in text, "Should still include ID"
assert "Test comment" in text, "Should include comment"
def test_includes_subclass_info_for_classes(self, ontology_embedder):
"""Test that subclass information is included for classes."""
mock_class = MagicMock()
mock_class.labels = []
mock_class.comment = None
mock_class.subclass_of = "ParentClass"
text = ontology_embedder._create_text_representation(
"ChildClass",
mock_class,
"class"
)
assert "subclass of ParentClass" in text, "Should include subclass relationship"
class TestVectorStoreOperations:
"""Test suite for vector store operations."""
def test_vector_store_starts_empty(self, vector_store):
"""Test that vector store initializes empty."""
assert vector_store.size() == 0, "New vector store should be empty"
def test_vector_store_api_structure(self, vector_store):
"""Test that vector store has expected API methods."""
assert hasattr(vector_store, 'add'), "Should have add method"
assert hasattr(vector_store, 'add_batch'), "Should have add_batch method"
assert hasattr(vector_store, 'search'), "Should have search method"
assert hasattr(vector_store, 'size'), "Should have size method"
def test_search_result_class_structure(self):
"""Test that SearchResult has expected structure."""
# Create a sample SearchResult
result = SearchResult(id="test-1", score=0.95, metadata={"element": "Test"})
assert hasattr(result, 'id'), "Should have id attribute"
assert hasattr(result, 'score'), "Should have score attribute"
assert hasattr(result, 'metadata'), "Should have metadata attribute"
assert result.id == "test-1"
assert result.score == 0.95
assert result.metadata["element"] == "Test"
class TestOntologySelectorIntegration:
"""Test suite for ontology selector with embeddings."""
@pytest.fixture
def sample_ontology(self):
"""Create a sample ontology for testing."""
return Ontology(
id="food",
classes={
"Recipe": OntologyClass(
uri="http://purl.org/ontology/fo/Recipe",
type="owl:Class",
labels=[{"value": "Recipe", "lang": "en-gb"}],
comment="A Recipe is a combination of ingredients and a method."
),
"Ingredient": OntologyClass(
uri="http://purl.org/ontology/fo/Ingredient",
type="owl:Class",
labels=[{"value": "Ingredient", "lang": "en-gb"}],
comment="An Ingredient combines a quantity and a food."
)
},
object_properties={
"ingredients": OntologyProperty(
uri="http://purl.org/ontology/fo/ingredients",
type="owl:ObjectProperty",
labels=[{"value": "ingredients", "lang": "en-gb"}],
comment="Relates a recipe to its ingredients.",
domain="Recipe",
range="IngredientList"
)
},
datatype_properties={},
metadata={"name": "Food Ontology"}
)
@pytest.fixture
def ontology_loader_mock(self, sample_ontology):
"""Create a mock ontology loader."""
loader = MagicMock()
loader.get_ontology.return_value = sample_ontology
loader.get_all_ontology_ids.return_value = ["food"]
return loader
async def test_selector_handles_text_segments(
self, ontology_embedder, ontology_loader_mock
):
"""Test that selector can process text segments."""
# Create selector
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_mock,
top_k=5,
similarity_threshold=0.3
)
# Create text segments
segments = [
TextSegment(text="Recipe for cornish pasty", type="sentence", position=0),
TextSegment(text="ingredients needed", type="sentence", position=1)
]
# Select ontology subset (will be empty since we haven't embedded anything)
subsets = await selector.select_ontology_subset(segments)
# Should return a list (even if empty)
assert isinstance(subsets, list), "Should return a list of subsets"
async def test_selector_with_no_embedding_service(self, vector_store, ontology_loader_mock):
"""Test that selector handles missing embedding service gracefully."""
embedder = OntologyEmbedder(embedding_service=None, vector_store=vector_store)
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=ontology_loader_mock,
top_k=5,
similarity_threshold=0.7
)
segments = [
TextSegment(text="Test text", type="sentence", position=0)
]
# Should return empty results without crashing
subsets = await selector.select_ontology_subset(segments)
assert isinstance(subsets, list), "Should return a list even without embeddings"
def test_merge_subsets_combines_elements(self, ontology_loader_mock, ontology_embedder):
"""Test that merging subsets combines all elements."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_mock,
top_k=5,
similarity_threshold=0.7
)
# Create two subsets from same ontology
subset1 = OntologySubset(
ontology_id="food",
classes={"Recipe": {"uri": "http://example.com/Recipe"}},
object_properties={},
datatype_properties={},
metadata={},
relevance_score=0.8
)
subset2 = OntologySubset(
ontology_id="food",
classes={"Ingredient": {"uri": "http://example.com/Ingredient"}},
object_properties={"ingredients": {"uri": "http://example.com/ingredients"}},
datatype_properties={},
metadata={},
relevance_score=0.9
)
merged = selector.merge_subsets([subset1, subset2])
assert len(merged.classes) == 2, "Should combine classes"
# Keys may be prefixed with ontology id
assert any("Recipe" in key for key in merged.classes.keys())
assert any("Ingredient" in key for key in merged.classes.keys())
assert len(merged.object_properties) == 1, "Should include properties"
class TestEmbeddingEdgeCases:
"""Test suite for edge cases in embedding."""
async def test_embed_element_with_no_labels(self, ontology_embedder):
"""Test embedding element without labels."""
mock_element = MagicMock()
mock_element.labels = None
mock_element.comment = "Test element"
text = ontology_embedder._create_text_representation(
"TestElement",
mock_element,
"class"
)
# Should not crash and should include ID and comment
assert "TestElement" in text
assert "Test element" in text
async def test_embed_element_with_empty_comment(self, ontology_embedder):
"""Test embedding element with empty comment."""
mock_element = MagicMock()
mock_element.labels = [{"value": "Label"}]
mock_element.comment = None
text = ontology_embedder._create_text_representation(
"TestElement",
mock_element,
"class"
)
# Should not crash
assert "Label" in text
def test_ontology_element_metadata_structure(self):
"""Test OntologyElementMetadata structure."""
metadata = OntologyElementMetadata(
type="class",
ontology="food",
element="Recipe",
definition={"uri": "http://example.com/Recipe"},
text="Recipe A combination of ingredients"
)
assert metadata.type == "class"
assert metadata.ontology == "food"
assert metadata.element == "Recipe"
assert "uri" in metadata.definition
def test_vector_store_search_on_empty_store(self):
"""Test searching empty vector store."""
# Need a non-empty store for faiss to work
# This test verifies the store can be created but searching requires dimension
store = InMemoryVectorStore()
assert store.size() == 0, "Empty store should have size 0"
class TestOntologySubsetStructure:
"""Test suite for OntologySubset structure."""
def test_ontology_subset_creation(self):
"""Test creating an OntologySubset."""
subset = OntologySubset(
ontology_id="test",
classes={"Recipe": {}},
object_properties={"produces": {}},
datatype_properties={"serves": {}},
metadata={"name": "Test"},
relevance_score=0.85
)
assert subset.ontology_id == "test"
assert len(subset.classes) == 1
assert len(subset.object_properties) == 1
assert len(subset.datatype_properties) == 1
assert subset.relevance_score == 0.85
def test_ontology_subset_default_score(self):
"""Test that OntologySubset has default score."""
subset = OntologySubset(
ontology_id="test",
classes={},
object_properties={},
datatype_properties={},
metadata={}
)
assert subset.relevance_score == 0.0, "Should have default score of 0.0"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,353 @@
"""
Unit tests for entity context building.
Tests that entity contexts are properly created from extracted triples,
collecting labels and definitions for entity embedding and retrieval.
"""
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.knowledge.graph import EntityContext
@pytest.fixture
def processor():
"""Create a Processor instance for testing."""
processor = object.__new__(Processor)
return processor
class TestEntityContextBuilding:
"""Test suite for entity context building from triples."""
def test_builds_context_from_label(self, processor):
"""Test that entity context is built from rdfs:label."""
triples = [
Triple(
s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish Pasty", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1, "Should create one entity context"
assert isinstance(contexts[0], EntityContext)
assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty"
assert "Label: Cornish Pasty" in contexts[0].context
def test_builds_context_from_definition(self, processor):
"""Test that entity context includes definitions."""
triples = [
Triple(
s=Value(value="https://example.com/entity/pasty", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="A baked pastry filled with savory ingredients", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert "A baked pastry filled with savory ingredients" in contexts[0].context
def test_combines_label_and_definition(self, processor):
"""Test that label and definition are combined in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Pasty Recipe", is_uri=False)
),
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Traditional Cornish pastry recipe", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
context_text = contexts[0].context
assert "Label: Pasty Recipe" in context_text
assert "Traditional Cornish pastry recipe" in context_text
assert ". " in context_text, "Should join parts with period and space"
def test_uses_first_label_only(self, processor):
"""Test that only the first label is used in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="First Label", is_uri=False)
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Second Label", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert "Label: First Label" in contexts[0].context
assert "Second Label" not in contexts[0].context
def test_includes_all_definitions(self, processor):
"""Test that all definitions are included in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="First definition", is_uri=False)
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Second definition", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
context_text = contexts[0].context
assert "First definition" in context_text
assert "Second definition" in context_text
def test_supports_schema_org_description(self, processor):
"""Test that schema.org description is treated as definition."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="https://schema.org/description", is_uri=True),
o=Value(value="A delicious food item", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert "A delicious food item" in contexts[0].context
def test_handles_multiple_entities(self, processor):
"""Test that contexts are created for multiple entities."""
triples = [
Triple(
s=Value(value="https://example.com/entity/entity1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity One", is_uri=False)
),
Triple(
s=Value(value="https://example.com/entity/entity2", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity Two", is_uri=False)
),
Triple(
s=Value(value="https://example.com/entity/entity3", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity Three", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 3, "Should create context for each entity"
entity_uris = [ctx.entity.value for ctx in contexts]
assert "https://example.com/entity/entity1" in entity_uris
assert "https://example.com/entity/entity2" in entity_uris
assert "https://example.com/entity/entity3" in entity_uris
def test_ignores_uri_literals(self, processor):
"""Test that URI objects are ignored (only literal labels/definitions)."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal
)
]
contexts = processor.build_entity_contexts(triples)
# Should not create context since label is URI
assert len(contexts) == 0, "Should not create context for URI labels"
def test_ignores_non_label_non_definition_triples(self, processor):
"""Test that other predicates are ignored."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Food", is_uri=True)
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://example.com/produces", is_uri=True),
o=Value(value="https://example.com/entity/food2", is_uri=True)
)
]
contexts = processor.build_entity_contexts(triples)
# Should not create context since no labels or definitions
assert len(contexts) == 0, "Should not create context without labels/definitions"
def test_handles_empty_triple_list(self, processor):
"""Test handling of empty triple list."""
triples = []
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 0, "Empty triple list should return empty contexts"
def test_entity_context_has_value_object(self, processor):
"""Test that EntityContext.entity is a Value object."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test Entity", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert isinstance(contexts[0].entity, Value), "Entity should be Value object"
assert contexts[0].entity.is_uri, "Entity should be marked as URI"
def test_entity_context_text_is_string(self, processor):
"""Test that EntityContext.context is a string."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test Entity", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert isinstance(contexts[0].context, str), "Context should be string"
def test_only_creates_contexts_with_meaningful_info(self, processor):
"""Test that contexts are only created when there's meaningful information."""
triples = [
# Entity with label - should create context
Triple(
s=Value(value="https://example.com/entity/entity1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity One", is_uri=False)
),
# Entity with only rdf:type - should NOT create context
Triple(
s=Value(value="https://example.com/entity/entity2", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Food", is_uri=True)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1, "Should only create context for entity with label/definition"
assert contexts[0].entity.value == "https://example.com/entity/entity1"
class TestEntityContextEdgeCases:
"""Test suite for edge cases in entity context building."""
def test_handles_unicode_in_labels(self, processor):
"""Test handling of unicode characters in labels."""
triples = [
Triple(
s=Value(value="https://example.com/entity/café", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Café Spécial", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert "Café Spécial" in contexts[0].context
def test_handles_long_definitions(self, processor):
"""Test handling of very long definitions."""
long_def = "This is a very long definition " * 50
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value=long_def, is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert long_def in contexts[0].context
def test_handles_special_characters_in_context(self, processor):
"""Test handling of special characters in context text."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test & Entity <with> \"quotes\"", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert "Test & Entity <with> \"quotes\"" in contexts[0].context
def test_mixed_relevant_and_irrelevant_triples(self, processor):
"""Test extracting contexts from mixed triple types."""
triples = [
# Label - relevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish Pasty Recipe", is_uri=False)
),
# Type - irrelevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Recipe", is_uri=True)
),
# Property - irrelevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://example.com/produces", is_uri=True),
o=Value(value="https://example.com/entity/pasty", is_uri=True)
),
# Definition - relevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Traditional British pastry recipe", is_uri=False)
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
context_text = contexts[0].context
# Should include label and definition
assert "Label: Cornish Pasty Recipe" in context_text
assert "Traditional British pastry recipe" in context_text
# Should not include type or property info
assert "Recipe" not in context_text or "Cornish Pasty Recipe" in context_text # Only in label
assert "produces" not in context_text
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,518 @@
"""
Unit tests for ontology loading and configuration.
Tests that ontologies are properly loaded from configuration,
parsed, validated, and managed by the OntologyLoader.
"""
import pytest
from trustgraph.extract.kg.ontology.ontology_loader import (
OntologyLoader,
Ontology,
OntologyClass,
OntologyProperty
)
@pytest.fixture
def ontology_loader():
"""Create an OntologyLoader instance."""
return OntologyLoader()
@pytest.fixture
def sample_ontology_config():
"""Create a sample ontology configuration."""
return {
"food": {
"metadata": {
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/"
},
"classes": {
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"rdfs:label": [{"value": "Recipe", "lang": "en-gb"}],
"rdfs:comment": "A Recipe is a combination of ingredients and a method."
},
"Ingredient": {
"uri": "http://purl.org/ontology/fo/Ingredient",
"type": "owl:Class",
"rdfs:label": [{"value": "Ingredient", "lang": "en-gb"}],
"rdfs:comment": "An Ingredient combines a quantity and a food."
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"rdfs:label": [{"value": "Food", "lang": "en-gb"}],
"rdfs:comment": "A Food is something that can be eaten.",
"rdfs:subClassOf": "EdibleThing"
}
},
"objectProperties": {
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"rdfs:label": [{"value": "ingredients", "lang": "en-gb"}],
"rdfs:domain": "Recipe",
"rdfs:range": "IngredientList"
},
"produces": {
"uri": "http://purl.org/ontology/fo/produces",
"type": "owl:ObjectProperty",
"rdfs:label": [{"value": "produces", "lang": "en-gb"}],
"rdfs:domain": "Recipe",
"rdfs:range": "Food"
}
},
"datatypeProperties": {
"serves": {
"uri": "http://purl.org/ontology/fo/serves",
"type": "owl:DatatypeProperty",
"rdfs:label": [{"value": "serves", "lang": "en-gb"}],
"rdfs:domain": "Recipe",
"rdfs:range": "xsd:string"
}
}
}
}
class TestOntologyLoaderInitialization:
"""Test suite for OntologyLoader initialization."""
def test_loader_starts_empty(self, ontology_loader):
"""Test that loader initializes with no ontologies."""
assert len(ontology_loader.get_all_ontologies()) == 0
def test_loader_get_nonexistent_ontology(self, ontology_loader):
"""Test getting non-existent ontology returns None."""
result = ontology_loader.get_ontology("nonexistent")
assert result is None
class TestOntologyLoading:
"""Test suite for loading ontologies from configuration."""
def test_loads_single_ontology(self, ontology_loader, sample_ontology_config):
"""Test loading a single ontology."""
ontology_loader.update_ontologies(sample_ontology_config)
ontologies = ontology_loader.get_all_ontologies()
assert len(ontologies) == 1
assert "food" in ontologies
def test_loaded_ontology_has_correct_id(self, ontology_loader, sample_ontology_config):
"""Test that loaded ontology has correct ID."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
assert ontology is not None
assert ontology.id == "food"
def test_loaded_ontology_has_metadata(self, ontology_loader, sample_ontology_config):
"""Test that loaded ontology includes metadata."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
assert ontology.metadata["name"] == "Food Ontology"
assert ontology.metadata["namespace"] == "http://purl.org/ontology/fo/"
def test_loaded_ontology_has_classes(self, ontology_loader, sample_ontology_config):
"""Test that loaded ontology includes classes."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
assert len(ontology.classes) == 3
assert "Recipe" in ontology.classes
assert "Ingredient" in ontology.classes
assert "Food" in ontology.classes
def test_loaded_classes_have_correct_properties(self, ontology_loader, sample_ontology_config):
"""Test that loaded classes have correct properties."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
recipe = ontology.get_class("Recipe")
assert isinstance(recipe, OntologyClass)
assert recipe.uri == "http://purl.org/ontology/fo/Recipe"
assert recipe.type == "owl:Class"
assert len(recipe.labels) == 1
assert recipe.labels[0]["value"] == "Recipe"
assert "combination of ingredients" in recipe.comment
def test_loaded_ontology_has_object_properties(self, ontology_loader, sample_ontology_config):
"""Test that loaded ontology includes object properties."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
assert len(ontology.object_properties) == 2
assert "ingredients" in ontology.object_properties
assert "produces" in ontology.object_properties
def test_loaded_properties_have_domain_and_range(self, ontology_loader, sample_ontology_config):
"""Test that loaded properties have domain and range."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
produces = ontology.get_property("produces")
assert isinstance(produces, OntologyProperty)
assert produces.domain == "Recipe"
assert produces.range == "Food"
def test_loaded_ontology_has_datatype_properties(self, ontology_loader, sample_ontology_config):
"""Test that loaded ontology includes datatype properties."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
assert len(ontology.datatype_properties) == 1
assert "serves" in ontology.datatype_properties
def test_loads_multiple_ontologies(self, ontology_loader):
"""Test loading multiple ontologies."""
config = {
"food": {
"metadata": {"name": "Food Ontology"},
"classes": {"Recipe": {"uri": "http://example.com/Recipe"}},
"objectProperties": {},
"datatypeProperties": {}
},
"music": {
"metadata": {"name": "Music Ontology"},
"classes": {"Song": {"uri": "http://example.com/Song"}},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontologies = ontology_loader.get_all_ontologies()
assert len(ontologies) == 2
assert "food" in ontologies
assert "music" in ontologies
def test_update_replaces_existing_ontologies(self, ontology_loader, sample_ontology_config):
"""Test that update replaces existing ontologies."""
# Load initial ontologies
ontology_loader.update_ontologies(sample_ontology_config)
assert len(ontology_loader.get_all_ontologies()) == 1
# Update with different config
new_config = {
"music": {
"metadata": {"name": "Music Ontology"},
"classes": {},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(new_config)
# Old ontologies should be replaced
ontologies = ontology_loader.get_all_ontologies()
assert len(ontologies) == 1
assert "music" in ontologies
assert "food" not in ontologies
class TestOntologyRetrieval:
"""Test suite for retrieving ontologies."""
def test_get_ontology_by_id(self, ontology_loader, sample_ontology_config):
"""Test retrieving ontology by ID."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
assert ontology is not None
assert isinstance(ontology, Ontology)
def test_get_all_ontologies(self, ontology_loader):
"""Test retrieving all ontologies."""
config = {
"food": {
"metadata": {},
"classes": {},
"objectProperties": {},
"datatypeProperties": {}
},
"music": {
"metadata": {},
"classes": {},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontologies = ontology_loader.get_all_ontologies()
assert isinstance(ontologies, dict)
assert len(ontologies) == 2
def test_get_all_ontology_ids(self, ontology_loader):
"""Test retrieving all ontology IDs."""
config = {
"food": {
"metadata": {},
"classes": {},
"objectProperties": {},
"datatypeProperties": {}
},
"music": {
"metadata": {},
"classes": {},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontologies = ontology_loader.get_all_ontologies()
ids = list(ontologies.keys())
assert len(ids) == 2
assert "food" in ids
assert "music" in ids
class TestOntologyClassMethods:
"""Test suite for Ontology helper methods."""
def test_get_class(self, ontology_loader, sample_ontology_config):
"""Test getting a class from ontology."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
recipe = ontology.get_class("Recipe")
assert recipe is not None
assert recipe.uri == "http://purl.org/ontology/fo/Recipe"
def test_get_nonexistent_class(self, ontology_loader, sample_ontology_config):
"""Test getting non-existent class returns None."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
result = ontology.get_class("NonExistent")
assert result is None
def test_get_property(self, ontology_loader, sample_ontology_config):
"""Test getting a property from ontology."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
produces = ontology.get_property("produces")
assert produces is not None
assert produces.domain == "Recipe"
def test_get_property_checks_both_types(self, ontology_loader, sample_ontology_config):
"""Test that get_property checks both object and datatype properties."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
# Object property
produces = ontology.get_property("produces")
assert produces is not None
# Datatype property
serves = ontology.get_property("serves")
assert serves is not None
def test_get_parent_classes(self, ontology_loader, sample_ontology_config):
"""Test getting parent classes following subClassOf."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
parents = ontology.get_parent_classes("Food")
assert "EdibleThing" in parents
def test_get_parent_classes_empty_for_root(self, ontology_loader, sample_ontology_config):
"""Test that root classes have no parents."""
ontology_loader.update_ontologies(sample_ontology_config)
ontology = ontology_loader.get_ontology("food")
parents = ontology.get_parent_classes("Recipe")
assert len(parents) == 0
class TestOntologyValidation:
"""Test suite for ontology validation."""
def test_validates_property_domain_exists(self, ontology_loader):
"""Test validation of property domain."""
config = {
"test": {
"metadata": {},
"classes": {
"Recipe": {"uri": "http://example.com/Recipe"}
},
"objectProperties": {
"produces": {
"uri": "http://example.com/produces",
"type": "owl:ObjectProperty",
"rdfs:domain": "NonExistentClass", # Invalid
"rdfs:range": "Food"
}
},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("test")
issues = ontology.validate_structure()
assert len(issues) > 0
assert any("unknown domain" in issue.lower() for issue in issues)
def test_validates_object_property_range_exists(self, ontology_loader):
"""Test validation of object property range."""
config = {
"test": {
"metadata": {},
"classes": {
"Recipe": {"uri": "http://example.com/Recipe"}
},
"objectProperties": {
"produces": {
"uri": "http://example.com/produces",
"type": "owl:ObjectProperty",
"rdfs:domain": "Recipe",
"rdfs:range": "NonExistentClass" # Invalid
}
},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("test")
issues = ontology.validate_structure()
assert len(issues) > 0
assert any("unknown range" in issue.lower() for issue in issues)
def test_detects_circular_inheritance(self, ontology_loader):
"""Test detection of circular inheritance."""
config = {
"test": {
"metadata": {},
"classes": {
"A": {
"uri": "http://example.com/A",
"rdfs:subClassOf": "B"
},
"B": {
"uri": "http://example.com/B",
"rdfs:subClassOf": "C"
},
"C": {
"uri": "http://example.com/C",
"rdfs:subClassOf": "A" # Circular!
}
},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("test")
issues = ontology.validate_structure()
assert len(issues) > 0
assert any("circular" in issue.lower() for issue in issues)
def test_valid_ontology_has_no_issues(self, ontology_loader, sample_ontology_config):
"""Test that valid ontology passes validation."""
# Modify config to have valid references
config = sample_ontology_config.copy()
config["food"]["classes"]["EdibleThing"] = {
"uri": "http://purl.org/ontology/fo/EdibleThing"
}
config["food"]["classes"]["IngredientList"] = {
"uri": "http://purl.org/ontology/fo/IngredientList"
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("food")
issues = ontology.validate_structure()
# Should have minimal or no issues for valid ontology
assert isinstance(issues, list)
class TestEdgeCases:
"""Test suite for edge cases in ontology loading."""
def test_handles_empty_config(self, ontology_loader):
"""Test handling of empty configuration."""
ontology_loader.update_ontologies({})
ontologies = ontology_loader.get_all_ontologies()
assert len(ontologies) == 0
def test_handles_ontology_without_classes(self, ontology_loader):
"""Test handling of ontology with no classes."""
config = {
"minimal": {
"metadata": {"name": "Minimal"},
"classes": {},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("minimal")
assert ontology is not None
assert len(ontology.classes) == 0
def test_handles_ontology_without_properties(self, ontology_loader):
"""Test handling of ontology with no properties."""
config = {
"test": {
"metadata": {},
"classes": {
"Recipe": {"uri": "http://example.com/Recipe"}
},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("test")
assert ontology is not None
assert len(ontology.object_properties) == 0
assert len(ontology.datatype_properties) == 0
def test_handles_missing_optional_fields(self, ontology_loader):
"""Test handling of missing optional fields."""
config = {
"test": {
"metadata": {},
"classes": {
"Simple": {
"uri": "http://example.com/Simple"
# No labels, comments, subclass, etc.
}
},
"objectProperties": {},
"datatypeProperties": {}
}
}
ontology_loader.update_ontologies(config)
ontology = ontology_loader.get_ontology("test")
simple = ontology.get_class("Simple")
assert simple is not None
assert simple.uri == "http://example.com/Simple"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,336 @@
"""
Unit tests for OntologySelector component.
Tests the critical auto-include properties feature that automatically pulls in
all properties related to selected classes.
"""
import pytest
from unittest.mock import Mock, AsyncMock
from trustgraph.extract.kg.ontology.ontology_selector import (
OntologySelector,
OntologySubset
)
from trustgraph.extract.kg.ontology.ontology_loader import (
Ontology,
OntologyClass,
OntologyProperty
)
from trustgraph.extract.kg.ontology.text_processor import TextSegment
@pytest.fixture
def sample_ontology():
"""Create a sample food ontology for testing."""
# Create classes
recipe_class = OntologyClass(
uri="http://purl.org/ontology/fo/Recipe",
type="owl:Class",
labels=[{"value": "Recipe", "lang": "en-gb"}],
comment="A Recipe is a combination of ingredients and a method."
)
ingredient_class = OntologyClass(
uri="http://purl.org/ontology/fo/Ingredient",
type="owl:Class",
labels=[{"value": "Ingredient", "lang": "en-gb"}],
comment="An Ingredient is a combination of a quantity and a food."
)
food_class = OntologyClass(
uri="http://purl.org/ontology/fo/Food",
type="owl:Class",
labels=[{"value": "Food", "lang": "en-gb"}],
comment="A Food is something that can be eaten."
)
method_class = OntologyClass(
uri="http://purl.org/ontology/fo/Method",
type="owl:Class",
labels=[{"value": "Method", "lang": "en-gb"}],
comment="A Method is the way in which ingredients are combined."
)
# Create object properties
ingredients_prop = OntologyProperty(
uri="http://purl.org/ontology/fo/ingredients",
type="owl:ObjectProperty",
labels=[{"value": "ingredients", "lang": "en-gb"}],
comment="The ingredients property relates a recipe to an ingredient list.",
domain="Recipe",
range="IngredientList"
)
food_prop = OntologyProperty(
uri="http://purl.org/ontology/fo/food",
type="owl:ObjectProperty",
labels=[{"value": "food", "lang": "en-gb"}],
comment="The food property relates an ingredient to the food that is required.",
domain="Ingredient",
range="Food"
)
method_prop = OntologyProperty(
uri="http://purl.org/ontology/fo/method",
type="owl:ObjectProperty",
labels=[{"value": "method", "lang": "en-gb"}],
comment="The method property relates a recipe to the method used.",
domain="Recipe",
range="Method"
)
produces_prop = OntologyProperty(
uri="http://purl.org/ontology/fo/produces",
type="owl:ObjectProperty",
labels=[{"value": "produces", "lang": "en-gb"}],
comment="The produces property relates a recipe to the food it produces.",
domain="Recipe",
range="Food"
)
# Create datatype properties
serves_prop = OntologyProperty(
uri="http://purl.org/ontology/fo/serves",
type="owl:DatatypeProperty",
labels=[{"value": "serves", "lang": "en-gb"}],
comment="The serves property indicates what the recipe is intended to serve.",
domain="Recipe",
range="xsd:string"
)
# Build ontology
ontology = Ontology(
id="food",
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/"
},
classes={
"Recipe": recipe_class,
"Ingredient": ingredient_class,
"Food": food_class,
"Method": method_class
},
object_properties={
"ingredients": ingredients_prop,
"food": food_prop,
"method": method_prop,
"produces": produces_prop
},
datatype_properties={
"serves": serves_prop
}
)
return ontology
@pytest.fixture
def ontology_loader_with_sample(sample_ontology):
"""Create an OntologyLoader with the sample ontology."""
loader = Mock()
loader.get_ontology = Mock(return_value=sample_ontology)
loader.ontologies = {"food": sample_ontology}
return loader
@pytest.fixture
def ontology_embedder():
"""Create a mock OntologyEmbedder."""
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2, 0.3]) # Mock embedding
# Mock vector store with search results
vector_store = Mock()
embedder.get_vector_store = Mock(return_value=vector_store)
return embedder
class TestOntologySelector:
"""Test suite for OntologySelector."""
def test_auto_include_properties_for_recipe_class(
self, ontology_loader_with_sample, ontology_embedder, sample_ontology
):
"""Test that selecting Recipe class automatically includes all related properties."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_with_sample,
top_k=10,
similarity_threshold=0.3
)
# Create a subset with only Recipe class initially selected
subset = OntologySubset(
ontology_id="food",
classes={"Recipe": sample_ontology.classes["Recipe"].__dict__},
object_properties={},
datatype_properties={},
metadata=sample_ontology.metadata,
relevance_score=0.8
)
# Resolve dependencies (this is where auto-include happens)
selector._resolve_dependencies(subset)
# Assert that properties with Recipe in domain are included
assert "ingredients" in subset.object_properties, \
"ingredients property should be auto-included (Recipe in domain)"
assert "method" in subset.object_properties, \
"method property should be auto-included (Recipe in domain)"
assert "produces" in subset.object_properties, \
"produces property should be auto-included (Recipe in domain)"
assert "serves" in subset.datatype_properties, \
"serves property should be auto-included (Recipe in domain)"
# Assert that unrelated property is NOT included
assert "food" not in subset.object_properties, \
"food property should NOT be included (Recipe not in domain/range)"
def test_auto_include_properties_for_ingredient_class(
self, ontology_loader_with_sample, ontology_embedder, sample_ontology
):
"""Test that selecting Ingredient class includes properties with Ingredient in domain."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_with_sample
)
subset = OntologySubset(
ontology_id="food",
classes={"Ingredient": sample_ontology.classes["Ingredient"].__dict__},
object_properties={},
datatype_properties={},
metadata=sample_ontology.metadata
)
selector._resolve_dependencies(subset)
# Ingredient has 'food' property in domain
assert "food" in subset.object_properties, \
"food property should be auto-included (Ingredient in domain)"
# Recipe-related properties should NOT be included
assert "ingredients" not in subset.object_properties
assert "method" not in subset.object_properties
def test_auto_include_properties_for_range_class(
self, ontology_loader_with_sample, ontology_embedder, sample_ontology
):
"""Test that selecting a class includes properties with that class in range."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_with_sample
)
subset = OntologySubset(
ontology_id="food",
classes={"Food": sample_ontology.classes["Food"].__dict__},
object_properties={},
datatype_properties={},
metadata=sample_ontology.metadata
)
selector._resolve_dependencies(subset)
# Food appears in range of 'food' and 'produces' properties
assert "food" in subset.object_properties, \
"food property should be auto-included (Food in range)"
assert "produces" in subset.object_properties, \
"produces property should be auto-included (Food in range)"
def test_auto_include_adds_domain_and_range_classes(
self, ontology_loader_with_sample, ontology_embedder, sample_ontology
):
"""Test that auto-included properties also add their domain/range classes."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_with_sample
)
# Start with only Recipe class
subset = OntologySubset(
ontology_id="food",
classes={"Recipe": sample_ontology.classes["Recipe"].__dict__},
object_properties={},
datatype_properties={},
metadata=sample_ontology.metadata
)
selector._resolve_dependencies(subset)
# Should auto-include 'produces' property (Recipe → Food)
assert "produces" in subset.object_properties
# Should also add Food class (range of produces)
assert "Food" in subset.classes, \
"Food class should be added (range of auto-included produces property)"
# Should also add Method class (range of method property)
assert "Method" in subset.classes, \
"Method class should be added (range of auto-included method property)"
def test_multiple_classes_get_all_related_properties(
self, ontology_loader_with_sample, ontology_embedder, sample_ontology
):
"""Test that selecting multiple classes includes all their related properties."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_with_sample
)
# Select both Recipe and Ingredient classes
subset = OntologySubset(
ontology_id="food",
classes={
"Recipe": sample_ontology.classes["Recipe"].__dict__,
"Ingredient": sample_ontology.classes["Ingredient"].__dict__
},
object_properties={},
datatype_properties={},
metadata=sample_ontology.metadata
)
selector._resolve_dependencies(subset)
# Should include Recipe-related properties
assert "ingredients" in subset.object_properties
assert "method" in subset.object_properties
assert "produces" in subset.object_properties
assert "serves" in subset.datatype_properties
# Should also include Ingredient-related properties
assert "food" in subset.object_properties
def test_no_duplicate_properties_added(
self, ontology_loader_with_sample, ontology_embedder, sample_ontology
):
"""Test that properties aren't added multiple times."""
selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=ontology_loader_with_sample
)
# Start with Recipe and Food (both related to 'produces')
subset = OntologySubset(
ontology_id="food",
classes={
"Recipe": sample_ontology.classes["Recipe"].__dict__,
"Food": sample_ontology.classes["Food"].__dict__
},
object_properties={},
datatype_properties={},
metadata=sample_ontology.metadata
)
selector._resolve_dependencies(subset)
# 'produces' should be included once (not duplicated)
assert "produces" in subset.object_properties
# Count would be 1 - dict keys are unique, so this is guaranteed
# but worth documenting the expected behavior
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,300 @@
"""
Unit tests for ontology triple generation.
Tests that ontology elements (classes and properties) are properly converted
to RDF triples with labels, comments, domains, and ranges so they appear in
the knowledge graph.
"""
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.schema.core.primitives import Triple, Value
@pytest.fixture
def extractor():
"""Create a Processor instance for testing."""
extractor = object.__new__(Processor)
return extractor
@pytest.fixture
def sample_ontology_subset():
"""Create a sample ontology subset with classes and properties."""
return OntologySubset(
ontology_id="food",
classes={
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"labels": [{"value": "Recipe", "lang": "en-gb"}],
"comment": "A Recipe is a combination of ingredients and a method.",
"subclass_of": None
},
"Ingredient": {
"uri": "http://purl.org/ontology/fo/Ingredient",
"type": "owl:Class",
"labels": [{"value": "Ingredient", "lang": "en-gb"}],
"comment": "An Ingredient combines a quantity and a food.",
"subclass_of": None
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"labels": [{"value": "Food", "lang": "en-gb"}],
"comment": "A Food is something that can be eaten.",
"subclass_of": None
}
},
object_properties={
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"labels": [{"value": "ingredients", "lang": "en-gb"}],
"comment": "The ingredients property relates a recipe to an ingredient list.",
"domain": "Recipe",
"range": "IngredientList"
},
"produces": {
"uri": "http://purl.org/ontology/fo/produces",
"type": "owl:ObjectProperty",
"labels": [{"value": "produces", "lang": "en-gb"}],
"comment": "The produces property relates a recipe to the food it produces.",
"domain": "Recipe",
"range": "Food"
}
},
datatype_properties={
"serves": {
"uri": "http://purl.org/ontology/fo/serves",
"type": "owl:DatatypeProperty",
"labels": [{"value": "serves", "lang": "en-gb"}],
"comment": "The serves property indicates serving size.",
"domain": "Recipe",
"rdfs:range": "xsd:string"
}
},
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/"
}
)
class TestOntologyTripleGeneration:
"""Test suite for ontology triple generation."""
def test_generates_class_type_triples(self, extractor, sample_ontology_subset):
"""Test that classes get rdf:type owl:Class triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find type triples for Recipe class
recipe_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class"
assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \
"Class type should be owl:Class"
def test_generates_class_labels(self, extractor, sample_ontology_subset):
"""Test that classes get rdfs:label triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find label triples for Recipe class
recipe_label_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(recipe_label_triples) == 1, "Should generate label triple for class"
assert recipe_label_triples[0].o.value == "Recipe", \
"Label should match class label from ontology"
assert not recipe_label_triples[0].o.is_uri, \
"Label should be a literal, not URI"
def test_generates_class_comments(self, extractor, sample_ontology_subset):
"""Test that classes get rdfs:comment triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find comment triples for Recipe class
recipe_comment_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment"
]
assert len(recipe_comment_triples) == 1, "Should generate comment triple for class"
assert "combination of ingredients and a method" in recipe_comment_triples[0].o.value, \
"Comment should match class description from ontology"
def test_generates_object_property_type_triples(self, extractor, sample_ontology_subset):
"""Test that object properties get rdf:type owl:ObjectProperty triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find type triples for ingredients property
ingredients_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(ingredients_type_triples) == 1, \
"Should generate exactly one type triple per object property"
assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \
"Object property type should be owl:ObjectProperty"
def test_generates_object_property_labels(self, extractor, sample_ontology_subset):
"""Test that object properties get rdfs:label triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find label triples for ingredients property
ingredients_label_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(ingredients_label_triples) == 1, \
"Should generate label triple for object property"
assert ingredients_label_triples[0].o.value == "ingredients", \
"Label should match property label from ontology"
def test_generates_object_property_domain(self, extractor, sample_ontology_subset):
"""Test that object properties get rdfs:domain triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find domain triples for ingredients property
ingredients_domain_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain"
]
assert len(ingredients_domain_triples) == 1, \
"Should generate domain triple for object property"
assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \
"Domain should be Recipe class URI"
assert ingredients_domain_triples[0].o.is_uri, \
"Domain should be a URI reference"
def test_generates_object_property_range(self, extractor, sample_ontology_subset):
"""Test that object properties get rdfs:range triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find range triples for produces property
produces_range_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/produces"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
]
assert len(produces_range_triples) == 1, \
"Should generate range triple for object property"
assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \
"Range should be Food class URI"
def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset):
"""Test that datatype properties get rdf:type owl:DatatypeProperty triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find type triples for serves property
serves_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/serves"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(serves_type_triples) == 1, \
"Should generate exactly one type triple per datatype property"
assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
"Datatype property type should be owl:DatatypeProperty"
def test_generates_datatype_property_range(self, extractor, sample_ontology_subset):
"""Test that datatype properties get rdfs:range triples with XSD types."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Find range triples for serves property
serves_range_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/serves"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
]
assert len(serves_range_triples) == 1, \
"Should generate range triple for datatype property"
assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \
"Range should be XSD type URI (xsd:string expanded)"
def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset):
"""Test that triples are generated for all classes in the subset."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Count unique class subjects
class_subjects = set(
t.s.value for t in triples
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and t.o.value == "http://www.w3.org/2002/07/owl#Class"
)
assert len(class_subjects) == 3, \
"Should generate triples for all 3 classes (Recipe, Ingredient, Food)"
def test_generates_triples_for_all_properties(self, extractor, sample_ontology_subset):
"""Test that triples are generated for all properties in the subset."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Count unique property subjects (object + datatype properties)
property_subjects = set(
t.s.value for t in triples
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value)
)
assert len(property_subjects) == 3, \
"Should generate triples for all 3 properties (ingredients, produces, serves)"
def test_uses_dict_field_names_not_rdf_names(self, extractor, sample_ontology_subset):
"""Test that triple generation works with dict field names (labels, comment, domain, range).
This is critical - the ontology subset has dicts with Python field names,
not RDF property names.
"""
# Verify the subset uses dict field names
recipe_def = sample_ontology_subset.classes["Recipe"]
assert isinstance(recipe_def, dict), "Class definitions should be dicts"
assert "labels" in recipe_def, "Should use 'labels' not 'rdfs:label'"
assert "comment" in recipe_def, "Should use 'comment' not 'rdfs:comment'"
# Now verify triple generation works
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Should still generate proper RDF triples despite dict field names
label_triples = [
t for t in triples
if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(label_triples) > 0, \
"Should generate rdfs:label triples from dict 'labels' field"
def test_total_triple_count_is_reasonable(self, extractor, sample_ontology_subset):
"""Test that we generate a reasonable number of triples."""
triples = extractor.build_ontology_triples(sample_ontology_subset)
# Each class gets: type, label, comment (3 triples)
# Each object property gets: type, label, comment, domain, range (5 triples)
# Each datatype property gets: type, label, comment, domain, range (5 triples)
# Expected: 3 classes * 3 + 2 object props * 5 + 1 datatype prop * 5 = 9 + 10 + 5 = 24
assert len(triples) >= 20, \
"Should generate substantial number of triples for ontology elements"
assert len(triples) < 50, \
"Should not generate excessive duplicate triples"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,414 @@
"""
Unit tests for LLM prompt construction and triple extraction.
Tests that the system correctly constructs prompts with ontology constraints
and extracts/validates triples from LLM responses.
"""
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.schema.core.primitives import Triple, Value
@pytest.fixture
def extractor():
"""Create a Processor instance for testing."""
extractor = object.__new__(Processor)
extractor.URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
return extractor
@pytest.fixture
def sample_ontology_subset():
"""Create a sample ontology subset for extraction testing."""
return OntologySubset(
ontology_id="food",
classes={
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"labels": [{"value": "Recipe", "lang": "en-gb"}],
"comment": "A Recipe is a combination of ingredients and a method."
},
"Ingredient": {
"uri": "http://purl.org/ontology/fo/Ingredient",
"type": "owl:Class",
"labels": [{"value": "Ingredient", "lang": "en-gb"}],
"comment": "An Ingredient combines a quantity and a food."
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"labels": [{"value": "Food", "lang": "en-gb"}],
"comment": "A Food is something that can be eaten."
}
},
object_properties={
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"labels": [{"value": "ingredients", "lang": "en-gb"}],
"comment": "The ingredients property relates a recipe to an ingredient list.",
"domain": "Recipe",
"range": "IngredientList"
},
"food": {
"uri": "http://purl.org/ontology/fo/food",
"type": "owl:ObjectProperty",
"labels": [{"value": "food", "lang": "en-gb"}],
"comment": "The food property relates an ingredient to food.",
"domain": "Ingredient",
"range": "Food"
},
"produces": {
"uri": "http://purl.org/ontology/fo/produces",
"type": "owl:ObjectProperty",
"labels": [{"value": "produces", "lang": "en-gb"}],
"comment": "The produces property relates a recipe to the food it produces.",
"domain": "Recipe",
"range": "Food"
}
},
datatype_properties={
"serves": {
"uri": "http://purl.org/ontology/fo/serves",
"type": "owl:DatatypeProperty",
"labels": [{"value": "serves", "lang": "en-gb"}],
"comment": "The serves property indicates serving size.",
"domain": "Recipe",
"rdfs:range": "xsd:string"
}
},
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/"
}
)
class TestPromptConstruction:
"""Test suite for LLM prompt construction."""
def test_build_extraction_variables_includes_text(self, extractor, sample_ontology_subset):
"""Test that extraction variables include the input text."""
chunk = "Cornish pasty is a traditional British pastry filled with meat and vegetables."
variables = extractor.build_extraction_variables(chunk, sample_ontology_subset)
assert "text" in variables, "Should include text key"
assert variables["text"] == chunk, "Text should match input chunk"
def test_build_extraction_variables_includes_classes(self, extractor, sample_ontology_subset):
"""Test that extraction variables include ontology classes."""
chunk = "Test text"
variables = extractor.build_extraction_variables(chunk, sample_ontology_subset)
assert "classes" in variables, "Should include classes key"
assert len(variables["classes"]) == 3, "Should include all classes from subset"
assert "Recipe" in variables["classes"]
assert "Ingredient" in variables["classes"]
assert "Food" in variables["classes"]
def test_build_extraction_variables_includes_properties(self, extractor, sample_ontology_subset):
"""Test that extraction variables include ontology properties."""
chunk = "Test text"
variables = extractor.build_extraction_variables(chunk, sample_ontology_subset)
assert "object_properties" in variables, "Should include object_properties key"
assert "datatype_properties" in variables, "Should include datatype_properties key"
assert len(variables["object_properties"]) == 3
assert len(variables["datatype_properties"]) == 1
def test_build_extraction_variables_structure(self, extractor, sample_ontology_subset):
"""Test the overall structure of extraction variables."""
chunk = "Test text"
variables = extractor.build_extraction_variables(chunk, sample_ontology_subset)
# Should have exactly 4 keys
assert set(variables.keys()) == {"text", "classes", "object_properties", "datatype_properties"}
def test_build_extraction_variables_with_empty_subset(self, extractor):
"""Test building variables with minimal ontology subset."""
minimal_subset = OntologySubset(
ontology_id="minimal",
classes={},
object_properties={},
datatype_properties={},
metadata={}
)
chunk = "Test text"
variables = extractor.build_extraction_variables(chunk, minimal_subset)
assert variables["text"] == chunk
assert len(variables["classes"]) == 0
assert len(variables["object_properties"]) == 0
assert len(variables["datatype_properties"]) == 0
class TestTripleValidation:
"""Test suite for triple validation against ontology."""
def test_validates_rdf_type_triple_with_valid_class(self, extractor, sample_ontology_subset):
"""Test that rdf:type triples are validated against ontology classes."""
subject = "cornish-pasty"
predicate = "rdf:type"
object_val = "Recipe"
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid, "rdf:type with valid class should be valid"
def test_rejects_rdf_type_triple_with_invalid_class(self, extractor, sample_ontology_subset):
"""Test that rdf:type triples with non-existent classes are rejected."""
subject = "cornish-pasty"
predicate = "rdf:type"
object_val = "NonExistentClass"
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert not is_valid, "rdf:type with invalid class should be rejected"
def test_validates_rdfs_label_triple(self, extractor, sample_ontology_subset):
"""Test that rdfs:label triples are always valid."""
subject = "cornish-pasty"
predicate = "rdfs:label"
object_val = "Cornish Pasty"
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid, "rdfs:label should always be valid"
def test_validates_object_property_triple(self, extractor, sample_ontology_subset):
"""Test that object property triples are validated."""
subject = "cornish-pasty-recipe"
predicate = "produces"
object_val = "cornish-pasty"
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid, "Valid object property should be accepted"
def test_validates_datatype_property_triple(self, extractor, sample_ontology_subset):
"""Test that datatype property triples are validated."""
subject = "cornish-pasty-recipe"
predicate = "serves"
object_val = "4-6 people"
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid, "Valid datatype property should be accepted"
def test_rejects_unknown_property(self, extractor, sample_ontology_subset):
"""Test that unknown properties are rejected."""
subject = "cornish-pasty"
predicate = "unknownProperty"
object_val = "some value"
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert not is_valid, "Unknown property should be rejected"
def test_validates_multiple_valid_properties(self, extractor, sample_ontology_subset):
"""Test validation of different property types."""
test_cases = [
("recipe1", "produces", "food1", True),
("ingredient1", "food", "food1", True),
("recipe1", "serves", "4", True),
("recipe1", "invalidProp", "value", False),
]
for subject, predicate, object_val, expected in test_cases:
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid == expected, f"Validation of {predicate} should be {expected}"
class TestTripleParsing:
"""Test suite for parsing triples from LLM responses."""
def test_parse_simple_triple_dict(self, extractor, sample_ontology_subset):
"""Test parsing a simple triple from dict format."""
triples_response = [
{
"subject": "cornish-pasty",
"predicate": "rdf:type",
"object": "Recipe"
}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1, "Should parse one valid triple"
assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty"
assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
def test_parse_multiple_triples(self, extractor, sample_ontology_subset):
"""Test parsing multiple triples."""
triples_response = [
{"subject": "cornish-pasty", "predicate": "rdf:type", "object": "Recipe"},
{"subject": "cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
{"subject": "cornish-pasty", "predicate": "serves", "object": "1-2 people"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 3, "Should parse all valid triples"
def test_filters_invalid_triples(self, extractor, sample_ontology_subset):
"""Test that invalid triples are filtered out."""
triples_response = [
{"subject": "cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}, # Valid
{"subject": "cornish-pasty", "predicate": "invalidProp", "object": "value"}, # Invalid
{"subject": "cornish-pasty", "predicate": "produces", "object": "food1"} # Valid
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 2, "Should filter out invalid triple"
def test_handles_missing_fields(self, extractor, sample_ontology_subset):
"""Test that triples with missing fields are skipped."""
triples_response = [
{"subject": "cornish-pasty", "predicate": "rdf:type"}, # Missing object
{"subject": "cornish-pasty", "object": "Recipe"}, # Missing predicate
{"predicate": "rdf:type", "object": "Recipe"}, # Missing subject
{"subject": "cornish-pasty", "predicate": "rdf:type", "object": "Recipe"} # Valid
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1, "Should skip triples with missing fields"
def test_handles_empty_response(self, extractor, sample_ontology_subset):
"""Test handling of empty LLM response."""
triples_response = []
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 0, "Empty response should return no triples"
def test_expands_uris_in_parsed_triples(self, extractor, sample_ontology_subset):
"""Test that URIs are properly expanded in parsed triples."""
triples_response = [
{"subject": "recipe1", "predicate": "produces", "object": "Food"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1
# Subject should be expanded to entity URI
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
# Predicate should be expanded to ontology URI
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
# Object should be expanded to class URI
assert validated[0].o.value == "http://purl.org/ontology/fo/Food"
def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset):
"""Test that Triple objects are properly created."""
triples_response = [
{"subject": "cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1
triple = validated[0]
assert isinstance(triple, Triple), "Should create Triple objects"
assert isinstance(triple.s, Value), "Subject should be Value object"
assert isinstance(triple.p, Value), "Predicate should be Value object"
assert isinstance(triple.o, Value), "Object should be Value object"
assert triple.s.is_uri, "Subject should be marked as URI"
assert triple.p.is_uri, "Predicate should be marked as URI"
assert not triple.o.is_uri, "Object literal should not be marked as URI"
class TestURIExpansionInExtraction:
"""Test suite for URI expansion during triple extraction."""
def test_expands_class_names_in_objects(self, extractor, sample_ontology_subset):
"""Test that class names in object position are expanded."""
triples_response = [
{"subject": "entity1", "predicate": "rdf:type", "object": "Recipe"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
assert validated[0].o.is_uri, "Class reference should be URI"
def test_expands_property_names(self, extractor, sample_ontology_subset):
"""Test that property names are expanded to full URIs."""
triples_response = [
{"subject": "recipe1", "predicate": "produces", "object": "food1"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
def test_expands_entity_instances(self, extractor, sample_ontology_subset):
"""Test that entity instances get constructed URIs."""
triples_response = [
{"subject": "my-special-recipe", "predicate": "rdf:type", "object": "Recipe"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
assert "my-special-recipe" in validated[0].s.value
class TestEdgeCases:
"""Test suite for edge cases in extraction."""
def test_handles_non_dict_response_items(self, extractor, sample_ontology_subset):
"""Test that non-dict items in response are skipped."""
triples_response = [
{"subject": "entity1", "predicate": "rdf:type", "object": "Recipe"}, # Valid
"invalid string item", # Invalid
None, # Invalid
{"subject": "entity2", "predicate": "rdf:type", "object": "Food"} # Valid
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
# Should skip non-dict items gracefully
assert len(validated) >= 0, "Should handle non-dict items without crashing"
def test_handles_empty_string_values(self, extractor, sample_ontology_subset):
"""Test that empty string values are skipped."""
triples_response = [
{"subject": "", "predicate": "rdf:type", "object": "Recipe"},
{"subject": "entity1", "predicate": "", "object": "Recipe"},
{"subject": "entity1", "predicate": "rdf:type", "object": ""},
{"subject": "entity1", "predicate": "rdf:type", "object": "Recipe"} # Valid
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1, "Should skip triples with empty strings"
def test_handles_unicode_in_literals(self, extractor, sample_ontology_subset):
"""Test handling of unicode characters in literal values."""
triples_response = [
{"subject": "café-recipe", "predicate": "rdfs:label", "object": "Café Spécial"}
]
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1
assert "Café Spécial" in validated[0].o.value
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,290 @@
"""
Unit tests for text processing and segmentation.
Tests that text is properly split into sentences for ontology matching,
including NLTK tokenization and TextSegment creation.
"""
import pytest
from trustgraph.extract.kg.ontology.text_processor import TextProcessor, TextSegment
@pytest.fixture
def text_processor():
"""Create a TextProcessor instance for testing."""
return TextProcessor()
class TestTextSegmentation:
"""Test suite for text segmentation functionality."""
def test_segment_single_sentence(self, text_processor):
"""Test segmentation of a single sentence."""
text = "This is a simple sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 1, "Single sentence should produce one sentence segment"
assert text in sentences[0].text, "Segment text should contain input"
def test_segment_multiple_sentences(self, text_processor):
"""Test segmentation of multiple sentences."""
text = "First sentence. Second sentence. Third sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 3, "Should create three sentence segments for three sentences"
assert "First sentence" in sentences[0].text
assert "Second sentence" in sentences[1].text
assert "Third sentence" in sentences[2].text
def test_segment_positions(self, text_processor):
"""Test that segment positions are tracked."""
text = "First sentence. Second sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 2
assert sentences[0].position == 0
assert sentences[1].position > 0
def test_segment_empty_text(self, text_processor):
"""Test handling of empty text."""
text = ""
segments = text_processor.process_chunk(text, extract_phrases=False)
assert len(segments) == 0, "Empty text should produce no segments"
def test_segment_whitespace_only(self, text_processor):
"""Test handling of whitespace-only text."""
text = " \n\t "
segments = text_processor.process_chunk(text, extract_phrases=False)
# May produce empty segments or no segments depending on implementation
assert len(segments) <= 1, "Whitespace-only text should produce minimal segments"
def test_segment_with_newlines(self, text_processor):
"""Test segmentation of text with newlines."""
text = "First sentence.\nSecond sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 2
assert "First sentence" in sentences[0].text
assert "Second sentence" in sentences[1].text
def test_segment_complex_punctuation(self, text_processor):
"""Test segmentation with complex punctuation."""
text = "Dr. Smith went to the U.S.A. yesterday. He met Mr. Jones."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
# NLTK should handle abbreviations correctly
assert len(sentences) == 2, "Should recognize abbreviations and not split on them"
assert "Dr. Smith" in sentences[0].text
assert "Mr. Jones" in sentences[1].text
def test_segment_question_and_exclamation(self, text_processor):
"""Test segmentation with different sentence terminators."""
text = "Is this working? Yes, it is! Great news."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 3
assert "Is this working?" in sentences[0].text
assert "Yes, it is!" in sentences[1].text
assert "Great news" in sentences[2].text
def test_segment_long_paragraph(self, text_processor):
"""Test segmentation of a longer paragraph."""
text = (
"The recipe requires several ingredients. "
"First, gather flour and sugar. "
"Then, add eggs and milk. "
"Finally, mix everything together."
)
segments = text_processor.process_chunk(text, extract_phrases=False)
# Filter to only sentences
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 4, "Should split paragraph into individual sentences"
assert all(isinstance(seg, TextSegment) for seg in sentences)
def test_extract_phrases_option(self, text_processor):
"""Test that phrase extraction can be enabled."""
text = "The recipe requires several ingredients."
# With phrases
segments_with_phrases = text_processor.process_chunk(text, extract_phrases=True)
# Without phrases
segments_without_phrases = text_processor.process_chunk(text, extract_phrases=False)
# With phrases should have more segments (sentences + phrases)
assert len(segments_with_phrases) >= len(segments_without_phrases)
class TestTextSegmentCreation:
"""Test suite for TextSegment object creation."""
def test_text_segment_attributes(self, text_processor):
"""Test that TextSegment objects have correct attributes."""
text = "This is a test sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
assert len(segments) >= 1
segment = segments[0]
assert hasattr(segment, 'text'), "Segment should have text attribute"
assert hasattr(segment, 'type'), "Segment should have type attribute"
assert hasattr(segment, 'position'), "Segment should have position attribute"
assert segment.type in ['sentence', 'phrase', 'noun_phrase', 'verb_phrase']
def test_text_segment_types(self, text_processor):
"""Test that different segment types are created correctly."""
text = "The recipe requires several ingredients."
# Without phrases
segments = text_processor.process_chunk(text, extract_phrases=False)
types = set(s.type for s in segments)
assert 'sentence' in types, "Should create sentence segments"
# With phrases
segments = text_processor.process_chunk(text, extract_phrases=True)
types = set(s.type for s in segments)
assert 'sentence' in types, "Should create sentence segments"
# May also have phrase types
def test_text_segment_sentence_tracking(self, text_processor):
"""Test that segments track their parent sentence."""
text = "This is a test sentence."
segments = text_processor.process_chunk(text, extract_phrases=True)
# Phrases should reference their parent sentence
phrases = [s for s in segments if s.type != 'sentence']
if phrases:
for phrase in phrases:
# parent_sentence may be set for phrases
assert hasattr(phrase, 'parent_sentence')
class TestNLTKCompatibility:
"""Test suite for NLTK version compatibility."""
def test_nltk_punkt_availability(self, text_processor):
"""Test that NLTK punkt tokenizer is available."""
# This test verifies the text_processor can use NLTK
# If punkt/punkt_tab is not available, this will fail during setup
import nltk
# Try to use sentence tokenizer
text = "Test sentence. Another sentence."
try:
from nltk.tokenize import sent_tokenize
result = sent_tokenize(text)
assert len(result) > 0, "NLTK sentence tokenizer should work"
except LookupError:
pytest.fail("NLTK punkt tokenizer not available")
def test_text_processor_uses_nltk(self, text_processor):
"""Test that TextProcessor successfully uses NLTK for segmentation."""
# This verifies the integration works
text = "First sentence. Second sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
# Should successfully segment using NLTK
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) >= 1, "Should successfully segment text using NLTK"
class TestEdgeCases:
"""Test suite for edge cases in text processing."""
def test_sentence_with_only_punctuation(self, text_processor):
"""Test handling of unusual punctuation patterns."""
text = "...!?!"
segments = text_processor.process_chunk(text, extract_phrases=False)
# Should handle gracefully (NLTK may split this oddly, that's ok)
assert len(segments) <= 3, "Should handle punctuation-only text gracefully"
def test_very_long_sentence(self, text_processor):
"""Test handling of very long sentences."""
# Create a long sentence with many clauses
text = (
"This is a very long sentence with many clauses, "
"including subordinate clauses, coordinate clauses, "
"and various other grammatical structures that make it "
"quite lengthy but still technically a single sentence."
)
segments = text_processor.process_chunk(text, extract_phrases=False)
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 1, "Long sentence should still be one sentence segment"
assert len(sentences[0].text) > 100
def test_unicode_text(self, text_processor):
"""Test handling of unicode characters."""
text = "Café serves crêpes. The recipe is français."
segments = text_processor.process_chunk(text, extract_phrases=False)
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 2
assert "Café" in sentences[0].text
assert "français" in sentences[1].text
def test_numbers_and_dates(self, text_processor):
"""Test handling of numbers and dates in text."""
text = "The recipe was created on Jan. 1, 2024. It serves 4-6 people."
segments = text_processor.process_chunk(text, extract_phrases=False)
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 2
assert "2024" in sentences[0].text
assert "4-6" in sentences[1].text
def test_ellipsis_handling(self, text_processor):
"""Test handling of ellipsis in text."""
text = "First sentence... Second sentence."
segments = text_processor.process_chunk(text, extract_phrases=False)
# NLTK may handle ellipsis differently
assert len(segments) >= 1, "Should produce at least one segment"
# The exact behavior depends on NLTK version
def test_quoted_text(self, text_processor):
"""Test handling of quoted text."""
text = 'He said "Hello world." Then he left.'
segments = text_processor.process_chunk(text, extract_phrases=False)
sentences = [s for s in segments if s.type == 'sentence']
assert len(sentences) == 2
assert '"Hello world."' in sentences[0].text or "Hello world" in sentences[0].text
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,258 @@
"""
Unit tests for URI expansion functionality.
Tests that URIs are properly expanded using ontology definitions instead of
constructed fallback URIs.
"""
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
class MockParams:
"""Mock parameters for Processor."""
def get(self, key, default=None):
return default
@pytest.fixture
def extractor():
"""Create a Processor instance for testing."""
params = MockParams()
# We only need the expand_uri method, so minimal initialization
extractor = object.__new__(Processor)
extractor.URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
return extractor
@pytest.fixture
def ontology_subset_with_uris():
"""Create an ontology subset with proper URIs defined."""
return OntologySubset(
ontology_id="food",
classes={
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"labels": [{"value": "Recipe", "lang": "en-gb"}],
"comment": "A Recipe is a combination of ingredients and a method."
},
"Ingredient": {
"uri": "http://purl.org/ontology/fo/Ingredient",
"type": "owl:Class",
"labels": [{"value": "Ingredient", "lang": "en-gb"}],
"comment": "An Ingredient combines a quantity and a food."
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"labels": [{"value": "Food", "lang": "en-gb"}],
"comment": "A Food is something that can be eaten."
}
},
object_properties={
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"labels": [{"value": "ingredients", "lang": "en-gb"}],
"domain": "Recipe",
"range": "IngredientList"
},
"food": {
"uri": "http://purl.org/ontology/fo/food",
"type": "owl:ObjectProperty",
"labels": [{"value": "food", "lang": "en-gb"}],
"domain": "Ingredient",
"range": "Food"
},
"produces": {
"uri": "http://purl.org/ontology/fo/produces",
"type": "owl:ObjectProperty",
"labels": [{"value": "produces", "lang": "en-gb"}],
"domain": "Recipe",
"range": "Food"
}
},
datatype_properties={
"serves": {
"uri": "http://purl.org/ontology/fo/serves",
"type": "owl:DatatypeProperty",
"labels": [{"value": "serves", "lang": "en-gb"}],
"domain": "Recipe",
"range": "xsd:string"
}
},
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/"
}
)
class TestURIExpansion:
"""Test suite for URI expansion functionality."""
def test_expand_class_uri_from_ontology(self, extractor, ontology_subset_with_uris):
"""Test that class names are expanded to their ontology URIs."""
result = extractor.expand_uri("Recipe", ontology_subset_with_uris, "food")
assert result == "http://purl.org/ontology/fo/Recipe", \
"Recipe should expand to its ontology URI"
def test_expand_object_property_uri_from_ontology(self, extractor, ontology_subset_with_uris):
"""Test that object properties are expanded to their ontology URIs."""
result = extractor.expand_uri("ingredients", ontology_subset_with_uris, "food")
assert result == "http://purl.org/ontology/fo/ingredients", \
"ingredients property should expand to its ontology URI"
def test_expand_datatype_property_uri_from_ontology(self, extractor, ontology_subset_with_uris):
"""Test that datatype properties are expanded to their ontology URIs."""
result = extractor.expand_uri("serves", ontology_subset_with_uris, "food")
assert result == "http://purl.org/ontology/fo/serves", \
"serves property should expand to its ontology URI"
def test_expand_multiple_classes(self, extractor, ontology_subset_with_uris):
"""Test expansion of multiple different classes."""
recipe_uri = extractor.expand_uri("Recipe", ontology_subset_with_uris, "food")
ingredient_uri = extractor.expand_uri("Ingredient", ontology_subset_with_uris, "food")
food_uri = extractor.expand_uri("Food", ontology_subset_with_uris, "food")
assert recipe_uri == "http://purl.org/ontology/fo/Recipe"
assert ingredient_uri == "http://purl.org/ontology/fo/Ingredient"
assert food_uri == "http://purl.org/ontology/fo/Food"
def test_expand_rdf_prefix(self, extractor, ontology_subset_with_uris):
"""Test that standard RDF prefixes are expanded."""
result = extractor.expand_uri("rdf:type", ontology_subset_with_uris, "food")
assert result == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", \
"rdf:type should expand to full RDF namespace URI"
def test_expand_rdfs_prefix(self, extractor, ontology_subset_with_uris):
"""Test that RDFS prefixes are expanded."""
result = extractor.expand_uri("rdfs:label", ontology_subset_with_uris, "food")
assert result == "http://www.w3.org/2000/01/rdf-schema#label", \
"rdfs:label should expand to full RDFS namespace URI"
def test_expand_owl_prefix(self, extractor, ontology_subset_with_uris):
"""Test that OWL prefixes are expanded."""
result = extractor.expand_uri("owl:Class", ontology_subset_with_uris, "food")
assert result == "http://www.w3.org/2002/07/owl#Class", \
"owl:Class should expand to full OWL namespace URI"
def test_expand_xsd_prefix(self, extractor, ontology_subset_with_uris):
"""Test that XSD prefixes are expanded."""
result = extractor.expand_uri("xsd:string", ontology_subset_with_uris, "food")
assert result == "http://www.w3.org/2001/XMLSchema#string", \
"xsd:string should expand to full XSD namespace URI"
def test_fallback_uri_for_instance(self, extractor, ontology_subset_with_uris):
"""Test that entity instances get constructed URIs when not in ontology."""
result = extractor.expand_uri("recipe:cornish-pasty", ontology_subset_with_uris, "food")
# Should construct a URI for the instance
assert result.startswith("https://trustgraph.ai/food/"), \
"Entity instance should get constructed URI under trustgraph.ai domain"
assert "cornish-pasty" in result.lower(), \
"Instance URI should include normalized entity name"
def test_already_full_uri_unchanged(self, extractor, ontology_subset_with_uris):
"""Test that full URIs are returned unchanged."""
full_uri = "http://example.com/custom/entity"
result = extractor.expand_uri(full_uri, ontology_subset_with_uris, "food")
assert result == full_uri, \
"Full URIs should be returned unchanged"
def test_https_uri_unchanged(self, extractor, ontology_subset_with_uris):
"""Test that HTTPS URIs are returned unchanged."""
full_uri = "https://example.com/custom/entity"
result = extractor.expand_uri(full_uri, ontology_subset_with_uris, "food")
assert result == full_uri, \
"HTTPS URIs should be returned unchanged"
def test_class_without_uri_gets_fallback(self, extractor):
"""Test that classes without URI definitions get constructed fallback URIs."""
# Create subset with class that has no URI
subset_no_uri = OntologySubset(
ontology_id="test",
classes={
"SomeClass": {
"type": "owl:Class",
"labels": [{"value": "Some Class"}],
# No 'uri' field
}
},
object_properties={},
datatype_properties={},
metadata={}
)
result = extractor.expand_uri("SomeClass", subset_no_uri, "test")
assert result == "https://trustgraph.ai/ontology/test#SomeClass", \
"Class without URI should get fallback constructed URI"
def test_property_without_uri_gets_fallback(self, extractor):
"""Test that properties without URI definitions get constructed fallback URIs."""
subset_no_uri = OntologySubset(
ontology_id="test",
classes={},
object_properties={
"someProperty": {
"type": "owl:ObjectProperty",
# No 'uri' field
}
},
datatype_properties={},
metadata={}
)
result = extractor.expand_uri("someProperty", subset_no_uri, "test")
assert result == "https://trustgraph.ai/ontology/test#someProperty", \
"Property without URI should get fallback constructed URI"
def test_entity_normalization_in_constructed_uri(self, extractor, ontology_subset_with_uris):
"""Test that entity names are normalized when constructing URIs."""
# Entity with spaces and mixed case
result = extractor.expand_uri("Cornish Pasty Recipe", ontology_subset_with_uris, "food")
# Should be normalized: lowercase, spaces to hyphens
assert result == "https://trustgraph.ai/food/cornish-pasty-recipe", \
"Entity names should be normalized (lowercase, spaces to hyphens)"
def test_dict_access_not_object_attribute(self, extractor, ontology_subset_with_uris):
"""Test that URI expansion works with dict access (not object attributes).
This is the key fix - ontology_selector stores cls.__dict__ which means
we get dicts, not objects, so we must use dict key access.
"""
# The ontology_subset_with_uris uses dicts (with 'uri' key)
# This test verifies we can access it correctly
class_def = ontology_subset_with_uris.classes["Recipe"]
# Verify it's a dict
assert isinstance(class_def, dict), "Class definitions should be dicts"
assert "uri" in class_def, "Dict should have 'uri' key"
# Now test expansion works
result = extractor.expand_uri("Recipe", ontology_subset_with_uris, "food")
assert result == class_def["uri"], \
"URI expansion must work with dict access (not object attributes)"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -119,8 +119,8 @@ class TestPineconeDocEmbeddingsQueryProcessor:
chunks = await processor.query_document_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "d-test_user-test_collection"
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
@ -264,10 +264,12 @@ class TestPineconeDocEmbeddingsQueryProcessor:
chunks = await processor.query_document_embeddings(message)
# Verify same index used for both vectors
expected_index_name = "d-test_user-test_collection"
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
processor.pinecone.Index.assert_called_with(expected_index_name)
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "d-test_user-test_collection-2" in index_names # 2D vector
assert "d-test_user-test_collection-4" in index_names # 4D vector
# Verify both queries were made
assert mock_index.query.call_count == 2

View file

@ -103,8 +103,8 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called with correct parameters
expected_collection = 'd_test_user_test_collection'
# Verify query was called with correct parameters (with dimension suffix)
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions
mock_qdrant_instance.query_points.assert_called_once_with(
collection_name=expected_collection,
query=[0.1, 0.2, 0.3],
@ -164,9 +164,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify both collections were queried
expected_collection = 'd_multi_user_multi_collection'
# Verify both collections were queried (both 2-dimensional vectors)
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
@ -301,13 +301,13 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection'
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection'
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results

View file

@ -147,8 +147,8 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
entities = await processor.query_graph_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "t-test_user-test_collection"
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
@ -290,10 +290,12 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
entities = await processor.query_graph_embeddings(message)
# Verify same index used for both vectors
expected_index_name = "t-test_user-test_collection"
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
processor.pinecone.Index.assert_called_with(expected_index_name)
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "t-test_user-test_collection-2" in index_names # 2D vector
assert "t-test_user-test_collection-4" in index_names # 4D vector
# Verify both queries were made
assert mock_index.query.call_count == 2

View file

@ -175,8 +175,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called with correct parameters
expected_collection = 't_test_user_test_collection'
# Verify query was called with correct parameters (with dimension suffix)
expected_collection = 't_test_user_test_collection_3' # 3 dimensions
mock_qdrant_instance.query_points.assert_called_once_with(
collection_name=expected_collection,
query=[0.1, 0.2, 0.3],
@ -234,9 +234,9 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify both collections were queried
expected_collection = 't_multi_user_multi_collection'
# Verify both collections were queried (both 2-dimensional vectors)
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
@ -372,13 +372,13 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection'
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection'
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results

View file

@ -134,8 +134,8 @@ class TestPineconeDocEmbeddingsStorageProcessor:
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
# Verify index name and operations
expected_index_name = "d-test_user-test_collection"
# Verify index name and operations (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_document_embeddings_index_validation(self, processor):
"""Test that writing to non-existent index raises ValueError"""
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
@ -191,12 +191,24 @@ class TestPineconeDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
# Mock index doesn't exist
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
with pytest.raises(ValueError, match="Collection .* does not exist"):
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify index was created with correct dimension
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
# Verify upsert was still called
mock_index.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
@ -345,7 +357,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_document_embeddings_validation_before_creation(self, processor):
"""Test that validation error occurs before creation attempts"""
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
@ -359,13 +371,18 @@ class TestPineconeDocEmbeddingsStorageProcessor:
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
with pytest.raises(ValueError, match="Collection .* does not exist"):
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@pytest.mark.asyncio
async def test_store_document_embeddings_validates_before_timeout(self, processor):
"""Test that validation error occurs before timeout checks"""
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
@ -379,10 +396,16 @@ class TestPineconeDocEmbeddingsStorageProcessor:
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
with pytest.raises(ValueError, match="Collection .* does not exist"):
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()
mock_index.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""

View file

@ -103,8 +103,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message)
# Assert
# Verify collection existence was checked
expected_collection = 'd_test_user_test_collection'
# Verify collection existence was checked (with dimension suffix)
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
@ -112,7 +112,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == expected_collection
assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3'
assert len(upsert_call_args[1]['points']) == 1
point = upsert_call_args[1]['points'][0]
@ -272,18 +272,21 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert
# Should not call upsert for empty chunks
mock_qdrant_instance.upsert.assert_not_called()
# But collection_exists should be called for validation
mock_qdrant_instance.collection_exists.assert_called_once()
# collection_exists should NOT be called since we return early for empty chunks
mock_qdrant_instance.collection_exists.assert_not_called()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
"""Test that writing to non-existent collection raises ValueError"""
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test that writing to non-existent collection creates it lazily"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
@ -305,19 +308,36 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(mock_message)
# Act
await processor.store_document_embeddings(mock_message)
# Assert - collection should be lazily created
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call = mock_qdrant_instance.create_collection.call_args
assert create_call[1]['collection_name'] == expected_collection
assert create_call[1]['vectors_config'].size == 5
# Verify upsert was still called
mock_qdrant_instance.upsert.assert_called_once()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test that validation error occurs before connection errors"""
async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test that collection creation errors are propagated"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
# Simulate creation failure
mock_qdrant_instance.create_collection.side_effect = Exception("Connection error")
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
@ -339,8 +359,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
# Act & Assert - should propagate the creation error
with pytest.raises(Exception, match="Connection error"):
await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@ -398,7 +418,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection'
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
# Verify collection existence is checked on each write
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
@ -407,15 +427,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance.upsert.assert_called_once()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client):
async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test that different vector dimensions create different collections"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -424,35 +447,39 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with different dimension vectors
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
mock_chunk.vectors = [
[0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions
]
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should check existence of the same collection (dimensions no longer create separate collections)
expected_collection = 'd_dim_user_dim_collection'
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Should check existence of DIFFERENT collections for each dimension
assert mock_qdrant_instance.collection_exists.call_count == 2
# Should upsert to the same collection for both vectors
# Verify the two different collection names were checked
collection_exists_calls = [call[0][0] for call in mock_qdrant_instance.collection_exists.call_args_list]
assert 'd_dim_user_dim_collection_2' in collection_exists_calls # 2-dim vector
assert 'd_dim_user_dim_collection_3' in collection_exists_calls # 3-dim vector
# Should upsert to different collections for each vector
assert mock_qdrant_instance.upsert.call_count == 2
upsert_calls = mock_qdrant_instance.upsert.call_args_list
assert upsert_calls[0][1]['collection_name'] == expected_collection
assert upsert_calls[1][1]['collection_name'] == expected_collection
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')

View file

@ -134,8 +134,8 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_graph_embeddings(message)
# Verify index name and operations
expected_index_name = "t-test_user-test_collection"
# Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
@ -179,7 +179,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_validation(self, processor):
"""Test that writing to non-existent index raises ValueError"""
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
@ -193,10 +193,22 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
with pytest.raises(ValueError, match="Collection .* does not exist"):
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
# Verify index was created with correct dimension
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
# Verify upsert was still called
mock_index.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
@ -267,11 +279,16 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
# Verify same index was used for all dimensions
expected_index_name = 't-test_user-test_collection'
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify different indexes were used for different dimensions
index_calls = processor.pinecone.Index.call_args_list
assert len(index_calls) == 3
# Extract index names from calls
index_names = [call[0][0] for call in index_calls]
assert 't-test_user-test_collection-2' in index_names # 2D vector
assert 't-test_user-test_collection-4' in index_names # 4D vector
assert 't-test_user-test_collection-3' in index_names # 3D vector
# Verify all vectors were upserted to the same index
# Verify all vectors were upserted (to their respective indexes)
assert mock_index.upsert.call_count == 3
@pytest.mark.asyncio
@ -316,7 +333,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_validation_before_creation(self, processor):
"""Test that validation error occurs before any creation attempts"""
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
@ -330,13 +347,18 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
with pytest.raises(ValueError, match="Collection .* does not exist"):
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@pytest.mark.asyncio
async def test_store_graph_embeddings_validates_before_timeout(self, processor):
"""Test that validation error occurs before timeout checks"""
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
@ -350,10 +372,16 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
with pytest.raises(ValueError, match="Collection .* does not exist"):
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()
mock_index.upsert.assert_called_once()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser

View file

@ -44,29 +44,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client):
"""Test get_collection validates that collection exists"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
processor.get_collection(user='test_user', collection='test_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
@ -103,114 +80,22 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_graph_embeddings(mock_message)
# Assert
# Verify collection existence was checked
expected_collection = 't_test_user_test_collection'
# Verify collection existence was checked (with dimension suffix)
expected_collection = 't_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == expected_collection
assert upsert_call_args[1]['collection_name'] == 't_test_user_test_collection_3'
assert len(upsert_call_args[1]['points']) == 1
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['entity'] == 'test_entity'
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection uses existing collection without creating new one"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(user='existing_user', collection='existing_collection')
# Assert
expected_name = 't_existing_user_existing_collection'
assert collection_name == expected_name
# Verify collection existence check was performed
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
# Verify create_collection was NOT called
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client):
"""Test get_collection validates collection existence on each call"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# First call
collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection')
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
mock_qdrant_instance.collection_exists.return_value = True
# Act - Second call with same parameters
collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection')
# Assert
expected_name = 't_cache_user_cache_collection'
assert collection_name1 == expected_name
assert collection_name2 == expected_name
# Verify collection existence check happens on each call
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test get_collection raises ValueError when collection doesn't exist"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
processor.get_collection(user='error_user', collection='error_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')

View file

@ -102,11 +102,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
}]
}],
temperature=0.0,
max_tokens=4096,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "text"}
max_tokens=4096
)
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@ -385,10 +381,6 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['model'] == 'gpt-3.5-turbo'
assert call_args[1]['temperature'] == 0.5
assert call_args[1]['max_tokens'] == 1024
assert call_args[1]['top_p'] == 1
assert call_args[1]['frequency_penalty'] == 0
assert call_args[1]['presence_penalty'] == 0
assert call_args[1]['response_format'] == {"type": "text"}
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')

View file

@ -9,7 +9,7 @@ from prometheus_client import Histogram
from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
# Module logger
logger = logging.getLogger(__name__)
@ -45,6 +45,12 @@ class EmbeddingsService(FlowProcessor):
)
)
self.register_specification(
ParameterSpec(
name = "model",
)
)
async def on_request(self, msg, consumer, flow):
try:
@ -57,7 +63,9 @@ class EmbeddingsService(FlowProcessor):
logger.debug(f"Handling embeddings request {id}...")
vectors = await self.on_embeddings(request.text)
# Pass model from request if specified (non-empty), otherwise use default
model = flow("model")
vectors = await self.on_embeddings(request.text, model=model)
await flow("response").send(
EmbeddingsResponse(

View file

@ -1,4 +1,5 @@
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
RDF_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION = "http://www.w3.org/2004/02/skos/core#definition"
SUBJECT_OF = "https://schema.org/subjectOf"

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"pulsar-client",
"prometheus-client",
"boto3",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"requests",
"pulsar-client",
"aiohttp",

View file

@ -7,6 +7,7 @@ specification. This script stores MCP tool configurations with:
- id: Unique identifier for the tool
- remote-name: Name used by the MCP server (defaults to id)
- url: MCP server endpoint URL
- auth-token: Optional bearer token for authentication
Configurations are stored in the 'mcp' configuration group and can be
referenced by agent tools using the 'mcp-tool' type.
@ -25,17 +26,24 @@ def set_mcp_tool(
id : str,
remote_name : str,
tool_url : str,
auth_token : str = None,
):
api = Api(url).config()
# Build the MCP tool configuration
config = {
"remote-name": remote_name,
"url": tool_url,
}
if auth_token:
config["auth-token"] = auth_token
# Store the MCP tool configuration in the 'mcp' group
values = api.put([
ConfigValue(
type="mcp", key=id, value=json.dumps({
"remote-name": remote_name,
"url": tool_url,
})
type="mcp", key=id, value=json.dumps(config)
)
])
@ -45,12 +53,15 @@ def main():
prog='tg-set-mcp-tool',
description=__doc__,
epilog=textwrap.dedent('''
MCP tools are configured with just a name and URL. The URL should point
MCP tools are configured with a name and URL. The URL should point
to the MCP server endpoint that provides the tool functionality.
Optionally, an auth-token can be provided for secured endpoints.
Examples:
%(prog)s --id weather --tool-url "http://localhost:3000/weather"
%(prog)s --id calculator --tool-url "http://mcp-tools.example.com/calc"
%(prog)s --id secure-tool --tool-url "https://api.example.com/mcp" \\
--auth-token "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
''').strip(),
formatter_class=argparse.RawDescriptionHelpFormatter
)
@ -79,6 +90,12 @@ def main():
help='MCP tool URL endpoint',
)
parser.add_argument(
'--auth-token',
required=False,
help='Bearer token for authentication (optional)',
)
args = parser.parse_args()
try:
@ -98,7 +115,8 @@ def main():
url=args.api_url,
id=args.id,
remote_name=remote_name,
tool_url=args.tool_url
tool_url=args.tool_url,
auth_token=args.auth_token
)
except Exception as e:

View file

@ -27,6 +27,12 @@ def show_config(url):
table.append(("remote-name", data["remote-name"]))
table.append(("url", data["url"]))
# Display auth status (masked for security)
if "auth-token" in data and data["auth-token"]:
table.append(("auth", "Yes (configured)"))
else:
table.append(("auth", "No"))
print()
print(tabulate.tabulate(

View file

@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph."
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-flow>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"trustgraph-flow>=1.5,<1.6",
"torch",
"urllib3",
"transformers",

View file

@ -26,10 +26,31 @@ class Processor(EmbeddingsService):
**params | { "model": model }
)
logger.info(f"Loading HuggingFace embeddings model: {model}")
self.embeddings = HuggingFaceEmbeddings(model_name=model)
self.default_model = model
async def on_embeddings(self, text):
# Cache for currently loaded model
self.cached_model_name = None
self.embeddings = None
# Load the default model
self._load_model(model)
def _load_model(self, model_name):
"""Load a model, caching it for reuse"""
if self.cached_model_name != model_name:
logger.info(f"Loading HuggingFace embeddings model: {model_name}")
self.embeddings = HuggingFaceEmbeddings(model_name=model_name)
self.cached_model_name = model_name
logger.info(f"HuggingFace model {model_name} loaded successfully")
else:
logger.debug(f"Using cached model: {model_name}")
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
# Reload model if it has changed
self._load_model(use_model)
embeds = self.embeddings.embed_documents([text])
logger.debug("Embeddings generation complete")

View file

@ -10,12 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"aiohttp",
"anthropic",
"cassandra-driver",
"scylla-driver",
"cohere",
"cryptography",
"faiss-cpu",
"falkordb",
"fastembed",
"google-genai",
@ -29,6 +30,7 @@ dependencies = [
"minio",
"mistralai",
"neo4j",
"nltk",
"ollama",
"openai",
"pinecone[grpc]",
@ -82,6 +84,7 @@ kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-objects = "trustgraph.extract.kg.objects:run"
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
kg-extract-topics = "trustgraph.extract.kg.topics:run"
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
kg-manager = "trustgraph.cores:run"
kg-store = "trustgraph.storage.knowledge:run"
librarian = "trustgraph.librarian:run"

View file

@ -56,10 +56,16 @@ class Service(ToolService):
else:
remote_name = name
# Build headers with optional bearer token
headers = {}
if "auth-token" in self.mcp_services[name]:
token = self.mcp_services[name]["auth-token"]
headers["Authorization"] = f"Bearer {token}"
logger.info(f"Invoking {remote_name} at {url}")
# Connect to a streamable HTTP server
async with streamablehttp_client(url) as (
# Connect to a streamable HTTP server with headers
async with streamablehttp_client(url, headers=headers) as (
read_stream,
write_stream,
_,

View file

@ -317,7 +317,7 @@ class Processor(AgentService):
AgentStep(
thought=h.thought,
action=h.name,
arguments=h.arguments,
arguments={k: str(v) for k, v in h.arguments.items()},
observation=h.observation
)
for h in history

View file

@ -334,13 +334,14 @@ class KnowledgeGraph:
count += 1
# Execute batch every 100 triples to avoid oversized batches
if count % 100 == 0:
# Execute batch every 25 triples to avoid oversized batches
# (Each triple adds ~4 statements, so 25 triples = ~100 statements)
if count % 25 == 0:
self.session.execute(batch)
batch = BatchStatement()
# Execute remaining deletions
if count % 100 != 0:
if count % 25 != 0:
self.session.execute(batch)
# Step 3: Delete collection metadata

View file

@ -50,24 +50,26 @@ class DocVectors:
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
"""Check if collection exists (dimension-independent check)"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
return self.client.has_collection(collection_name)
"""
Check if any collection exists for this user/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
"""Create collection with default dimension"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists")
return
self.init_collection(dimension, user, collection)
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(user, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
name="id",
@ -115,6 +117,7 @@ class DocVectors:
)
self.collections[(dimension, user, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, doc, user, collection):
@ -139,8 +142,15 @@ class DocVectors:
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
base_name = make_safe_collection_name(user, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
@ -172,19 +182,27 @@ class DocVectors:
return res
def delete_collection(self, user, collection):
"""Delete a collection for the given user and collection"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
"""
Delete all dimension variants of the collection for the given user/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
prefix = f"{base_name}_"
# Check if collection exists
if self.client.has_collection(collection_name):
# Drop the collection
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
# Get all collections and filter for matches
all_collections = self.client.list_collections()
matching_collections = [coll for coll in all_collections if coll.startswith(prefix)]
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]
if not matching_collections:
logger.info(f"No collections found matching prefix {prefix}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -50,24 +50,26 @@ class EntityVectors:
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
"""Check if collection exists (dimension-independent check)"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
return self.client.has_collection(collection_name)
"""
Check if any collection exists for this user/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
"""Create collection with default dimension"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists")
return
self.init_collection(dimension, user, collection)
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(user, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
name="id",
@ -115,6 +117,7 @@ class EntityVectors:
)
self.collections[(dimension, user, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, entity, user, collection):
@ -139,8 +142,15 @@ class EntityVectors:
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
base_name = make_safe_collection_name(user, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
@ -172,19 +182,27 @@ class EntityVectors:
return res
def delete_collection(self, user, collection):
"""Delete a collection for the given user and collection"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
"""
Delete all dimension variants of the collection for the given user/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
prefix = f"{base_name}_"
# Check if collection exists
if self.client.has_collection(collection_name):
# Drop the collection
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
# Get all collections and filter for matches
all_collections = self.client.list_collections()
matching_collections = [coll for coll in all_collections if coll.startswith(prefix)]
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]
if not matching_collections:
logger.info(f"No collections found matching prefix {prefix}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -27,10 +27,31 @@ class Processor(EmbeddingsService):
**params | { "model": model }
)
logger.info("Loading FastEmbed model...")
self.embeddings = TextEmbedding(model_name = model)
self.default_model = model
async def on_embeddings(self, text):
# Cache for currently loaded model
self.cached_model_name = None
self.embeddings = None
# Load the default model
self._load_model(model)
def _load_model(self, model_name):
"""Load a model, caching it for reuse"""
if self.cached_model_name != model_name:
logger.info(f"Loading FastEmbed model: {model_name}")
self.embeddings = TextEmbedding(model_name=model_name)
self.cached_model_name = model_name
logger.info(f"FastEmbed model {model_name} loaded successfully")
else:
logger.debug(f"Using cached model: {model_name}")
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
# Reload model if it has changed
self._load_model(use_model)
vecs = self.embeddings.embed([text])

View file

@ -28,12 +28,14 @@ class Processor(EmbeddingsService):
)
self.client = Client(host=ollama)
self.model = model
self.default_model = model
async def on_embeddings(self, text):
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
embeds = self.client.embed(
model = self.model,
model = use_model,
input = text
)

View file

@ -0,0 +1 @@
from . extract import *

View file

@ -0,0 +1,848 @@
"""
OntoRAG: Ontology-based knowledge extraction service.
Extracts ontology-conformant triples from text chunks.
"""
import json
import logging
import asyncio
from typing import List, Dict, Any, Optional
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec, EmbeddingsClientSpec
from .ontology_loader import OntologyLoader
from .ontology_embedder import OntologyEmbedder
from .vector_store import InMemoryVectorStore
from .text_processor import TextProcessor
from .ontology_selector import OntologySelector, OntologySubset
logger = logging.getLogger(__name__)
default_ident = "kg-extract-ontology"
default_concurrency = 1
# URI prefix mappings for common namespaces
URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"skos:": "http://www.w3.org/2004/02/skos/core#",
"schema:": "https://schema.org/",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
class Processor(FlowProcessor):
"""Main OntoRAG extraction processor."""
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
}
)
# Register specifications
self.register_specification(
ConsumerSpec(
name="input",
schema=Chunk,
handler=self.on_message,
concurrency=concurrency,
)
)
self.register_specification(
PromptClientSpec(
request_name="prompt-request",
response_name="prompt-response",
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response"
)
)
self.register_specification(
ProducerSpec(
name="triples",
schema=Triples
)
)
self.register_specification(
ProducerSpec(
name="entity-contexts",
schema=EntityContexts
)
)
# Register config handler for ontology updates
self.register_config_handler(self.on_ontology_config)
# Shared components (not flow-specific)
self.ontology_loader = OntologyLoader()
self.text_processor = TextProcessor()
# Per-flow components (each flow gets its own embedder/vector store/selector)
self.flow_components = {} # flow_id -> {embedder, vector_store, selector}
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3)
# Track loaded ontology version
self.current_ontology_version = None
self.loaded_ontology_ids = set()
async def initialize_flow_components(self, flow):
"""Initialize per-flow OntoRAG components.
Each flow gets its own vector store and embedder to support
different embedding models across flows. The vector store dimension
is auto-detected from the embeddings service.
Args:
flow: Flow object for this processing context
Returns:
flow_id: Identifier for this flow's components
"""
# Use flow object as identifier
flow_id = id(flow)
if flow_id in self.flow_components:
return flow_id # Already initialized for this flow
try:
logger.info(f"Initializing components for flow {flow_id}")
# Use embeddings client directly (no wrapper needed)
embeddings_client = flow("embeddings-request")
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed("test")
test_embedding = test_embedding_response[0] # Extract from [[vector]]
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")
# Initialize vector store with detected dimension
vector_store = InMemoryVectorStore(
dimension=dimension,
index_type='flat'
)
ontology_embedder = OntologyEmbedder(
embedding_service=embeddings_client,
vector_store=vector_store
)
# Embed all loaded ontologies for this flow
if self.ontology_loader.get_all_ontologies():
logger.info(f"Embedding ontologies for flow {flow_id}")
for ont_id, ontology in self.ontology_loader.get_all_ontologies().items():
await ontology_embedder.embed_ontology(ontology)
logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}")
# Initialize ontology selector
ontology_selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=self.ontology_loader,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold
)
# Store flow-specific components
self.flow_components[flow_id] = {
'embedder': ontology_embedder,
'vector_store': vector_store,
'selector': ontology_selector,
'dimension': dimension
}
logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})")
return flow_id
except Exception as e:
logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True)
raise
async def on_ontology_config(self, config, version):
"""
Handle ontology configuration updates from ConfigPush queue.
Parses and stores ontologies. Embedding happens per-flow on first message.
Called automatically when:
- Processor starts (gets full config history via start_of_messages=True)
- Config service pushes updates (immediate event-driven notification)
Args:
config: Full configuration map - config[type][key] = value
version: Config version number (monotonically increasing)
"""
try:
logger.info(f"Received ontology config update, version={version}")
# Skip if we've already processed this version
if version == self.current_ontology_version:
logger.debug(f"Already at version {version}, skipping")
return
# Extract ontology configurations
if "ontology" not in config:
logger.warning("No 'ontology' section in config")
return
ontology_configs = config["ontology"]
# Parse ontology definitions
ontologies = {}
for ont_id, ont_json in ontology_configs.items():
try:
ontologies[ont_id] = json.loads(ont_json)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
continue
logger.info(f"Loaded {len(ontologies)} ontology definitions")
# Determine what changed (for incremental updates)
new_ids = set(ontologies.keys())
added_ids = new_ids - self.loaded_ontology_ids
removed_ids = self.loaded_ontology_ids - new_ids
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
if added_ids:
logger.info(f"New ontologies: {added_ids}")
if removed_ids:
logger.info(f"Removed ontologies: {removed_ids}")
if updated_ids:
logger.info(f"Updated ontologies: {updated_ids}")
# Update ontology loader's internal state
self.ontology_loader.update_ontologies(ontologies)
# Clear all flow components to force re-embedding with new ontologies
if added_ids or removed_ids or updated_ids:
logger.info("Clearing flow components to trigger re-embedding")
self.flow_components.clear()
# Update tracking
self.current_ontology_version = version
self.loaded_ontology_ids = new_ids
logger.info(f"Ontology config update complete, version={version}")
except Exception as e:
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
async def on_message(self, msg, consumer, flow):
"""Process incoming chunk message."""
v = msg.value()
logger.info(f"Extracting ontology-based triples from {v.metadata.id}...")
# Initialize flow-specific components if needed
flow_id = await self.initialize_flow_components(flow)
components = self.flow_components[flow_id]
chunk = v.chunk.decode("utf-8")
logger.debug(f"Processing chunk: {chunk[:200]}...")
try:
# Process text into segments
segments = self.text_processor.process_chunk(chunk, extract_phrases=True)
logger.debug(f"Split chunk into {len(segments)} segments")
# Select relevant ontology subset (using flow-specific selector)
ontology_subsets = await components['selector'].select_ontology_subset(segments)
if not ontology_subsets:
logger.warning("No relevant ontology elements found for chunk")
# Emit empty outputs
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
return
# Merge subsets if multiple ontologies matched
if len(ontology_subsets) > 1:
ontology_subset = components['selector'].merge_subsets(ontology_subsets)
else:
ontology_subset = ontology_subsets[0]
logger.debug(f"Selected ontology subset with {len(ontology_subset.classes)} classes, "
f"{len(ontology_subset.object_properties)} object properties, "
f"{len(ontology_subset.datatype_properties)} datatype properties")
# Build extraction prompt variables
prompt_variables = self.build_extraction_variables(chunk, ontology_subset)
# Call prompt service for extraction
try:
# Use prompt() method with extract-with-ontologies prompt ID
triples_response = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
logger.debug(f"Extraction response: {triples_response}")
if not isinstance(triples_response, list):
logger.error("Expected list of triples from prompt service")
triples_response = []
except Exception as e:
logger.error(f"Prompt service error: {e}", exc_info=True)
triples_response = []
# Parse and validate triples
triples = self.parse_and_validate_triples(triples_response, ontology_subset)
# Add metadata triples
for t in v.metadata.metadata:
triples.append(t)
# Generate ontology definition triples
ontology_triples = self.build_ontology_triples(ontology_subset)
# Combine extracted triples with ontology triples
all_triples = triples + ontology_triples
# Build entity contexts from all triples (including ontology elements)
entity_contexts = self.build_entity_contexts(all_triples)
# Emit all triples (extracted + ontology definitions)
await self.emit_triples(
flow("triples"),
v.metadata,
all_triples
)
# Emit entity contexts
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
entity_contexts
)
logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples "
f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts")
except Exception as e:
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
# Emit empty outputs on error
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
def build_extraction_variables(self, chunk: str, ontology_subset: OntologySubset) -> Dict[str, Any]:
"""Build variables for ontology-based extraction prompt template.
Args:
chunk: Text chunk to extract from
ontology_subset: Relevant ontology elements
Returns:
Dict with template variables: text, classes, object_properties, datatype_properties
"""
return {
"text": chunk,
"classes": ontology_subset.classes,
"object_properties": ontology_subset.object_properties,
"datatype_properties": ontology_subset.datatype_properties
}
def parse_and_validate_triples(self, triples_response: List[Any],
ontology_subset: OntologySubset) -> List[Triple]:
"""Parse and validate extracted triples against ontology."""
validated_triples = []
ontology_id = ontology_subset.ontology_id
for triple_data in triples_response:
try:
if isinstance(triple_data, dict):
subject = triple_data.get('subject', '')
predicate = triple_data.get('predicate', '')
object_val = triple_data.get('object', '')
if not subject or not predicate or not object_val:
continue
# Validate against ontology
if self.is_valid_triple(subject, predicate, object_val, ontology_subset):
# Expand URIs before creating Value objects
subject_uri = self.expand_uri(subject, ontology_subset, ontology_id)
predicate_uri = self.expand_uri(predicate, ontology_subset, ontology_id)
# Object might be URI or literal - check before expanding
if self.is_uri(object_val) or self.should_expand_as_uri(object_val, ontology_subset):
object_uri = self.expand_uri(object_val, ontology_subset, ontology_id)
is_object_uri = True
else:
object_uri = object_val
is_object_uri = False
# Create Triple object with expanded URIs
s_value = Value(value=subject_uri, is_uri=True)
p_value = Value(value=predicate_uri, is_uri=True)
o_value = Value(value=object_uri, is_uri=is_object_uri)
validated_triples.append(Triple(
s=s_value,
p=p_value,
o=o_value
))
else:
logger.debug(f"Invalid triple: ({subject}, {predicate}, {object_val})")
except Exception as e:
logger.error(f"Error parsing triple: {e}")
return validated_triples
def should_expand_as_uri(self, value: str, ontology_subset: OntologySubset) -> bool:
"""Check if a value should be treated as URI (not literal).
Returns True if value is a class name, property name, or entity reference.
"""
# Check if it's a class or property from ontology
if value in ontology_subset.classes:
return True
if value in ontology_subset.object_properties:
return True
if value in ontology_subset.datatype_properties:
return True
# Check if it starts with a known prefix
for prefix in URI_PREFIXES.keys():
if value.startswith(prefix):
return True
# Check if it looks like an entity reference (e.g., "recipe:cornish-pasty")
if ":" in value and not value.startswith("http"):
return True
return False
def is_valid_triple(self, subject: str, predicate: str, object_val: str,
ontology_subset: OntologySubset) -> bool:
"""Validate triple against ontology constraints."""
# Special case for rdf:type
if predicate == "rdf:type" or predicate == str(RDF_TYPE):
# Check if object is a valid class
return object_val in ontology_subset.classes
# Special case for rdfs:label
if predicate == "rdfs:label" or predicate == str(RDF_LABEL):
return True # Labels are always valid
# Check if predicate is a valid property
is_obj_prop = predicate in ontology_subset.object_properties
is_dt_prop = predicate in ontology_subset.datatype_properties
if not is_obj_prop and not is_dt_prop:
return False # Unknown property
# TODO: Add more sophisticated validation (domain/range checking)
return True
def expand_uri(self, value: str, ontology_subset: OntologySubset, ontology_id: str = "unknown") -> str:
"""Expand prefix notation or short names to full URIs.
Args:
value: Value to expand (e.g., "rdf:type", "Recipe", "has_ingredient")
ontology_subset: Ontology subset for class/property lookup
ontology_id: ID of the ontology for constructing instance URIs
Returns:
Full URI string
"""
# Already a full URI
if value.startswith("http://") or value.startswith("https://"):
return value
# Check standard prefixes (rdf:, rdfs:, etc.)
for prefix, namespace in URI_PREFIXES.items():
if value.startswith(prefix):
return namespace + value[len(prefix):]
# Check if it's an ontology class
if value in ontology_subset.classes:
class_def = ontology_subset.classes[value]
# class_def is a dict (from cls.__dict__ in ontology_selector)
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
return class_def['uri']
# Fallback: construct URI
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
# Check if it's an ontology property
if value in ontology_subset.object_properties:
prop_def = ontology_subset.object_properties[value]
# prop_def is a dict (from prop.__dict__ in ontology_selector)
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
return prop_def['uri']
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
if value in ontology_subset.datatype_properties:
prop_def = ontology_subset.datatype_properties[value]
# prop_def is a dict (from prop.__dict__ in ontology_selector)
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
return prop_def['uri']
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
# Otherwise, treat as entity instance - construct unique URI
# Normalize the value for URI (lowercase, replace spaces with hyphens)
normalized = value.replace(" ", "-").lower()
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
def is_uri(self, value: str) -> bool:
"""Check if value is already a full URI."""
return value.startswith("http://") or value.startswith("https://")
async def emit_triples(self, pub, metadata: Metadata, triples: List[Triple]):
"""Emit triples to output."""
t = Triples(
metadata=Metadata(
id=metadata.id,
metadata=[],
user=metadata.user,
collection=metadata.collection,
),
triples=triples,
)
await pub.send(t)
async def emit_entity_contexts(self, pub, metadata: Metadata, entities: List[EntityContext]):
"""Emit entity contexts to output."""
ec = EntityContexts(
metadata=Metadata(
id=metadata.id,
metadata=[],
user=metadata.user,
collection=metadata.collection,
),
entities=entities,
)
await pub.send(ec)
def build_ontology_triples(self, ontology_subset: OntologySubset) -> List[Triple]:
"""Build triples describing the ontology elements themselves.
Generates triples for classes and properties so they exist in the knowledge graph.
Args:
ontology_subset: The ontology subset used for extraction
Returns:
List of Triple objects describing ontology elements
"""
ontology_triples = []
# Generate triples for classes
for class_id, class_def in ontology_subset.classes.items():
# Get URI for class
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
class_uri = class_def['uri']
else:
# Fallback to constructed URI
class_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{class_id}"
# rdf:type owl:Class
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'labels' in class_def:
labels = class_def['labels']
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']:
comment = class_def['comment']
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
))
# rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'subclass_of' in class_def and class_def['subclass_of']:
parent = class_def['subclass_of']
# Get parent URI
if parent in ontology_subset.classes:
parent_class_def = ontology_subset.classes[parent]
if isinstance(parent_class_def, dict) and 'uri' in parent_class_def and parent_class_def['uri']:
parent_uri = parent_class_def['uri']
else:
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
else:
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
o=Value(value=parent_uri, is_uri=True)
))
# Generate triples for object properties
for prop_id, prop_def in ontology_subset.object_properties.items():
# Get URI for property
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
prop_uri = prop_def['uri']
else:
prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}"
# rdf:type owl:ObjectProperty
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'labels' in prop_def:
labels = prop_def['labels']
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
comment = prop_def['comment']
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
))
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']:
domain = prop_def['domain']
# Get domain class URI
if domain in ontology_subset.classes:
domain_class_def = ontology_subset.classes[domain]
if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']:
domain_uri = domain_class_def['uri']
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=Value(value=domain_uri, is_uri=True)
))
# rdfs:range (stored as 'range' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'range' in prop_def and prop_def['range']:
range_val = prop_def['range']
# Get range class URI
if range_val in ontology_subset.classes:
range_class_def = ontology_subset.classes[range_val]
if isinstance(range_class_def, dict) and 'uri' in range_class_def and range_class_def['uri']:
range_uri = range_class_def['uri']
else:
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
else:
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=Value(value=range_uri, is_uri=True)
))
# Generate triples for datatype properties
for prop_id, prop_def in ontology_subset.datatype_properties.items():
# Get URI for property
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
prop_uri = prop_def['uri']
else:
prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}"
# rdf:type owl:DatatypeProperty
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'labels' in prop_def:
labels = prop_def['labels']
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
comment = prop_def['comment']
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
))
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']:
domain = prop_def['domain']
# Get domain class URI
if domain in ontology_subset.classes:
domain_class_def = ontology_subset.classes[domain]
if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']:
domain_uri = domain_class_def['uri']
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=Value(value=domain_uri, is_uri=True)
))
# rdfs:range (datatype)
if isinstance(prop_def, dict) and 'rdfs:range' in prop_def and prop_def['rdfs:range']:
range_val = prop_def['rdfs:range']
# Range for datatype properties is usually xsd:string, xsd:int, etc.
if range_val.startswith('xsd:'):
range_uri = f"http://www.w3.org/2001/XMLSchema#{range_val[4:]}"
else:
range_uri = range_val
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=Value(value=range_uri, is_uri=True)
))
logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements")
return ontology_triples
def build_entity_contexts(self, triples: List[Triple]) -> List[EntityContext]:
"""Build entity contexts from extracted triples.
Collects rdfs:label and definition properties for each entity to create
contextual descriptions for embedding.
Args:
triples: List of extracted triples
Returns:
List of EntityContext objects
"""
# Group triples by subject to collect entity information
entity_data = {} # subject_uri -> {labels: [], definitions: []}
for triple in triples:
subject_uri = triple.s.value
predicate_uri = triple.p.value
object_val = triple.o.value
# Initialize entity data if not exists
if subject_uri not in entity_data:
entity_data[subject_uri] = {'labels': [], 'definitions': []}
# Collect labels (rdfs:label)
if predicate_uri == RDF_LABEL:
if not triple.o.is_uri: # Labels are literals
entity_data[subject_uri]['labels'].append(object_val)
# Collect definitions (skos:definition, schema:description)
elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description":
if not triple.o.is_uri:
entity_data[subject_uri]['definitions'].append(object_val)
# Build EntityContext objects
entity_contexts = []
for subject_uri, data in entity_data.items():
# Build context text from labels and definitions
context_parts = []
if data['labels']:
context_parts.append(f"Label: {data['labels'][0]}")
if data['definitions']:
context_parts.extend(data['definitions'])
# Only create EntityContext if we have meaningful context
if context_parts:
context_text = ". ".join(context_parts)
entity_contexts.append(EntityContext(
entity=Value(value=subject_uri, is_uri=True),
context=context_text
))
logger.debug(f"Built {len(entity_contexts)} entity contexts from {len(triples)} triples")
return entity_contexts
@staticmethod
def add_args(parser):
"""Add command-line arguments."""
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
parser.add_argument(
'--top-k',
type=int,
default=10,
help='Number of top ontology elements to retrieve (default: 10)'
)
parser.add_argument(
'--similarity-threshold',
type=float,
default=0.3,
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
)
FlowProcessor.add_args(parser)
def run():
"""Launch the OntoRAG extraction service."""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,310 @@
"""
Ontology embedder component for OntoRAG system.
Generates and stores embeddings for ontology elements.
"""
import asyncio
import logging
import numpy as np
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from .ontology_loader import Ontology, OntologyClass, OntologyProperty
from .vector_store import InMemoryVectorStore
logger = logging.getLogger(__name__)
@dataclass
class OntologyElementMetadata:
"""Metadata for an embedded ontology element."""
type: str # 'class', 'objectProperty', 'datatypeProperty'
ontology: str # Ontology ID
element: str # Element ID
definition: Dict[str, Any] # Full element definition
text: str # Text used for embedding
class OntologyEmbedder:
"""Generates embeddings for ontology elements and stores them in vector store."""
def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None):
"""Initialize the ontology embedder.
Args:
embedding_service: Service for generating embeddings
vector_store: Vector store instance (InMemoryVectorStore)
"""
self.embedding_service = embedding_service
self.vector_store = vector_store or InMemoryVectorStore()
self.embedded_ontologies = set()
def _create_text_representation(self, element_id: str, element: Any,
element_type: str) -> str:
"""Create text representation of an ontology element for embedding.
Args:
element_id: ID of the element
element: The element object (OntologyClass or OntologyProperty)
element_type: Type of element
Returns:
Text representation for embedding
"""
parts = []
# Add the element ID (often meaningful)
parts.append(element_id.replace('-', ' ').replace('_', ' '))
# Add labels
if hasattr(element, 'labels') and element.labels:
for label in element.labels:
if isinstance(label, dict):
parts.append(label.get('value', ''))
else:
parts.append(str(label))
# Add comment/description
if hasattr(element, 'comment') and element.comment:
parts.append(element.comment)
# Add type-specific information
if element_type == 'class':
if hasattr(element, 'subclass_of') and element.subclass_of:
parts.append(f"subclass of {element.subclass_of}")
elif element_type in ['objectProperty', 'datatypeProperty']:
if hasattr(element, 'domain') and element.domain:
parts.append(f"domain: {element.domain}")
if hasattr(element, 'range') and element.range:
parts.append(f"range: {element.range}")
# Join all parts with spaces
text = ' '.join(filter(None, parts))
return text
async def embed_ontology(self, ontology: Ontology) -> int:
"""Generate and store embeddings for all elements in an ontology.
Args:
ontology: The ontology to embed
Returns:
Number of elements embedded
"""
if not self.embedding_service:
logger.warning("No embedding service available, skipping embedding")
return 0
embedded_count = 0
batch_size = 50 # Process embeddings in batches
# Collect all elements to embed
elements_to_embed = []
# Process classes
for class_id, class_def in ontology.classes.items():
text = self._create_text_representation(class_id, class_def, 'class')
elements_to_embed.append({
'id': f"{ontology.id}:class:{class_id}",
'text': text,
'metadata': OntologyElementMetadata(
type='class',
ontology=ontology.id,
element=class_id,
definition=class_def.__dict__,
text=text
).__dict__
})
# Process object properties
for prop_id, prop_def in ontology.object_properties.items():
text = self._create_text_representation(prop_id, prop_def, 'objectProperty')
elements_to_embed.append({
'id': f"{ontology.id}:objectProperty:{prop_id}",
'text': text,
'metadata': OntologyElementMetadata(
type='objectProperty',
ontology=ontology.id,
element=prop_id,
definition=prop_def.__dict__,
text=text
).__dict__
})
# Process datatype properties
for prop_id, prop_def in ontology.datatype_properties.items():
text = self._create_text_representation(prop_id, prop_def, 'datatypeProperty')
elements_to_embed.append({
'id': f"{ontology.id}:datatypeProperty:{prop_id}",
'text': text,
'metadata': OntologyElementMetadata(
type='datatypeProperty',
ontology=ontology.id,
element=prop_id,
definition=prop_def.__dict__,
text=text
).__dict__
})
# Process in batches
for i in range(0, len(elements_to_embed), batch_size):
batch = elements_to_embed[i:i + batch_size]
# Get embeddings for batch
texts = [elem['text'] for elem in batch]
try:
# Call embedding service for each text
# Note: embed() returns 2D array [[vector]], so extract first element
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract vectors from responses (each is [[vector]])
embeddings_list = [resp[0] for resp in embeddings_responses]
# Convert to numpy array
embeddings = np.array(embeddings_list)
# Log embedding shape for debugging
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
# Store in vector store
ids = [elem['id'] for elem in batch]
metadata_list = [elem['metadata'] for elem in batch]
self.vector_store.add_batch(ids, embeddings, metadata_list)
embedded_count += len(batch)
logger.debug(f"Embedded batch of {len(batch)} elements from ontology {ontology.id}")
except Exception as e:
logger.error(f"Failed to embed batch for ontology {ontology.id}: {e}", exc_info=True)
self.embedded_ontologies.add(ontology.id)
logger.info(f"Embedded {embedded_count} elements from ontology {ontology.id}")
return embedded_count
async def embed_ontologies(self, ontologies: Dict[str, Ontology]) -> int:
"""Generate and store embeddings for multiple ontologies.
Args:
ontologies: Dictionary of ontology ID to Ontology objects
Returns:
Total number of elements embedded
"""
total_embedded = 0
for ont_id, ontology in ontologies.items():
if ont_id not in self.embedded_ontologies:
count = await self.embed_ontology(ontology)
total_embedded += count
else:
logger.debug(f"Ontology {ont_id} already embedded, skipping")
logger.info(f"Total embedded elements: {total_embedded} from {len(ontologies)} ontologies")
return total_embedded
async def embed_text(self, text: str) -> Optional[np.ndarray]:
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding vector or None if failed
"""
if not self.embedding_service:
logger.warning("No embedding service available")
return None
try:
# embed() returns 2D array [[vector]], extract first element
embedding_response = await self.embedding_service.embed(text)
return np.array(embedding_response[0])
except Exception as e:
logger.error(f"Failed to embed text: {e}")
return None
async def embed_texts(self, texts: List[str]) -> Optional[np.ndarray]:
"""Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
Array of embeddings or None if failed
"""
if not self.embedding_service:
logger.warning("No embedding service available")
return None
try:
# Call embed() for each text (returns [[vector]] per call)
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract first vector from each response
embeddings_list = [resp[0] for resp in embeddings_responses]
return np.array(embeddings_list)
except Exception as e:
logger.error(f"Failed to embed texts: {e}")
return None
def remove_ontology(self, ontology_id: str):
"""Remove all embeddings for a specific ontology.
Note: FAISS doesn't support efficient deletion, so this currently
requires rebuilding the entire index without the removed ontology.
Args:
ontology_id: ID of ontology to remove
"""
if ontology_id not in self.embedded_ontologies:
logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove")
return
# FAISS doesn't support selective deletion, so we'd need to rebuild the index
# For now, just remove from tracking set
# TODO: Implement index rebuilding if selective removal is needed
self.embedded_ontologies.discard(ontology_id)
logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)")
def clear_embeddings(self, ontology_id: Optional[str] = None):
"""Clear embeddings from vector store.
Args:
ontology_id: If provided, only clear embeddings for this ontology
Otherwise, clear all embeddings
"""
if ontology_id:
self.remove_ontology(ontology_id)
else:
self.vector_store.clear()
self.embedded_ontologies.clear()
logger.info("Cleared all embeddings from vector store")
def get_vector_store(self) -> InMemoryVectorStore:
"""Get the vector store instance.
Returns:
The vector store being used
"""
return self.vector_store
def get_embedded_count(self) -> int:
"""Get the number of embedded elements.
Returns:
Number of elements in the vector store
"""
return self.vector_store.size()
def is_ontology_embedded(self, ontology_id: str) -> bool:
"""Check if an ontology has been embedded.
Args:
ontology_id: ID of the ontology
Returns:
True if the ontology has been embedded
"""
return ontology_id in self.embedded_ontologies

View file

@ -0,0 +1,247 @@
"""
Ontology loader component for OntoRAG system.
Loads and manages ontologies from configuration service.
"""
import json
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class OntologyClass:
"""Represents an OWL-like class in the ontology."""
uri: str
type: str = "owl:Class"
labels: List[Dict[str, str]] = field(default_factory=list)
comment: Optional[str] = None
subclass_of: Optional[str] = None
equivalent_classes: List[str] = field(default_factory=list)
disjoint_with: List[str] = field(default_factory=list)
identifier: Optional[str] = None
@staticmethod
def from_dict(class_id: str, data: Dict[str, Any]) -> 'OntologyClass':
"""Create OntologyClass from dictionary representation."""
labels = data.get('rdfs:label', [])
if isinstance(labels, list):
labels = labels
else:
labels = [labels] if labels else []
return OntologyClass(
uri=data.get('uri', ''),
type=data.get('type', 'owl:Class'),
labels=labels,
comment=data.get('rdfs:comment'),
subclass_of=data.get('rdfs:subClassOf'),
equivalent_classes=data.get('owl:equivalentClass', []),
disjoint_with=data.get('owl:disjointWith', []),
identifier=data.get('dcterms:identifier')
)
@dataclass
class OntologyProperty:
"""Represents a property (object or datatype) in the ontology."""
uri: str
type: str
labels: List[Dict[str, str]] = field(default_factory=list)
comment: Optional[str] = None
domain: Optional[str] = None
range: Optional[str] = None
inverse_of: Optional[str] = None
functional: bool = False
inverse_functional: bool = False
min_cardinality: Optional[int] = None
max_cardinality: Optional[int] = None
cardinality: Optional[int] = None
@staticmethod
def from_dict(prop_id: str, data: Dict[str, Any]) -> 'OntologyProperty':
"""Create OntologyProperty from dictionary representation."""
labels = data.get('rdfs:label', [])
if isinstance(labels, list):
labels = labels
else:
labels = [labels] if labels else []
return OntologyProperty(
uri=data.get('uri', ''),
type=data.get('type', ''),
labels=labels,
comment=data.get('rdfs:comment'),
domain=data.get('rdfs:domain'),
range=data.get('rdfs:range'),
inverse_of=data.get('owl:inverseOf'),
functional=data.get('owl:functionalProperty', False),
inverse_functional=data.get('owl:inverseFunctionalProperty', False),
min_cardinality=data.get('owl:minCardinality'),
max_cardinality=data.get('owl:maxCardinality'),
cardinality=data.get('owl:cardinality')
)
@dataclass
class Ontology:
"""Represents a complete ontology with metadata, classes, and properties."""
id: str
metadata: Dict[str, Any]
classes: Dict[str, OntologyClass]
object_properties: Dict[str, OntologyProperty]
datatype_properties: Dict[str, OntologyProperty]
def get_class(self, class_id: str) -> Optional[OntologyClass]:
"""Get a class by ID."""
return self.classes.get(class_id)
def get_property(self, prop_id: str) -> Optional[OntologyProperty]:
"""Get a property (object or datatype) by ID."""
prop = self.object_properties.get(prop_id)
if prop is None:
prop = self.datatype_properties.get(prop_id)
return prop
def get_parent_classes(self, class_id: str) -> List[str]:
"""Get all parent classes (following subClassOf hierarchy)."""
parents = []
current = class_id
visited = set()
while current and current not in visited:
visited.add(current)
cls = self.get_class(current)
if cls and cls.subclass_of:
parents.append(cls.subclass_of)
current = cls.subclass_of
else:
break
return parents
def validate_structure(self) -> List[str]:
"""Validate ontology structure and return list of issues."""
issues = []
# Check for circular inheritance
for class_id in self.classes:
visited = set()
current = class_id
while current:
if current in visited:
issues.append(f"Circular inheritance detected for class {class_id}")
break
visited.add(current)
cls = self.get_class(current)
if cls:
current = cls.subclass_of
else:
break
# Check property domains and ranges exist
for prop_id, prop in {**self.object_properties, **self.datatype_properties}.items():
if prop.domain and prop.domain not in self.classes:
issues.append(f"Property {prop_id} has unknown domain {prop.domain}")
if prop.type == "owl:ObjectProperty" and prop.range and prop.range not in self.classes:
issues.append(f"Object property {prop_id} has unknown range class {prop.range}")
# Check disjoint classes
for class_id, cls in self.classes.items():
for disjoint_id in cls.disjoint_with:
if disjoint_id not in self.classes:
issues.append(f"Class {class_id} disjoint with unknown class {disjoint_id}")
return issues
class OntologyLoader:
"""Manages ontologies received via event-driven config updates.
No direct database access - receives ontologies via config handler.
"""
def __init__(self):
"""Initialize empty ontology store."""
self.ontologies: Dict[str, Ontology] = {}
def update_ontologies(self, ontology_configs: Dict[str, Any]):
"""Update ontology definitions from config.
Args:
ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts)
"""
self.ontologies.clear()
for ont_id, ont_data in ontology_configs.items():
try:
# Parse classes
classes = {}
for class_id, class_data in ont_data.get('classes', {}).items():
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
# Parse object properties
object_props = {}
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
# Parse datatype properties
datatype_props = {}
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
# Create ontology
ontology = Ontology(
id=ont_id,
metadata=ont_data.get('metadata', {}),
classes=classes,
object_properties=object_props,
datatype_properties=datatype_props
)
# Validate structure
issues = ontology.validate_structure()
if issues:
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
self.ontologies[ont_id] = ontology
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
f"{len(object_props)} object properties, "
f"{len(datatype_props)} datatype properties")
except Exception as e:
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
def get_ontology(self, ont_id: str) -> Optional[Ontology]:
"""Get a specific ontology by ID.
Args:
ont_id: Ontology identifier
Returns:
Ontology object or None if not found
"""
return self.ontologies.get(ont_id)
def get_all_ontologies(self) -> Dict[str, Ontology]:
"""Get all loaded ontologies.
Returns:
Dictionary of ontology ID to Ontology objects
"""
return self.ontologies
def list_ontology_ids(self) -> List[str]:
"""Get list of loaded ontology IDs.
Returns:
List of ontology IDs
"""
return list(self.ontologies.keys())
def clear(self):
"""Clear all loaded ontologies."""
self.ontologies.clear()
logger.info("Cleared all loaded ontologies")

View file

@ -0,0 +1,356 @@
"""
Ontology selection algorithm for OntoRAG system.
Selects relevant ontology subsets based on text similarity.
"""
import logging
from typing import List, Dict, Any, Set, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict
from .ontology_loader import Ontology, OntologyLoader
from .ontology_embedder import OntologyEmbedder
from .text_processor import TextSegment
from .vector_store import SearchResult
logger = logging.getLogger(__name__)
@dataclass
class OntologySubset:
"""Represents a subset of an ontology relevant to a text chunk."""
ontology_id: str
classes: Dict[str, Any]
object_properties: Dict[str, Any]
datatype_properties: Dict[str, Any]
metadata: Dict[str, Any]
relevance_score: float = 0.0
class OntologySelector:
"""Selects relevant ontology elements for text segments using vector similarity."""
def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader,
top_k: int = 10,
similarity_threshold: float = 0.7):
"""Initialize the ontology selector.
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve per segment
similarity_threshold: Minimum similarity score
"""
self.embedder = ontology_embedder
self.loader = ontology_loader
self.top_k = top_k
self.similarity_threshold = similarity_threshold
async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]:
"""Select relevant ontology subsets for text segments.
Args:
segments: List of text segments to match
Returns:
List of ontology subsets with relevant elements
"""
# Collect all relevant elements
relevant_elements = await self._find_relevant_elements(segments)
# Group by ontology and build subsets
ontology_subsets = self._build_ontology_subsets(relevant_elements)
# Resolve dependencies
for subset in ontology_subsets:
self._resolve_dependencies(subset)
logger.info(f"Selected {len(ontology_subsets)} ontology subsets")
return ontology_subsets
async def _find_relevant_elements(self, segments: List[TextSegment]) -> Set[Tuple[str, str, str, Dict]]:
"""Find relevant ontology elements for text segments.
Args:
segments: Text segments to match
Returns:
Set of (ontology_id, element_type, element_id, definition) tuples
"""
relevant_elements = set()
element_scores = defaultdict(float)
# Check if vector store has any elements
vector_store = self.embedder.get_vector_store()
store_size = vector_store.size()
logger.info(f"Vector store size: {store_size} elements")
if store_size == 0:
logger.warning("Vector store is empty - no ontology elements embedded")
return relevant_elements
# Process each segment (log first few for debugging)
for i, segment in enumerate(segments):
# Get embedding for segment
embedding = await self.embedder.embed_text(segment.text)
if embedding is None:
logger.warning(f"Failed to embed segment: {segment.text[:50]}...")
continue
# Search vector store with no threshold to see all scores
all_results = vector_store.search(
embedding=embedding,
top_k=self.top_k,
threshold=0.0 # Get all results to see scores
)
# Log top scores for first 3 segments to debug
if i < 3 and all_results:
top_scores = [r.score for r in all_results[:3]]
top_elements = [r.metadata['element'] for r in all_results[:3]]
logger.info(f"Segment {i}: '{segment.text[:60]}...'")
logger.info(f" Top 3 scores: {top_scores} (threshold={self.similarity_threshold})")
logger.info(f" Top 3 elements: {top_elements}")
# Filter by threshold
results = [r for r in all_results if r.score >= self.similarity_threshold]
# Process results
for result in results:
metadata = result.metadata
element_key = (
metadata['ontology'],
metadata['type'],
metadata['element'],
str(metadata['definition']) # Convert dict to string for hashability
)
relevant_elements.add(element_key)
# Track scores for ranking
element_scores[element_key] = max(element_scores[element_key], result.score)
logger.info(f"Found {len(relevant_elements)} relevant elements from {len(segments)} segments")
return relevant_elements
def _build_ontology_subsets(self, relevant_elements: Set[Tuple[str, str, str, Dict]]) -> List[OntologySubset]:
"""Build ontology subsets from relevant elements.
Args:
relevant_elements: Set of relevant element tuples
Returns:
List of ontology subsets
"""
# Group elements by ontology
ontology_groups = defaultdict(lambda: {
'classes': {},
'object_properties': {},
'datatype_properties': {},
'scores': []
})
for ont_id, elem_type, elem_id, definition in relevant_elements:
# Parse definition back from string if needed
if isinstance(definition, str):
import json
try:
definition = json.loads(definition.replace("'", '"'))
except:
definition = eval(definition) # Fallback for dict-like strings
# Get the actual ontology and element
ontology = self.loader.get_ontology(ont_id)
if not ontology:
logger.warning(f"Ontology {ont_id} not found in loader")
continue
# Add element to appropriate category
if elem_type == 'class':
cls = ontology.get_class(elem_id)
if cls:
ontology_groups[ont_id]['classes'][elem_id] = cls.__dict__
elif elem_type == 'objectProperty':
prop = ontology.object_properties.get(elem_id)
if prop:
ontology_groups[ont_id]['object_properties'][elem_id] = prop.__dict__
elif elem_type == 'datatypeProperty':
prop = ontology.datatype_properties.get(elem_id)
if prop:
ontology_groups[ont_id]['datatype_properties'][elem_id] = prop.__dict__
# Create OntologySubset objects
subsets = []
for ont_id, elements in ontology_groups.items():
ontology = self.loader.get_ontology(ont_id)
if ontology:
subset = OntologySubset(
ontology_id=ont_id,
classes=elements['classes'],
object_properties=elements['object_properties'],
datatype_properties=elements['datatype_properties'],
metadata=ontology.metadata,
relevance_score=sum(elements['scores']) / len(elements['scores']) if elements['scores'] else 0.0
)
subsets.append(subset)
return subsets
def _resolve_dependencies(self, subset: OntologySubset):
"""Resolve dependencies for ontology subset elements.
Args:
subset: Ontology subset to resolve dependencies for
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# Track classes to add
classes_to_add = set()
# Resolve class hierarchies
for class_id in list(subset.classes.keys()):
# Add parent classes
parents = ontology.get_parent_classes(class_id)
for parent_id in parents:
parent_class = ontology.get_class(parent_id)
if parent_class and parent_id not in subset.classes:
classes_to_add.add(parent_id)
# Resolve property domains and ranges
for prop_id, prop_def in subset.object_properties.items():
# Add domain class
if 'domain' in prop_def and prop_def['domain']:
domain_id = prop_def['domain']
if domain_id not in subset.classes:
domain_class = ontology.get_class(domain_id)
if domain_class:
classes_to_add.add(domain_id)
# Add range class
if 'range' in prop_def and prop_def['range']:
range_id = prop_def['range']
if range_id not in subset.classes:
range_class = ontology.get_class(range_id)
if range_class:
classes_to_add.add(range_id)
# Resolve datatype property domains
for prop_id, prop_def in subset.datatype_properties.items():
if 'domain' in prop_def and prop_def['domain']:
domain_id = prop_def['domain']
if domain_id not in subset.classes:
domain_class = ontology.get_class(domain_id)
if domain_class:
classes_to_add.add(domain_id)
# Add inverse properties
for prop_id, prop_def in list(subset.object_properties.items()):
if 'inverse_of' in prop_def and prop_def['inverse_of']:
inverse_id = prop_def['inverse_of']
if inverse_id not in subset.object_properties:
inverse_prop = ontology.object_properties.get(inverse_id)
if inverse_prop:
subset.object_properties[inverse_id] = inverse_prop.__dict__
# NEW: Auto-include properties related to selected classes
# For each selected class, find all properties that reference it in domain or range
properties_added = 0
datatype_properties_added = 0
for class_id in list(subset.classes.keys()):
# Check all object properties in the ontology
for prop_id, prop_def in ontology.object_properties.items():
if prop_id not in subset.object_properties:
# Check if this class is in the property's domain or range
prop_domain = getattr(prop_def, 'domain', None)
prop_range = getattr(prop_def, 'range', None)
if prop_domain == class_id or prop_range == class_id:
subset.object_properties[prop_id] = prop_def.__dict__
properties_added += 1
# Also add the other class (domain or range) if not already present
if prop_domain and prop_domain != class_id and prop_domain not in subset.classes:
other_class = ontology.get_class(prop_domain)
if other_class:
classes_to_add.add(prop_domain)
if prop_range and prop_range != class_id and prop_range not in subset.classes:
other_class = ontology.get_class(prop_range)
if other_class:
classes_to_add.add(prop_range)
# Check all datatype properties in the ontology
for prop_id, prop_def in ontology.datatype_properties.items():
if prop_id not in subset.datatype_properties:
# Check if this class is in the property's domain
prop_domain = getattr(prop_def, 'domain', None)
if prop_domain == class_id:
subset.datatype_properties[prop_id] = prop_def.__dict__
datatype_properties_added += 1
# Add collected classes
for class_id in classes_to_add:
cls = ontology.get_class(class_id)
if cls:
subset.classes[class_id] = cls.__dict__
logger.debug(f"Resolved dependencies for subset {subset.ontology_id}: "
f"added {len(classes_to_add)} classes, "
f"{properties_added} object properties, "
f"{datatype_properties_added} datatype properties")
def merge_subsets(self, subsets: List[OntologySubset]) -> OntologySubset:
"""Merge multiple ontology subsets into one.
Args:
subsets: List of subsets to merge
Returns:
Merged ontology subset
"""
if not subsets:
return None
if len(subsets) == 1:
return subsets[0]
# Use first subset as base
merged = OntologySubset(
ontology_id="merged",
classes={},
object_properties={},
datatype_properties={},
metadata={},
relevance_score=0.0
)
# Merge all subsets
total_score = 0.0
for subset in subsets:
# Merge classes
for class_id, class_def in subset.classes.items():
key = f"{subset.ontology_id}:{class_id}"
merged.classes[key] = class_def
# Merge object properties
for prop_id, prop_def in subset.object_properties.items():
key = f"{subset.ontology_id}:{prop_id}"
merged.object_properties[key] = prop_def
# Merge datatype properties
for prop_id, prop_def in subset.datatype_properties.items():
key = f"{subset.ontology_id}:{prop_id}"
merged.datatype_properties[key] = prop_def
total_score += subset.relevance_score
# Average relevance score
merged.relevance_score = total_score / len(subsets)
logger.info(f"Merged {len(subsets)} subsets into one with "
f"{len(merged.classes)} classes, "
f"{len(merged.object_properties)} object properties, "
f"{len(merged.datatype_properties)} datatype properties")
return merged

View file

@ -0,0 +1,10 @@
#!/usr/bin/env python3
"""
OntoRAG extraction service launcher.
"""
from . extract import run
if __name__ == "__main__":
run()

View file

@ -0,0 +1,240 @@
"""
Text processing components for OntoRAG system.
Splits text into sentences and extracts phrases for granular matching.
"""
import logging
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import nltk
from nltk.corpus import stopwords
logger = logging.getLogger(__name__)
# Ensure required NLTK data is downloaded
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
try:
nltk.download('punkt_tab', quiet=True)
except:
# Fallback to older punkt if punkt_tab not available
try:
nltk.download('punkt', quiet=True)
except:
pass
try:
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
try:
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
except:
# Fallback to older name
try:
nltk.download('averaged_perceptron_tagger', quiet=True)
except:
pass
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords', quiet=True)
@dataclass
class TextSegment:
"""Represents a segment of text (sentence or phrase)."""
text: str
type: str # 'sentence', 'phrase', 'noun_phrase', 'verb_phrase'
position: int
parent_sentence: Optional[str] = None
metadata: Dict[str, Any] = None
class SentenceSplitter:
"""Splits text into sentences using NLTK."""
def __init__(self):
"""Initialize sentence splitter."""
try:
# Try newer punkt_tab first
self.sent_detector = nltk.data.load('tokenizers/punkt_tab/english/')
logger.info("Using NLTK sentence tokenizer (punkt_tab)")
except:
# Fallback to older punkt
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
logger.info("Using NLTK sentence tokenizer (punkt)")
def split(self, text: str) -> List[str]:
"""Split text into sentences.
Args:
text: Text to split
Returns:
List of sentences
"""
sentences = self.sent_detector.tokenize(text)
return sentences
class PhraseExtractor:
"""Extracts meaningful phrases from sentences using NLTK."""
def __init__(self):
"""Initialize phrase extractor."""
logger.info("Using NLTK phrase extraction")
def extract(self, sentence: str) -> List[Dict[str, str]]:
"""Extract phrases from a sentence.
Args:
sentence: Sentence to extract phrases from
Returns:
List of phrases with their types
"""
phrases = []
# Tokenize and POS tag
tokens = nltk.word_tokenize(sentence)
pos_tags = nltk.pos_tag(tokens)
# Extract noun phrases (simple pattern)
noun_phrase = []
for word, pos in pos_tags:
if pos.startswith('NN') or pos.startswith('JJ'):
noun_phrase.append(word)
elif noun_phrase:
if len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
noun_phrase = []
# Add last noun phrase if exists
if noun_phrase and len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
# Extract verb phrases (simple pattern)
verb_phrase = []
for word, pos in pos_tags:
if pos.startswith('VB') or pos.startswith('RB'):
verb_phrase.append(word)
elif verb_phrase:
if len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
verb_phrase = []
# Add last verb phrase if exists
if verb_phrase and len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
return phrases
class TextProcessor:
"""Main text processing class that coordinates sentence splitting and phrase extraction."""
def __init__(self):
"""Initialize text processor."""
self.sentence_splitter = SentenceSplitter()
self.phrase_extractor = PhraseExtractor()
def process_chunk(self, chunk_text: str, extract_phrases: bool = True) -> List[TextSegment]:
"""Process a text chunk into segments.
Args:
chunk_text: Text chunk to process
extract_phrases: Whether to extract phrases from sentences
Returns:
List of TextSegment objects
"""
segments = []
position = 0
# Split into sentences
sentences = self.sentence_splitter.split(chunk_text)
for sentence in sentences:
# Add sentence segment
segments.append(TextSegment(
text=sentence,
type='sentence',
position=position
))
position += 1
# Extract phrases if requested
if extract_phrases:
phrases = self.phrase_extractor.extract(sentence)
for phrase_data in phrases:
segments.append(TextSegment(
text=phrase_data['text'],
type=phrase_data['type'],
position=position,
parent_sentence=sentence
))
position += 1
logger.debug(f"Processed chunk into {len(segments)} segments")
return segments
def extract_key_terms(self, text: str) -> List[str]:
"""Extract key terms from text for matching.
Args:
text: Text to extract terms from
Returns:
List of key terms
"""
terms = []
# Split on word boundaries
words = re.findall(r'\b\w+\b', text.lower())
# Use NLTK stopwords
stop_words = set(stopwords.words('english'))
# Filter stopwords and short words
terms = [w for w in words if w not in stop_words and len(w) > 2]
# Also extract multi-word terms (bigrams)
for i in range(len(words) - 1):
if words[i] not in stop_words and words[i+1] not in stop_words:
bigram = f"{words[i]} {words[i+1]}"
terms.append(bigram)
return terms
def normalize_text(self, text: str) -> str:
"""Normalize text for consistent processing.
Args:
text: Text to normalize
Returns:
Normalized text
"""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text)
# Remove leading/trailing whitespace
text = text.strip()
# Normalize quotes
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
return text

View file

@ -0,0 +1,130 @@
"""
Vector store implementation for OntoRAG system.
Provides FAISS-based vector storage for ontology embeddings.
"""
import logging
import numpy as np
from typing import List, Dict, Any
from dataclasses import dataclass
import faiss
logger = logging.getLogger(__name__)
@dataclass
class SearchResult:
"""Result from vector similarity search."""
id: str
score: float
metadata: Dict[str, Any]
class InMemoryVectorStore:
"""FAISS-based vector store implementation for ontology embeddings."""
def __init__(self, dimension: int = 1536, index_type: str = 'flat'):
"""Initialize FAISS vector store.
Args:
dimension: Embedding dimension (1536 for text-embedding-3-small)
index_type: 'flat' for exact search, 'ivf' for larger datasets
"""
self.dimension = dimension
self.metadata = []
self.ids = []
if index_type == 'flat':
# Exact search - best for ontologies with <10k elements
self.index = faiss.IndexFlatIP(dimension)
logger.info(f"Created FAISS flat index with dimension {dimension}")
else:
# Approximate search - for larger ontologies
quantizer = faiss.IndexFlatIP(dimension)
self.index = faiss.IndexIVFFlat(quantizer, dimension, 100)
# Train with random vectors for initialization
training_data = np.random.randn(1000, dimension).astype('float32')
training_data = training_data / np.linalg.norm(
training_data, axis=1, keepdims=True
)
self.index.train(training_data)
logger.info(f"Created FAISS IVF index with dimension {dimension}")
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
"""Add single embedding with metadata."""
# Normalize for cosine similarity
embedding = embedding / np.linalg.norm(embedding)
self.index.add(np.array([embedding], dtype=np.float32))
self.metadata.append(metadata)
self.ids.append(id)
def add_batch(self, ids: List[str], embeddings: np.ndarray,
metadata_list: List[Dict[str, Any]]):
"""Batch add for initial ontology loading."""
# Normalize all embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
normalized = embeddings / norms
self.index.add(normalized.astype(np.float32))
self.metadata.extend(metadata_list)
self.ids.extend(ids)
logger.debug(f"Added batch of {len(ids)} embeddings to FAISS index")
def search(self, embedding: np.ndarray, top_k: int = 10,
threshold: float = 0.0) -> List[SearchResult]:
"""Search for similar vectors."""
# Normalize query
embedding = embedding / np.linalg.norm(embedding)
# Search
scores, indices = self.index.search(
np.array([embedding], dtype=np.float32),
min(top_k, self.index.ntotal)
)
# Filter by threshold and format results
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0 and score >= threshold: # FAISS returns -1 for empty slots
results.append(SearchResult(
id=self.ids[idx],
score=float(score),
metadata=self.metadata[idx]
))
return results
def clear(self):
"""Reset the store."""
self.index.reset()
self.metadata = []
self.ids = []
logger.info("Cleared FAISS vector store")
def size(self) -> int:
"""Return number of stored vectors."""
return self.index.ntotal
# Utility functions for vector operations
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndarray:
"""Compute cosine similarity between query vectors and target vectors.
Args:
queries: Array of shape (n_queries, dimension)
targets: Array of shape (n_targets, dimension)
Returns:
Array of shape (n_queries, n_targets) with similarity scores
"""
# Normalize queries and targets
queries_norm = queries / np.linalg.norm(queries, axis=1, keepdims=True)
targets_norm = targets / np.linalg.norm(targets, axis=1, keepdims=True)
# Compute dot product
similarities = np.dot(queries_norm, targets_norm.T)
return similarities

View file

@ -87,12 +87,6 @@ class Processor(LlmService):
],
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
inputtokens = resp.usage.prompt_tokens

View file

@ -47,39 +47,6 @@ class Processor(DocumentEmbeddingsQueryService):
}
)
self.last_index_name = None
def ensure_index_exists(self, index_name, dim):
"""Ensure index exists, create if it doesn't"""
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.pinecone.create_index(
name=index_name,
dimension=dim,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1",
)
)
logger.info(f"Created index: {index_name}")
# Wait for index to be ready
import time
for i in range(0, 1000):
if self.pinecone.describe_index(index_name).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(index_name).status["ready"]:
raise RuntimeError("Gave up waiting for index creation")
except Exception as e:
logger.error(f"Pinecone index creation failed: {e}")
raise e
self.last_index_name = index_name
async def query_document_embeddings(self, msg):
try:
@ -94,11 +61,13 @@ class Processor(DocumentEmbeddingsQueryService):
dim = len(vec)
index_name = (
"d-" + msg.user + "-" + msg.collection
)
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
self.ensure_index_exists(index_name, dim)
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)

View file

@ -71,16 +71,17 @@ class Processor(DocumentEmbeddingsQueryService):
chunks = []
collection = (
"d_" + msg.user + "_" + msg.collection
)
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
for vec in msg.vectors:
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
continue
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,

View file

@ -49,39 +49,6 @@ class Processor(GraphEmbeddingsQueryService):
}
)
self.last_index_name = None
def ensure_index_exists(self, index_name, dim):
"""Ensure index exists, create if it doesn't"""
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.pinecone.create_index(
name=index_name,
dimension=dim,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1",
)
)
logger.info(f"Created index: {index_name}")
# Wait for index to be ready
import time
for i in range(0, 1000):
if self.pinecone.describe_index(index_name).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(index_name).status["ready"]:
raise RuntimeError("Gave up waiting for index creation")
except Exception as e:
logger.error(f"Pinecone index creation failed: {e}")
raise e
self.last_index_name = index_name
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
@ -103,11 +70,13 @@ class Processor(GraphEmbeddingsQueryService):
dim = len(vec)
index_name = (
"t-" + msg.user + "-" + msg.collection
)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
self.ensure_index_exists(index_name, dim)
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)

View file

@ -78,17 +78,17 @@ class Processor(GraphEmbeddingsQueryService):
entity_set = set()
entities = []
collection = (
"t_" + msg.user + "_" + msg.collection
)
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
for vec in msg.vectors:
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, skipping this vector")
continue
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(

View file

@ -0,0 +1,54 @@
"""
OntoRAG Query System.
Ontology-driven natural language query processing with multi-backend support.
Provides semantic query understanding, ontology matching, and answer generation.
"""
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .backend_router import BackendRouter, BackendType, QueryRoute
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
from .cypher_generator import CypherGenerator, CypherQuery
from .cypher_executor import CypherExecutor, CypherResult
from .answer_generator import AnswerGenerator, GeneratedAnswer, AnswerMetadata
__all__ = [
# Main service
'OntoRAGQueryService',
'QueryRequest',
'QueryResponse',
# Question analysis
'QuestionAnalyzer',
'QuestionComponents',
'QuestionType',
# Ontology matching
'OntologyMatcher',
'QueryOntologySubset',
# Backend routing
'BackendRouter',
'BackendType',
'QueryRoute',
# SPARQL components
'SPARQLGenerator',
'SPARQLQuery',
'SPARQLCassandraEngine',
'SPARQLResult',
# Cypher components
'CypherGenerator',
'CypherQuery',
'CypherExecutor',
'CypherResult',
# Answer generation
'AnswerGenerator',
'GeneratedAnswer',
'AnswerMetadata',
]

View file

@ -0,0 +1,521 @@
"""
Answer generator for natural language responses.
Converts query results into natural language answers using LLM assistance.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_cassandra import SPARQLResult
from .cypher_executor import CypherResult
logger = logging.getLogger(__name__)
@dataclass
class AnswerMetadata:
"""Metadata about answer generation."""
query_type: str
backend_used: str
execution_time: float
result_count: int
confidence: float
explanation: str
sources: List[str]
@dataclass
class GeneratedAnswer:
"""Generated natural language answer."""
answer: str
metadata: AnswerMetadata
supporting_facts: List[str]
raw_results: Union[SPARQLResult, CypherResult]
generation_time: float
class AnswerGenerator:
"""Generates natural language answers from query results."""
def __init__(self, prompt_service=None):
"""Initialize answer generator.
Args:
prompt_service: Service for LLM-based answer generation
"""
self.prompt_service = prompt_service
# Answer templates for different question types
self.templates = {
'count': "There are {count} {entity_type}.",
'boolean_true': "Yes, {statement} is true.",
'boolean_false': "No, {statement} is not true.",
'list': "The {entity_type} are: {items}.",
'single': "The {property} of {entity} is {value}.",
'none': "No results were found for your query.",
'error': "I encountered an error processing your query: {error}"
}
async def generate_answer(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset,
backend_used: str) -> GeneratedAnswer:
"""Generate natural language answer from query results.
Args:
question_components: Original question analysis
query_results: Results from query execution
ontology_subset: Ontology subset used
backend_used: Backend that executed the query
Returns:
Generated answer with metadata
"""
start_time = datetime.now()
try:
# Try LLM-based generation first
if self.prompt_service:
llm_answer = await self._generate_with_llm(
question_components, query_results, ontology_subset
)
if llm_answer:
execution_time = (datetime.now() - start_time).total_seconds()
return self._build_answer_response(
llm_answer, question_components, query_results,
backend_used, execution_time
)
# Fall back to template-based generation
template_answer = self._generate_with_template(
question_components, query_results, ontology_subset
)
execution_time = (datetime.now() - start_time).total_seconds()
return self._build_answer_response(
template_answer, question_components, query_results,
backend_used, execution_time
)
except Exception as e:
logger.error(f"Answer generation failed: {e}")
execution_time = (datetime.now() - start_time).total_seconds()
error_answer = self.templates['error'].format(error=str(e))
return self._build_answer_response(
error_answer, question_components, query_results,
backend_used, execution_time, confidence=0.0
)
async def _generate_with_llm(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Generate answer using LLM.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Generated answer or None if failed
"""
try:
prompt = self._build_answer_prompt(
question_components, query_results, ontology_subset
)
response = await self.prompt_service.generate_answer(prompt=prompt)
if response and isinstance(response, dict):
return response.get('answer', '').strip()
elif isinstance(response, str):
return response.strip()
except Exception as e:
logger.error(f"LLM answer generation failed: {e}")
return None
def _generate_with_template(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> str:
"""Generate answer using templates.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Template-based answer
"""
# Handle empty results
if not self._has_results(query_results):
return self.templates['none']
# Handle boolean queries
if question_components.question_type == QuestionType.BOOLEAN:
if hasattr(query_results, 'ask_result'):
# SPARQL ASK result
statement = self._extract_boolean_statement(question_components)
if query_results.ask_result:
return self.templates['boolean_true'].format(statement=statement)
else:
return self.templates['boolean_false'].format(statement=statement)
else:
# Cypher boolean (check if any results)
has_results = len(query_results.records) > 0
statement = self._extract_boolean_statement(question_components)
if has_results:
return self.templates['boolean_true'].format(statement=statement)
else:
return self.templates['boolean_false'].format(statement=statement)
# Handle count queries
if question_components.question_type == QuestionType.AGGREGATION:
count = self._extract_count(query_results)
entity_type = self._infer_entity_type(question_components, ontology_subset)
return self.templates['count'].format(count=count, entity_type=entity_type)
# Handle retrieval queries
if question_components.question_type == QuestionType.RETRIEVAL:
items = self._extract_items(query_results)
if len(items) == 1:
# Single result
entity = question_components.entities[0] if question_components.entities else "entity"
property_name = "value"
return self.templates['single'].format(
property=property_name, entity=entity, value=items[0]
)
else:
# Multiple results
entity_type = self._infer_entity_type(question_components, ontology_subset)
items_str = ", ".join(items)
return self.templates['list'].format(entity_type=entity_type, items=items_str)
# Handle factual queries
if question_components.question_type == QuestionType.FACTUAL:
facts = self._extract_facts(query_results)
return ". ".join(facts) if facts else self.templates['none']
# Default fallback
items = self._extract_items(query_results)
if items:
return f"Found: {', '.join(items[:5])}" + ("..." if len(items) > 5 else "")
else:
return self.templates['none']
def _build_answer_prompt(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> str:
"""Build prompt for LLM answer generation.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Formatted prompt string
"""
# Format results for prompt
results_str = self._format_results_for_prompt(query_results)
# Extract ontology context
context_classes = list(ontology_subset.classes.keys())[:5]
context_properties = list(ontology_subset.object_properties.keys())[:5]
prompt = f"""Generate a natural language answer for the following question based on the query results.
ORIGINAL QUESTION: {question_components.original_question}
QUESTION TYPE: {question_components.question_type.value}
EXPECTED ANSWER: {question_components.expected_answer_type}
ONTOLOGY CONTEXT:
- Classes: {', '.join(context_classes)}
- Properties: {', '.join(context_properties)}
QUERY RESULTS:
{results_str}
INSTRUCTIONS:
- Provide a clear, concise answer in natural language
- Use the original question's tone and style
- Include specific facts from the results
- If no results, explain that no information was found
- Be accurate and don't make assumptions beyond the data
- Limit response to 2-3 sentences unless the question requires more detail
ANSWER:"""
return prompt
def _format_results_for_prompt(self, query_results: Union[SPARQLResult, CypherResult]) -> str:
"""Format query results for prompt inclusion.
Args:
query_results: Query results to format
Returns:
Formatted results string
"""
if isinstance(query_results, SPARQLResult):
if hasattr(query_results, 'ask_result') and query_results.ask_result is not None:
return f"Boolean result: {query_results.ask_result}"
if not query_results.bindings:
return "No results found"
# Format SPARQL bindings
lines = []
for binding in query_results.bindings[:10]: # Limit to first 10
formatted = []
for var, value in binding.items():
if isinstance(value, dict):
formatted.append(f"{var}: {value.get('value', value)}")
else:
formatted.append(f"{var}: {value}")
lines.append("- " + ", ".join(formatted))
if len(query_results.bindings) > 10:
lines.append(f"... and {len(query_results.bindings) - 10} more results")
return "\n".join(lines)
else: # CypherResult
if not query_results.records:
return "No results found"
# Format Cypher records
lines = []
for record in query_results.records[:10]: # Limit to first 10
if isinstance(record, dict):
formatted = [f"{k}: {v}" for k, v in record.items()]
lines.append("- " + ", ".join(formatted))
else:
lines.append(f"- {record}")
if len(query_results.records) > 10:
lines.append(f"... and {len(query_results.records) - 10} more results")
return "\n".join(lines)
def _has_results(self, query_results: Union[SPARQLResult, CypherResult]) -> bool:
"""Check if query results contain data.
Args:
query_results: Query results to check
Returns:
True if results contain data
"""
if isinstance(query_results, SPARQLResult):
return bool(query_results.bindings) or query_results.ask_result is not None
else: # CypherResult
return bool(query_results.records)
def _extract_count(self, query_results: Union[SPARQLResult, CypherResult]) -> int:
"""Extract count from aggregation query results.
Args:
query_results: Query results
Returns:
Count value
"""
if isinstance(query_results, SPARQLResult):
if query_results.bindings:
binding = query_results.bindings[0]
# Look for count variable
for var, value in binding.items():
if 'count' in var.lower():
if isinstance(value, dict):
return int(value.get('value', 0))
return int(value)
return len(query_results.bindings)
else: # CypherResult
if query_results.records:
record = query_results.records[0]
if isinstance(record, dict):
# Look for count key
for key, value in record.items():
if 'count' in key.lower():
return int(value)
elif isinstance(record, (int, float)):
return int(record)
return len(query_results.records)
def _extract_items(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
"""Extract items from query results.
Args:
query_results: Query results
Returns:
List of extracted items
"""
items = []
if isinstance(query_results, SPARQLResult):
for binding in query_results.bindings:
for var, value in binding.items():
if isinstance(value, dict):
item_value = value.get('value', str(value))
else:
item_value = str(value)
# Clean up URIs
if item_value.startswith('http'):
item_value = item_value.split('/')[-1].split('#')[-1]
items.append(item_value)
break # Take first value per binding
else: # CypherResult
for record in query_results.records:
if isinstance(record, dict):
# Take first value from record
for key, value in record.items():
items.append(str(value))
break
else:
items.append(str(record))
return items
def _extract_facts(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
"""Extract facts from query results.
Args:
query_results: Query results
Returns:
List of facts
"""
facts = []
if isinstance(query_results, SPARQLResult):
for binding in query_results.bindings:
fact_parts = []
for var, value in binding.items():
if isinstance(value, dict):
val_str = value.get('value', str(value))
else:
val_str = str(value)
# Clean up URIs
if val_str.startswith('http'):
val_str = val_str.split('/')[-1].split('#')[-1]
fact_parts.append(f"{var}: {val_str}")
facts.append(", ".join(fact_parts))
else: # CypherResult
for record in query_results.records:
if isinstance(record, dict):
fact_parts = [f"{k}: {v}" for k, v in record.items()]
facts.append(", ".join(fact_parts))
else:
facts.append(str(record))
return facts
def _extract_boolean_statement(self, question_components: QuestionComponents) -> str:
"""Extract statement for boolean answer.
Args:
question_components: Question analysis
Returns:
Statement string
"""
# Extract the key assertion from the question
question = question_components.original_question.lower()
# Remove question words
statement = question.replace('is ', '').replace('are ', '').replace('does ', '')
statement = statement.replace('?', '').strip()
return statement
def _infer_entity_type(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Infer entity type from question and ontology.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Entity type string
"""
# Try to match entities to ontology classes
for entity in question_components.entities:
entity_lower = entity.lower()
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower or entity_lower in class_id.lower():
return class_id
# Fallback to first entity or generic term
if question_components.entities:
return question_components.entities[0]
else:
return "entities"
def _build_answer_response(self,
answer: str,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
backend_used: str,
execution_time: float,
confidence: float = 0.8) -> GeneratedAnswer:
"""Build final answer response.
Args:
answer: Generated answer text
question_components: Question analysis
query_results: Query results
backend_used: Backend used for query
execution_time: Answer generation time
confidence: Confidence score
Returns:
Complete answer response
"""
# Extract supporting facts
supporting_facts = self._extract_facts(query_results)
# Build metadata
result_count = 0
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
metadata = AnswerMetadata(
query_type=question_components.question_type.value,
backend_used=backend_used,
execution_time=execution_time,
result_count=result_count,
confidence=confidence,
explanation=f"Generated answer using {backend_used} backend",
sources=[] # Could be populated with data source information
)
return GeneratedAnswer(
answer=answer,
metadata=metadata,
supporting_facts=supporting_facts[:5], # Limit to top 5
raw_results=query_results,
generation_time=execution_time
)

View file

@ -0,0 +1,350 @@
"""
Backend router for ontology query system.
Routes queries to appropriate backend based on configuration.
"""
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
from .question_analyzer import QuestionComponents
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
class BackendType(Enum):
"""Supported backend types."""
CASSANDRA = "cassandra"
NEO4J = "neo4j"
MEMGRAPH = "memgraph"
FALKORDB = "falkordb"
@dataclass
class BackendConfig:
"""Configuration for a backend."""
type: BackendType
priority: int = 0
enabled: bool = True
config: Dict[str, Any] = None
@dataclass
class QueryRoute:
"""Routing decision for a query."""
backend_type: BackendType
query_language: str # 'sparql' or 'cypher'
confidence: float
reasoning: str
class BackendRouter:
"""Routes queries to appropriate backends based on configuration and heuristics."""
def __init__(self, config: Dict[str, Any]):
"""Initialize backend router.
Args:
config: Router configuration
"""
self.config = config
self.backends = self._parse_backend_config(config)
self.routing_strategy = config.get('routing_strategy', 'priority')
self.enable_fallback = config.get('enable_fallback', True)
def _parse_backend_config(self, config: Dict[str, Any]) -> Dict[BackendType, BackendConfig]:
"""Parse backend configuration.
Args:
config: Configuration dictionary
Returns:
Dictionary of backend type to configuration
"""
backends = {}
# Parse primary backend
primary = config.get('primary', 'cassandra')
if primary:
try:
backend_type = BackendType(primary)
backends[backend_type] = BackendConfig(
type=backend_type,
priority=100,
enabled=True,
config=config.get(primary, {})
)
except ValueError:
logger.warning(f"Unknown primary backend type: {primary}")
# Parse fallback backends
fallbacks = config.get('fallback', [])
for i, fallback in enumerate(fallbacks):
try:
backend_type = BackendType(fallback)
backends[backend_type] = BackendConfig(
type=backend_type,
priority=50 - i * 10, # Decreasing priority
enabled=True,
config=config.get(fallback, {})
)
except ValueError:
logger.warning(f"Unknown fallback backend type: {fallback}")
return backends
def route_query(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
"""Route a query to the best backend.
Args:
question_components: Analyzed question
ontology_subsets: Relevant ontology subsets
Returns:
QueryRoute with routing decision
"""
if self.routing_strategy == 'priority':
return self._route_by_priority()
elif self.routing_strategy == 'adaptive':
return self._route_adaptive(question_components, ontology_subsets)
elif self.routing_strategy == 'round_robin':
return self._route_round_robin()
else:
return self._route_by_priority()
def _route_by_priority(self) -> QueryRoute:
"""Route based on backend priority.
Returns:
QueryRoute to highest priority backend
"""
# Find highest priority enabled backend
best_backend = None
best_priority = -1
for backend_type, backend_config in self.backends.items():
if backend_config.enabled and backend_config.priority > best_priority:
best_backend = backend_type
best_priority = backend_config.priority
if best_backend is None:
raise RuntimeError("No enabled backends available")
# Determine query language
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=best_backend,
query_language=query_language,
confidence=1.0,
reasoning=f"Priority routing to {best_backend.value}"
)
def _route_adaptive(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
"""Route based on question characteristics and ontology complexity.
Args:
question_components: Analyzed question
ontology_subsets: Relevant ontology subsets
Returns:
QueryRoute with adaptive decision
"""
scores = {}
for backend_type, backend_config in self.backends.items():
if not backend_config.enabled:
continue
score = self._calculate_backend_score(
backend_type, question_components, ontology_subsets
)
scores[backend_type] = score
if not scores:
raise RuntimeError("No enabled backends available")
# Select backend with highest score
best_backend = max(scores.keys(), key=lambda k: scores[k])
best_score = scores[best_backend]
# Determine query language
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=best_backend,
query_language=query_language,
confidence=best_score,
reasoning=f"Adaptive routing: {best_backend.value} scored {best_score:.2f}"
)
def _calculate_backend_score(self,
backend_type: BackendType,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> float:
"""Calculate score for a backend based on query characteristics.
Args:
backend_type: Backend to score
question_components: Question analysis
ontology_subsets: Ontology subsets
Returns:
Score (0.0 to 1.0)
"""
score = 0.0
# Base priority score
backend_config = self.backends[backend_type]
score += backend_config.priority / 100.0
# Question type preferences
if backend_type == BackendType.CASSANDRA:
# SPARQL is good for hierarchical and complex reasoning
if question_components.question_type.value in ['factual', 'aggregation']:
score += 0.3
# Good for ontology-heavy queries
if len(ontology_subsets) > 1:
score += 0.2
else:
# Cypher is good for graph traversal and relationships
if question_components.question_type.value in ['relationship', 'retrieval']:
score += 0.3
# Good for simple graph patterns
if len(question_components.relationships) > 0:
score += 0.2
# Complexity considerations
total_elements = sum(
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
if backend_type == BackendType.CASSANDRA:
# SPARQL handles complex ontologies well
if total_elements > 20:
score += 0.2
else:
# Cypher is efficient for simpler queries
if total_elements <= 10:
score += 0.2
# Aggregation considerations
if question_components.aggregations:
if backend_type == BackendType.CASSANDRA:
score += 0.1 # SPARQL has built-in aggregation
else:
score += 0.2 # Cypher has excellent aggregation
return min(score, 1.0)
def _route_round_robin(self) -> QueryRoute:
"""Route using round-robin strategy.
Returns:
QueryRoute using round-robin selection
"""
# Simple round-robin implementation
enabled_backends = [
bt for bt, bc in self.backends.items() if bc.enabled
]
if not enabled_backends:
raise RuntimeError("No enabled backends available")
# For simplicity, just return the first enabled backend
# In a real implementation, you'd track state
backend_type = enabled_backends[0]
query_language = 'sparql' if backend_type == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=backend_type,
query_language=query_language,
confidence=0.8,
reasoning=f"Round-robin routing to {backend_type.value}"
)
def get_fallback_route(self, failed_backend: BackendType) -> Optional[QueryRoute]:
"""Get fallback route when a backend fails.
Args:
failed_backend: Backend that failed
Returns:
Fallback route or None if no fallback available
"""
if not self.enable_fallback:
return None
# Find next best backend
fallback_backends = [
(bt, bc) for bt, bc in self.backends.items()
if bc.enabled and bt != failed_backend
]
if not fallback_backends:
return None
# Sort by priority
fallback_backends.sort(key=lambda x: x[1].priority, reverse=True)
fallback_type = fallback_backends[0][0]
query_language = 'sparql' if fallback_type == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=fallback_type,
query_language=query_language,
confidence=0.7,
reasoning=f"Fallback from {failed_backend.value} to {fallback_type.value}"
)
def get_available_backends(self) -> List[BackendType]:
"""Get list of available backends.
Returns:
List of enabled backend types
"""
return [bt for bt, bc in self.backends.items() if bc.enabled]
def is_backend_enabled(self, backend_type: BackendType) -> bool:
"""Check if a backend is enabled.
Args:
backend_type: Backend to check
Returns:
True if backend is enabled
"""
backend_config = self.backends.get(backend_type)
return backend_config is not None and backend_config.enabled
def update_backend_status(self, backend_type: BackendType, enabled: bool):
"""Update backend enabled status.
Args:
backend_type: Backend to update
enabled: New enabled status
"""
if backend_type in self.backends:
self.backends[backend_type].enabled = enabled
logger.info(f"Backend {backend_type.value} {'enabled' if enabled else 'disabled'}")
else:
logger.warning(f"Unknown backend type: {backend_type}")
def get_backend_config(self, backend_type: BackendType) -> Optional[Dict[str, Any]]:
"""Get configuration for a backend.
Args:
backend_type: Backend type
Returns:
Configuration dictionary or None
"""
backend_config = self.backends.get(backend_type)
return backend_config.config if backend_config else None

View file

@ -0,0 +1,651 @@
"""
Caching system for OntoRAG query results and computations.
Provides multiple cache backends and intelligent cache management.
"""
import logging
import time
import json
import pickle
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from abc import ABC, abstractmethod
from pathlib import Path
import threading
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Cache entry with metadata."""
key: str
value: Any
created_at: datetime
accessed_at: datetime
access_count: int
ttl_seconds: Optional[int] = None
tags: List[str] = None
size_bytes: int = 0
def is_expired(self) -> bool:
"""Check if cache entry is expired."""
if self.ttl_seconds is None:
return False
return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds
def touch(self):
"""Update access time and count."""
self.accessed_at = datetime.now()
self.access_count += 1
@dataclass
class CacheStats:
"""Cache performance statistics."""
hits: int = 0
misses: int = 0
evictions: int = 0
total_entries: int = 0
total_size_bytes: int = 0
hit_rate: float = 0.0
def update_hit_rate(self):
"""Update hit rate calculation."""
total_requests = self.hits + self.misses
self.hit_rate = self.hits / total_requests if total_requests > 0 else 0.0
class CacheBackend(ABC):
"""Abstract base class for cache backends."""
@abstractmethod
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
pass
@abstractmethod
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""Delete cache entry."""
pass
@abstractmethod
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
pass
@abstractmethod
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
pass
@abstractmethod
def cleanup_expired(self):
"""Clean up expired entries."""
pass
class InMemoryCache(CacheBackend):
"""In-memory cache backend."""
def __init__(self, max_size: int = 1000, max_size_bytes: int = 100 * 1024 * 1024):
"""Initialize in-memory cache.
Args:
max_size: Maximum number of entries
max_size_bytes: Maximum total size in bytes
"""
self.max_size = max_size
self.max_size_bytes = max_size_bytes
self.entries: Dict[str, CacheEntry] = {}
self.stats = CacheStats()
self._lock = threading.RLock()
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
with self._lock:
entry = self.entries.get(key)
if entry is None:
self.stats.misses += 1
self.stats.update_hit_rate()
return None
if entry.is_expired():
del self.entries[key]
self.stats.misses += 1
self.stats.evictions += 1
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.update_hit_rate()
return None
entry.touch()
self.stats.hits += 1
self.stats.update_hit_rate()
return entry
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
with self._lock:
# Calculate size
try:
size_bytes = len(pickle.dumps(value))
except Exception:
size_bytes = len(str(value).encode('utf-8'))
# Create entry
now = datetime.now()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds,
tags=tags or [],
size_bytes=size_bytes
)
# Check if we need to evict
self._ensure_capacity(size_bytes)
# Store entry
old_entry = self.entries.get(key)
if old_entry:
self.stats.total_size_bytes -= old_entry.size_bytes
else:
self.stats.total_entries += 1
self.entries[key] = entry
self.stats.total_size_bytes += size_bytes
def delete(self, key: str) -> bool:
"""Delete cache entry."""
with self._lock:
entry = self.entries.pop(key, None)
if entry:
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.evictions += 1
return True
return False
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
with self._lock:
if tags is None:
# Clear all
self.stats.evictions += len(self.entries)
self.entries.clear()
self.stats.total_entries = 0
self.stats.total_size_bytes = 0
else:
# Clear by tags
to_delete = []
for key, entry in self.entries.items():
if any(tag in entry.tags for tag in tags):
to_delete.append(key)
for key in to_delete:
self.delete(key)
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
return CacheStats(
hits=self.stats.hits,
misses=self.stats.misses,
evictions=self.stats.evictions,
total_entries=self.stats.total_entries,
total_size_bytes=self.stats.total_size_bytes,
hit_rate=self.stats.hit_rate
)
def cleanup_expired(self):
"""Clean up expired entries."""
with self._lock:
to_delete = []
for key, entry in self.entries.items():
if entry.is_expired():
to_delete.append(key)
for key in to_delete:
self.delete(key)
def _ensure_capacity(self, new_size_bytes: int):
"""Ensure cache has capacity for new entry."""
# Check size limit
if self.stats.total_size_bytes + new_size_bytes > self.max_size_bytes:
self._evict_by_size(new_size_bytes)
# Check count limit
if len(self.entries) >= self.max_size:
self._evict_by_count()
def _evict_by_size(self, needed_bytes: int):
"""Evict entries to free up space."""
# Sort by access time (LRU)
sorted_entries = sorted(
self.entries.items(),
key=lambda x: (x[1].accessed_at, x[1].access_count)
)
freed_bytes = 0
for key, entry in sorted_entries:
if freed_bytes >= needed_bytes:
break
freed_bytes += entry.size_bytes
del self.entries[key]
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.evictions += 1
def _evict_by_count(self):
"""Evict least recently used entry."""
if not self.entries:
return
# Find LRU entry
lru_key = min(
self.entries.keys(),
key=lambda k: (self.entries[k].accessed_at, self.entries[k].access_count)
)
self.delete(lru_key)
class FileCache(CacheBackend):
"""File-based cache backend."""
def __init__(self, cache_dir: str, max_files: int = 10000):
"""Initialize file cache.
Args:
cache_dir: Directory to store cache files
max_files: Maximum number of cache files
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_files = max_files
self.stats = CacheStats()
self._lock = threading.RLock()
# Load existing stats
self._load_stats()
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
if not cache_file.exists():
self.stats.misses += 1
self.stats.update_hit_rate()
return None
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if entry.is_expired():
cache_file.unlink()
self.stats.misses += 1
self.stats.evictions += 1
self.stats.total_entries -= 1
self.stats.update_hit_rate()
return None
entry.touch()
# Update file modification time
cache_file.touch()
self.stats.hits += 1
self.stats.update_hit_rate()
return entry
except Exception as e:
logger.error(f"Error reading cache file {cache_file}: {e}")
cache_file.unlink(missing_ok=True)
self.stats.misses += 1
self.stats.update_hit_rate()
return None
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
# Create entry
now = datetime.now()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds,
tags=tags or []
)
try:
# Ensure capacity
self._ensure_capacity()
# Write to file
with open(cache_file, 'wb') as f:
pickle.dump(entry, f)
entry.size_bytes = cache_file.stat().st_size
if not cache_file.exists():
self.stats.total_entries += 1
self.stats.total_size_bytes += entry.size_bytes
self._save_stats()
except Exception as e:
logger.error(f"Error writing cache file {cache_file}: {e}")
def delete(self, key: str) -> bool:
"""Delete cache entry."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
if cache_file.exists():
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
self._save_stats()
return True
return False
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
with self._lock:
if tags is None:
# Clear all
for cache_file in self.cache_dir.glob("*.cache"):
cache_file.unlink()
self.stats.evictions += self.stats.total_entries
self.stats.total_entries = 0
self.stats.total_size_bytes = 0
else:
# Clear by tags
for cache_file in self.cache_dir.glob("*.cache"):
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if any(tag in entry.tags for tag in tags):
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
except Exception:
continue
self._save_stats()
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
return CacheStats(
hits=self.stats.hits,
misses=self.stats.misses,
evictions=self.stats.evictions,
total_entries=self.stats.total_entries,
total_size_bytes=self.stats.total_size_bytes,
hit_rate=self.stats.hit_rate
)
def cleanup_expired(self):
"""Clean up expired entries."""
with self._lock:
for cache_file in self.cache_dir.glob("*.cache"):
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if entry.is_expired():
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
except Exception:
# Remove corrupted files
cache_file.unlink()
self._save_stats()
def _safe_key(self, key: str) -> str:
"""Convert key to safe filename."""
import hashlib
return hashlib.md5(key.encode()).hexdigest()
def _ensure_capacity(self):
"""Ensure cache has capacity for new entry."""
cache_files = list(self.cache_dir.glob("*.cache"))
if len(cache_files) >= self.max_files:
# Remove oldest file
oldest_file = min(cache_files, key=lambda f: f.stat().st_mtime)
size = oldest_file.stat().st_size
oldest_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
def _load_stats(self):
"""Load statistics from file."""
stats_file = self.cache_dir / "stats.json"
if stats_file.exists():
try:
with open(stats_file, 'r') as f:
data = json.load(f)
self.stats = CacheStats(**data)
except Exception:
pass
def _save_stats(self):
"""Save statistics to file."""
stats_file = self.cache_dir / "stats.json"
try:
with open(stats_file, 'w') as f:
json.dump(asdict(self.stats), f, default=str)
except Exception:
pass
class CacheManager:
"""Cache manager with multiple backends and intelligent caching strategies."""
def __init__(self, config: Dict[str, Any]):
"""Initialize cache manager.
Args:
config: Cache configuration
"""
self.config = config
self.backends: Dict[str, CacheBackend] = {}
self.default_backend = config.get('default_backend', 'memory')
self.default_ttl = config.get('default_ttl_seconds', 3600) # 1 hour
# Initialize backends
self._init_backends()
# Start cleanup task
self.cleanup_interval = config.get('cleanup_interval_seconds', 300) # 5 minutes
self._start_cleanup_task()
def _init_backends(self):
"""Initialize cache backends."""
backends_config = self.config.get('backends', {})
# Memory backend
if 'memory' in backends_config or self.default_backend == 'memory':
memory_config = backends_config.get('memory', {})
self.backends['memory'] = InMemoryCache(
max_size=memory_config.get('max_size', 1000),
max_size_bytes=memory_config.get('max_size_bytes', 100 * 1024 * 1024)
)
# File backend
if 'file' in backends_config or self.default_backend == 'file':
file_config = backends_config.get('file', {})
self.backends['file'] = FileCache(
cache_dir=file_config.get('cache_dir', './cache'),
max_files=file_config.get('max_files', 10000)
)
def get(self, key: str, backend: Optional[str] = None) -> Optional[Any]:
"""Get value from cache.
Args:
key: Cache key
backend: Backend name (optional)
Returns:
Cached value or None
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
logger.warning(f"Cache backend '{backend_name}' not found")
return None
entry = cache_backend.get(key)
return entry.value if entry else None
def set(self,
key: str,
value: Any,
ttl_seconds: Optional[int] = None,
tags: Optional[List[str]] = None,
backend: Optional[str] = None):
"""Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl_seconds: Time to live in seconds
tags: Cache tags
backend: Backend name (optional)
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
logger.warning(f"Cache backend '{backend_name}' not found")
return
ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl
cache_backend.set(key, value, ttl, tags)
def delete(self, key: str, backend: Optional[str] = None) -> bool:
"""Delete value from cache.
Args:
key: Cache key
backend: Backend name (optional)
Returns:
True if deleted
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
return False
return cache_backend.delete(key)
def clear(self, tags: Optional[List[str]] = None, backend: Optional[str] = None):
"""Clear cache entries.
Args:
tags: Tags to clear (optional)
backend: Backend name (optional)
"""
if backend:
cache_backend = self.backends.get(backend)
if cache_backend:
cache_backend.clear(tags)
else:
# Clear all backends
for cache_backend in self.backends.values():
cache_backend.clear(tags)
def get_stats(self) -> Dict[str, CacheStats]:
"""Get statistics for all backends.
Returns:
Dictionary of backend name to statistics
"""
return {name: backend.get_stats() for name, backend in self.backends.items()}
def cleanup_expired(self):
"""Clean up expired entries in all backends."""
for backend in self.backends.values():
try:
backend.cleanup_expired()
except Exception as e:
logger.error(f"Error cleaning up cache backend: {e}")
def _start_cleanup_task(self):
"""Start periodic cleanup task."""
def cleanup_worker():
while True:
try:
time.sleep(self.cleanup_interval)
self.cleanup_expired()
except Exception as e:
logger.error(f"Cache cleanup error: {e}")
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
# Cache decorators and utilities
def cache_result(cache_manager: CacheManager,
key_func: Optional[callable] = None,
ttl_seconds: Optional[int] = None,
tags: Optional[List[str]] = None,
backend: Optional[str] = None):
"""Decorator to cache function results.
Args:
cache_manager: Cache manager instance
key_func: Function to generate cache key
ttl_seconds: Time to live
tags: Cache tags
backend: Backend name
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Generate cache key
if key_func:
cache_key = key_func(*args, **kwargs)
else:
cache_key = f"{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
# Try to get from cache
cached_result = cache_manager.get(cache_key, backend)
if cached_result is not None:
return cached_result
# Execute function
result = func(*args, **kwargs)
# Cache result
cache_manager.set(cache_key, result, ttl_seconds, tags, backend)
return result
return wrapper
return decorator

View file

@ -0,0 +1,610 @@
"""
Cypher executor for multiple graph databases.
Executes Cypher queries against Neo4j, Memgraph, and FalkorDB.
"""
import logging
import asyncio
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
from .cypher_generator import CypherQuery
logger = logging.getLogger(__name__)
# Try to import various database drivers
try:
from neo4j import GraphDatabase, Driver as Neo4jDriver
NEO4J_AVAILABLE = True
except ImportError:
NEO4J_AVAILABLE = False
Neo4jDriver = None
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
@dataclass
class CypherResult:
"""Result from Cypher query execution."""
records: List[Dict[str, Any]]
summary: Dict[str, Any]
execution_time: float
database_type: str
query_plan: Optional[Dict[str, Any]] = None
class CypherExecutorBase(ABC):
"""Abstract base class for Cypher executors."""
@abstractmethod
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query."""
pass
@abstractmethod
async def close(self):
"""Close database connection."""
pass
@abstractmethod
def is_connected(self) -> bool:
"""Check if connected to database."""
pass
class Neo4jExecutor(CypherExecutorBase):
"""Cypher executor for Neo4j database."""
def __init__(self, config: Dict[str, Any]):
"""Initialize Neo4j executor.
Args:
config: Neo4j configuration
"""
if not NEO4J_AVAILABLE:
raise RuntimeError("Neo4j driver not available")
self.config = config
self.driver: Optional[Neo4jDriver] = None
self._connection_pool_size = config.get('connection_pool_size', 10)
async def connect(self):
"""Connect to Neo4j database."""
try:
uri = self.config.get('uri', 'bolt://localhost:7687')
username = self.config.get('username')
password = self.config.get('password')
auth = (username, password) if username and password else None
# Create driver with connection pool
self.driver = GraphDatabase.driver(
uri,
auth=auth,
max_connection_pool_size=self._connection_pool_size,
connection_timeout=self.config.get('connection_timeout', 30),
max_retry_time=self.config.get('max_retry_time', 15)
)
# Verify connectivity
await asyncio.get_event_loop().run_in_executor(
None, self.driver.verify_connectivity
)
logger.info(f"Connected to Neo4j at {uri}")
except Exception as e:
logger.error(f"Failed to connect to Neo4j: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against Neo4j.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.driver:
await self.connect()
import time
start_time = time.time()
try:
# Execute query in a session
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={'record_count': len(records)},
execution_time=execution_time,
database_type='neo4j'
)
except Exception as e:
logger.error(f"Neo4j query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='neo4j'
)
def _execute_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously in thread executor.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
with self.driver.session() as session:
result = session.run(cypher_query.query, cypher_query.parameters)
records = []
for record in result:
record_dict = {}
for key in record.keys():
value = record[key]
record_dict[key] = self._format_neo4j_value(value)
records.append(record_dict)
return records
def _format_neo4j_value(self, value):
"""Format Neo4j value for JSON serialization.
Args:
value: Neo4j value
Returns:
JSON-serializable value
"""
# Handle Neo4j node objects
if hasattr(value, 'labels') and hasattr(value, 'items'):
return {
'labels': list(value.labels),
'properties': dict(value.items())
}
# Handle Neo4j relationship objects
elif hasattr(value, 'type') and hasattr(value, 'items'):
return {
'type': value.type,
'properties': dict(value.items())
}
# Handle Neo4j path objects
elif hasattr(value, 'nodes') and hasattr(value, 'relationships'):
return {
'nodes': [self._format_neo4j_value(n) for n in value.nodes],
'relationships': [self._format_neo4j_value(r) for r in value.relationships]
}
else:
return value
async def close(self):
"""Close Neo4j connection."""
if self.driver:
await asyncio.get_event_loop().run_in_executor(
None, self.driver.close
)
self.driver = None
logger.info("Neo4j connection closed")
def is_connected(self) -> bool:
"""Check if connected to Neo4j."""
return self.driver is not None
class MemgraphExecutor(CypherExecutorBase):
"""Cypher executor for Memgraph database."""
def __init__(self, config: Dict[str, Any]):
"""Initialize Memgraph executor.
Args:
config: Memgraph configuration
"""
if not NEO4J_AVAILABLE: # Memgraph uses Neo4j driver
raise RuntimeError("Neo4j driver required for Memgraph")
self.config = config
self.driver: Optional[Neo4jDriver] = None
async def connect(self):
"""Connect to Memgraph database."""
try:
uri = self.config.get('uri', 'bolt://localhost:7688')
username = self.config.get('username')
password = self.config.get('password')
auth = (username, password) if username and password else None
# Memgraph uses Neo4j driver but with different defaults
self.driver = GraphDatabase.driver(
uri,
auth=auth,
max_connection_pool_size=self.config.get('connection_pool_size', 5),
connection_timeout=self.config.get('connection_timeout', 10)
)
# Verify connectivity
await asyncio.get_event_loop().run_in_executor(
None, self.driver.verify_connectivity
)
logger.info(f"Connected to Memgraph at {uri}")
except Exception as e:
logger.error(f"Failed to connect to Memgraph: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against Memgraph.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.driver:
await self.connect()
import time
start_time = time.time()
try:
# Execute query with Memgraph-specific optimizations
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_memgraph_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={
'record_count': len(records),
'engine': 'memgraph'
},
execution_time=execution_time,
database_type='memgraph'
)
except Exception as e:
logger.error(f"Memgraph query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='memgraph'
)
def _execute_memgraph_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously for Memgraph.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
with self.driver.session() as session:
# Add Memgraph-specific query hints if available
query = cypher_query.query
if cypher_query.database_hints and cypher_query.database_hints.get('memory_limit'):
# Memgraph supports memory limits
query = f"// Memory limit: {cypher_query.database_hints['memory_limit']}\n{query}"
result = session.run(query, cypher_query.parameters)
records = []
for record in result:
record_dict = {}
for key in record.keys():
record_dict[key] = record[key]
records.append(record_dict)
return records
async def close(self):
"""Close Memgraph connection."""
if self.driver:
await asyncio.get_event_loop().run_in_executor(
None, self.driver.close
)
self.driver = None
logger.info("Memgraph connection closed")
def is_connected(self) -> bool:
"""Check if connected to Memgraph."""
return self.driver is not None
class FalkorDBExecutor(CypherExecutorBase):
"""Cypher executor for FalkorDB (Redis-based graph database)."""
def __init__(self, config: Dict[str, Any]):
"""Initialize FalkorDB executor.
Args:
config: FalkorDB configuration
"""
if not REDIS_AVAILABLE:
raise RuntimeError("Redis driver required for FalkorDB")
self.config = config
self.redis_client: Optional[redis.Redis] = None
self.graph_name = config.get('graph_name', 'knowledge_graph')
async def connect(self):
"""Connect to FalkorDB (Redis)."""
try:
self.redis_client = redis.Redis(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379),
password=self.config.get('password'),
db=self.config.get('db', 0),
decode_responses=True,
socket_connect_timeout=self.config.get('connection_timeout', 10),
socket_timeout=self.config.get('socket_timeout', 10)
)
# Test connection
await asyncio.get_event_loop().run_in_executor(
None, self.redis_client.ping
)
logger.info(f"Connected to FalkorDB at {self.config.get('host', 'localhost')}")
except Exception as e:
logger.error(f"Failed to connect to FalkorDB: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against FalkorDB.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.redis_client:
await self.connect()
import time
start_time = time.time()
try:
# Execute query using FalkorDB's GRAPH.QUERY command
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_falkordb_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={
'record_count': len(records),
'engine': 'falkordb'
},
execution_time=execution_time,
database_type='falkordb'
)
except Exception as e:
logger.error(f"FalkorDB query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='falkordb'
)
def _execute_falkordb_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously for FalkorDB.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
# Substitute parameters in query (FalkorDB parameter handling)
query = cypher_query.query
for param, value in cypher_query.parameters.items():
if isinstance(value, str):
query = query.replace(f'${param}', f'"{value}"')
else:
query = query.replace(f'${param}', str(value))
# Execute using FalkorDB GRAPH.QUERY command
result = self.redis_client.execute_command(
'GRAPH.QUERY', self.graph_name, query
)
# Parse FalkorDB result format
records = []
if result and len(result) > 1:
# FalkorDB returns [header, data rows, statistics]
headers = result[0] if result[0] else []
data_rows = result[1] if len(result) > 1 else []
for row in data_rows:
record = {}
for i, header in enumerate(headers):
if i < len(row):
record[header] = self._format_falkordb_value(row[i])
records.append(record)
return records
def _format_falkordb_value(self, value):
"""Format FalkorDB value for JSON serialization.
Args:
value: FalkorDB value
Returns:
JSON-serializable value
"""
# FalkorDB returns values in specific formats
if isinstance(value, list) and len(value) == 3:
# Check if it's a node/relationship representation
if value[0] == 1: # Node
return {
'type': 'node',
'labels': value[1],
'properties': value[2]
}
elif value[0] == 2: # Relationship
return {
'type': 'relationship',
'rel_type': value[1],
'properties': value[2]
}
return value
async def close(self):
"""Close FalkorDB connection."""
if self.redis_client:
await asyncio.get_event_loop().run_in_executor(
None, self.redis_client.close
)
self.redis_client = None
logger.info("FalkorDB connection closed")
def is_connected(self) -> bool:
"""Check if connected to FalkorDB."""
return self.redis_client is not None
class CypherExecutor:
"""Multi-database Cypher executor with automatic routing."""
def __init__(self, config: Dict[str, Any]):
"""Initialize multi-database executor.
Args:
config: Configuration for all database types
"""
self.config = config
self.executors: Dict[str, CypherExecutorBase] = {}
# Initialize available executors
self._initialize_executors()
def _initialize_executors(self):
"""Initialize database executors based on configuration."""
# Neo4j executor
if 'neo4j' in self.config and NEO4J_AVAILABLE:
try:
self.executors['neo4j'] = Neo4jExecutor(self.config['neo4j'])
logger.info("Neo4j executor initialized")
except Exception as e:
logger.error(f"Failed to initialize Neo4j executor: {e}")
# Memgraph executor
if 'memgraph' in self.config and NEO4J_AVAILABLE:
try:
self.executors['memgraph'] = MemgraphExecutor(self.config['memgraph'])
logger.info("Memgraph executor initialized")
except Exception as e:
logger.error(f"Failed to initialize Memgraph executor: {e}")
# FalkorDB executor
if 'falkordb' in self.config and REDIS_AVAILABLE:
try:
self.executors['falkordb'] = FalkorDBExecutor(self.config['falkordb'])
logger.info("FalkorDB executor initialized")
except Exception as e:
logger.error(f"Failed to initialize FalkorDB executor: {e}")
if not self.executors:
raise RuntimeError("No database executors could be initialized")
async def execute_cypher(self, cypher_query: CypherQuery,
database_type: str) -> CypherResult:
"""Execute Cypher query on specified database.
Args:
cypher_query: Cypher query to execute
database_type: Target database type
Returns:
Query results
"""
if database_type not in self.executors:
raise ValueError(f"Database type {database_type} not available. "
f"Available: {list(self.executors.keys())}")
executor = self.executors[database_type]
# Ensure connection
if not executor.is_connected():
await executor.connect()
# Execute query
return await executor.execute(cypher_query)
async def execute_on_all(self, cypher_query: CypherQuery) -> Dict[str, CypherResult]:
"""Execute query on all available databases.
Args:
cypher_query: Cypher query to execute
Returns:
Results from all databases
"""
results = {}
tasks = []
for db_type, executor in self.executors.items():
task = asyncio.create_task(
self.execute_cypher(cypher_query, db_type),
name=f"cypher_query_{db_type}"
)
tasks.append((db_type, task))
# Wait for all tasks to complete
for db_type, task in tasks:
try:
results[db_type] = await task
except Exception as e:
logger.error(f"Query failed on {db_type}: {e}")
results[db_type] = CypherResult(
records=[],
summary={'error': str(e)},
execution_time=0.0,
database_type=db_type
)
return results
def get_available_databases(self) -> List[str]:
"""Get list of available database types.
Returns:
List of available database type names
"""
return list(self.executors.keys())
async def close_all(self):
"""Close all database connections."""
for executor in self.executors.values():
await executor.close()
logger.info("All Cypher executor connections closed")

View file

@ -0,0 +1,628 @@
"""
Cypher query generator for ontology-sensitive queries.
Converts natural language questions to Cypher queries for graph databases.
"""
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
@dataclass
class CypherQuery:
"""Generated Cypher query with metadata."""
query: str
parameters: Dict[str, Any]
variables: List[str]
explanation: str
complexity_score: float
database_hints: Dict[str, Any] = None # Database-specific optimization hints
class CypherGenerator:
"""Generates Cypher queries from natural language questions using LLM assistance."""
def __init__(self, prompt_service=None):
"""Initialize Cypher generator.
Args:
prompt_service: Service for LLM-based query generation
"""
self.prompt_service = prompt_service
# Cypher query templates for common patterns
self.templates = {
'simple_node_query': """
MATCH (n:{node_label})
RETURN n.name AS name, n.{property} AS {property}
LIMIT {limit}""",
'relationship_query': """
MATCH (a:{source_label})-[r:{relationship}]->(b:{target_label})
WHERE a.name = $source_name
RETURN b.name AS name, r.{rel_property} AS property""",
'path_query': """
MATCH path = (start:{start_label})-[*1..{max_depth}]->(end:{end_label})
WHERE start.name = $start_name
RETURN path, length(path) AS path_length
ORDER BY path_length""",
'count_query': """
MATCH (n:{node_label})
{where_clause}
RETURN count(n) AS count""",
'aggregation_query': """
MATCH (n:{node_label})
{where_clause}
RETURN
count(n) AS count,
avg(n.{numeric_property}) AS average,
sum(n.{numeric_property}) AS total""",
'boolean_query': """
MATCH (a:{source_label})-[:{relationship}]->(b:{target_label})
WHERE a.name = $source_name AND b.name = $target_name
RETURN count(*) > 0 AS exists""",
'hierarchy_query': """
MATCH (child:{child_label})-[:SUBCLASS_OF*]->(parent:{parent_label})
WHERE parent.name = $parent_name
RETURN child.name AS child_name, parent.name AS parent_name""",
'property_filter_query': """
MATCH (n:{node_label})
WHERE n.{property} {operator} ${property}_value
RETURN n.name AS name, n.{property} AS {property}
ORDER BY n.{property}"""
}
async def generate_cypher(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str = "neo4j") -> CypherQuery:
"""Generate Cypher query for a question.
Args:
question_components: Analyzed question components
ontology_subset: Relevant ontology subset
database_type: Target database (neo4j, memgraph, falkordb)
Returns:
Generated Cypher query
"""
# Try template-based generation first
template_query = self._try_template_generation(
question_components, ontology_subset, database_type
)
if template_query:
logger.debug("Generated Cypher using template")
return template_query
# Fall back to LLM-based generation
if self.prompt_service:
llm_query = await self._generate_with_llm(
question_components, ontology_subset, database_type
)
if llm_query:
logger.debug("Generated Cypher using LLM")
return llm_query
# Final fallback to simple pattern
logger.warning("Falling back to simple Cypher pattern")
return self._generate_fallback_query(question_components, ontology_subset)
def _try_template_generation(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> Optional[CypherQuery]:
"""Try to generate query using templates.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Generated query or None if no template matches
"""
# Simple node query (What are the animals?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) == 1):
node_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
if node_label:
query = self.templates['simple_node_query'].format(
node_label=node_label,
property='name',
limit=100
)
return CypherQuery(
query=query,
parameters={},
variables=['name'],
explanation=f"Retrieve all nodes of type {node_label}",
complexity_score=0.2,
database_hints=self._get_database_hints(database_type, 'simple')
)
# Count query (How many animals are there?)
if (question_components.question_type == QuestionType.AGGREGATION and
'count' in question_components.aggregations):
node_label = self._find_matching_node_label(
question_components.entities[0] if question_components.entities else 'Entity',
ontology_subset
)
if node_label:
where_clause = self._build_where_clause(question_components)
query = self.templates['count_query'].format(
node_label=node_label,
where_clause=where_clause
)
return CypherQuery(
query=query,
parameters=self._extract_parameters(question_components),
variables=['count'],
explanation=f"Count nodes of type {node_label}",
complexity_score=0.3,
database_hints=self._get_database_hints(database_type, 'aggregation')
)
# Relationship query (Which documents were authored by John Smith?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) >= 2):
source_label = self._find_matching_node_label(
question_components.entities[1], ontology_subset
)
target_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
relationship = self._find_matching_relationship(
question_components, ontology_subset
)
if source_label and target_label and relationship:
query = self.templates['relationship_query'].format(
source_label=source_label,
target_label=target_label,
relationship=relationship,
rel_property='name'
)
return CypherQuery(
query=query,
parameters={'source_name': question_components.entities[1]},
variables=['name'],
explanation=f"Find {target_label} related to {source_label} via {relationship}",
complexity_score=0.4,
database_hints=self._get_database_hints(database_type, 'relationship')
)
# Boolean query (Is X related to Y?)
if question_components.question_type == QuestionType.BOOLEAN:
if len(question_components.entities) >= 2:
source_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
target_label = self._find_matching_node_label(
question_components.entities[1], ontology_subset
)
relationship = self._find_matching_relationship(
question_components, ontology_subset
)
if source_label and target_label and relationship:
query = self.templates['boolean_query'].format(
source_label=source_label,
target_label=target_label,
relationship=relationship
)
return CypherQuery(
query=query,
parameters={
'source_name': question_components.entities[0],
'target_name': question_components.entities[1]
},
variables=['exists'],
explanation="Boolean check for relationship existence",
complexity_score=0.3,
database_hints=self._get_database_hints(database_type, 'boolean')
)
return None
async def _generate_with_llm(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> Optional[CypherQuery]:
"""Generate Cypher using LLM.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Generated query or None if failed
"""
try:
prompt = self._build_cypher_prompt(
question_components, ontology_subset, database_type
)
response = await self.prompt_service.generate_cypher(prompt=prompt)
if response and isinstance(response, dict):
query = response.get('query', '').strip()
if query.upper().startswith(('MATCH', 'CREATE', 'MERGE', 'DELETE', 'RETURN')):
return CypherQuery(
query=query,
parameters=response.get('parameters', {}),
variables=self._extract_variables(query),
explanation=response.get('explanation', 'Generated by LLM'),
complexity_score=self._calculate_complexity(query),
database_hints=self._get_database_hints(database_type, 'complex')
)
except Exception as e:
logger.error(f"LLM Cypher generation failed: {e}")
return None
def _build_cypher_prompt(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> str:
"""Build prompt for LLM Cypher generation.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Formatted prompt string
"""
# Format ontology elements as node labels and relationships
node_labels = self._format_node_labels(ontology_subset.classes)
relationships = self._format_relationships(
ontology_subset.object_properties,
ontology_subset.datatype_properties
)
prompt = f"""Generate a Cypher query for the following question using the provided ontology.
QUESTION: {question_components.original_question}
TARGET DATABASE: {database_type}
AVAILABLE NODE LABELS (from classes):
{node_labels}
AVAILABLE RELATIONSHIP TYPES (from properties):
{relationships}
RULES:
- Use MATCH patterns for graph traversal
- Include WHERE clauses for filters
- Use aggregation functions when needed (COUNT, SUM, AVG)
- Optimize for {database_type} performance
- Consider index hints for large datasets
- Use parameters for values (e.g., $name)
QUERY TYPE HINTS:
- Question type: {question_components.question_type.value}
- Expected answer: {question_components.expected_answer_type}
- Entities mentioned: {', '.join(question_components.entities)}
- Aggregations: {', '.join(question_components.aggregations)}
DATABASE-SPECIFIC OPTIMIZATIONS:
{self._get_database_specific_hints(database_type)}
Generate a complete Cypher query with parameters:"""
return prompt
def _generate_fallback_query(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> CypherQuery:
"""Generate simple fallback query.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Basic Cypher query
"""
# Very basic MATCH query
first_class = list(ontology_subset.classes.keys())[0] if ontology_subset.classes else 'Entity'
query = f"""MATCH (n:{first_class})
WHERE n.name CONTAINS $keyword
RETURN n.name AS name, labels(n) AS types
LIMIT 10"""
return CypherQuery(
query=query,
parameters={'keyword': question_components.keywords[0] if question_components.keywords else 'entity'},
variables=['name', 'types'],
explanation="Fallback query for basic pattern matching",
complexity_score=0.1
)
def _find_matching_node_label(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching node label in ontology subset.
Args:
entity: Entity string to match
ontology_subset: Ontology subset
Returns:
Matching node label or None
"""
entity_lower = entity.lower()
# Direct match
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower:
return class_id
# Label match
for class_id, class_def in ontology_subset.classes.items():
labels = class_def.get('labels', [])
for label in labels:
if isinstance(label, dict):
label_value = label.get('value', '').lower()
if label_value == entity_lower:
return class_id
# Partial match
for class_id in ontology_subset.classes:
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
return class_id
return None
def _find_matching_relationship(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching relationship type.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Matching relationship type or None
"""
# Look for relationship keywords
for keyword in question_components.keywords:
keyword_lower = keyword.lower()
# Check object properties
for prop_id in ontology_subset.object_properties:
if keyword_lower in prop_id.lower() or prop_id.lower() in keyword_lower:
return prop_id.upper().replace('-', '_')
# Common relationship mappings
relationship_mappings = {
'author': 'AUTHORED_BY',
'created': 'CREATED_BY',
'owns': 'OWNS',
'has': 'HAS',
'contains': 'CONTAINS',
'parent': 'PARENT_OF',
'child': 'CHILD_OF',
'related': 'RELATED_TO'
}
for keyword in question_components.keywords:
if keyword.lower() in relationship_mappings:
return relationship_mappings[keyword.lower()]
# Default relationship
return 'RELATED_TO'
def _build_where_clause(self, question_components: QuestionComponents) -> str:
"""Build WHERE clause for Cypher query.
Args:
question_components: Question analysis
Returns:
WHERE clause string
"""
conditions = []
for constraint in question_components.constraints:
if 'greater than' in constraint.lower():
import re
numbers = re.findall(r'\d+', constraint)
if numbers:
conditions.append(f"n.value > {numbers[0]}")
elif 'less than' in constraint.lower():
numbers = re.findall(r'\d+', constraint)
if numbers:
conditions.append(f"n.value < {numbers[0]}")
if conditions:
return f"WHERE {' AND '.join(conditions)}"
return ""
def _extract_parameters(self, question_components: QuestionComponents) -> Dict[str, Any]:
"""Extract parameters from question components.
Args:
question_components: Question analysis
Returns:
Parameters dictionary
"""
parameters = {}
# Extract numeric values
import re
for constraint in question_components.constraints:
numbers = re.findall(r'\d+', constraint)
for i, number in enumerate(numbers):
parameters[f'value_{i}'] = int(number)
return parameters
def _format_node_labels(self, classes: Dict[str, Any]) -> str:
"""Format classes as node labels for prompt.
Args:
classes: Classes dictionary
Returns:
Formatted node labels string
"""
if not classes:
return "None"
lines = []
for class_id, definition in classes.items():
comment = definition.get('comment', '')
lines.append(f"- :{class_id} - {comment}")
return '\n'.join(lines)
def _format_relationships(self,
object_props: Dict[str, Any],
datatype_props: Dict[str, Any]) -> str:
"""Format properties as relationships for prompt.
Args:
object_props: Object properties
datatype_props: Datatype properties
Returns:
Formatted relationships string
"""
lines = []
for prop_id, definition in object_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'Any')
comment = definition.get('comment', '')
rel_type = prop_id.upper().replace('-', '_')
lines.append(f"- :{rel_type} ({domain} -> {range_val}) - {comment}")
return '\n'.join(lines) if lines else "None"
def _extract_variables(self, query: str) -> List[str]:
"""Extract variables from Cypher query.
Args:
query: Cypher query string
Returns:
List of variable names
"""
import re
# Extract RETURN clause variables
return_match = re.search(r'RETURN\s+(.+?)(?:ORDER|LIMIT|$)', query, re.IGNORECASE | re.DOTALL)
if return_match:
return_clause = return_match.group(1)
variables = re.findall(r'(\w+)(?:\s+AS\s+(\w+))?', return_clause)
return [var[1] if var[1] else var[0] for var in variables]
return []
def _calculate_complexity(self, query: str) -> float:
"""Calculate complexity score for Cypher query.
Args:
query: Cypher query string
Returns:
Complexity score (0.0 to 1.0)
"""
complexity = 0.0
query_upper = query.upper()
# Count different Cypher features
if 'JOIN' in query_upper or 'UNION' in query_upper:
complexity += 0.3
if 'WHERE' in query_upper:
complexity += 0.2
if 'OPTIONAL' in query_upper:
complexity += 0.1
if 'ORDER BY' in query_upper:
complexity += 0.1
if '*' in query: # Variable length paths
complexity += 0.2
if any(agg in query_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
complexity += 0.2
# Count path length
path_matches = re.findall(r'\[.*?\*(\d+)\.\.(\d+).*?\]', query)
for start, end in path_matches:
complexity += (int(end) - int(start)) * 0.05
return min(complexity, 1.0)
def _get_database_hints(self, database_type: str, query_category: str) -> Dict[str, Any]:
"""Get database-specific optimization hints.
Args:
database_type: Target database
query_category: Category of query
Returns:
Optimization hints
"""
hints = {}
if database_type == "neo4j":
hints.update({
'use_index': True,
'explain_plan': 'EXPLAIN',
'profile_query': 'PROFILE'
})
elif database_type == "memgraph":
hints.update({
'use_index': True,
'explain_plan': 'EXPLAIN',
'memory_limit': '1GB'
})
elif database_type == "falkordb":
hints.update({
'use_index': False, # Redis-based, different indexing
'cache_result': True
})
return hints
def _get_database_specific_hints(self, database_type: str) -> str:
"""Get database-specific optimization hints as text.
Args:
database_type: Target database
Returns:
Hints as formatted string
"""
if database_type == "neo4j":
return """- Use USING INDEX hints for large datasets
- Consider PROFILE for query optimization
- Prefer MERGE over CREATE when appropriate"""
elif database_type == "memgraph":
return """- Leverage in-memory processing advantages
- Use streaming for large result sets
- Consider query parallelization"""
elif database_type == "falkordb":
return """- Optimize for Redis memory constraints
- Use simple patterns for best performance
- Leverage Redis data structures when possible"""
else:
return "- Use standard Cypher optimization patterns"

View file

@ -0,0 +1,557 @@
"""
Error handling and recovery mechanisms for OntoRAG.
Provides comprehensive error handling, retry logic, and graceful degradation.
"""
import logging
import time
import asyncio
from typing import Dict, Any, List, Optional, Callable, Union, Type
from dataclasses import dataclass
from enum import Enum
from functools import wraps
import traceback
logger = logging.getLogger(__name__)
class ErrorSeverity(Enum):
"""Error severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ErrorCategory(Enum):
"""Error categories for better handling."""
ONTOLOGY_LOADING = "ontology_loading"
QUESTION_ANALYSIS = "question_analysis"
QUERY_GENERATION = "query_generation"
QUERY_EXECUTION = "query_execution"
ANSWER_GENERATION = "answer_generation"
BACKEND_CONNECTION = "backend_connection"
CACHE_ERROR = "cache_error"
VALIDATION_ERROR = "validation_error"
TIMEOUT_ERROR = "timeout_error"
AUTHENTICATION_ERROR = "authentication_error"
@dataclass
class ErrorContext:
"""Context information for an error."""
category: ErrorCategory
severity: ErrorSeverity
component: str
operation: str
user_message: Optional[str] = None
technical_details: Optional[str] = None
suggestion: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
metadata: Dict[str, Any] = None
class OntoRAGError(Exception):
"""Base exception for OntoRAG system."""
def __init__(self,
message: str,
context: Optional[ErrorContext] = None,
cause: Optional[Exception] = None):
"""Initialize OntoRAG error.
Args:
message: Error message
context: Error context
cause: Original exception that caused this error
"""
super().__init__(message)
self.message = message
self.context = context or ErrorContext(
category=ErrorCategory.VALIDATION_ERROR,
severity=ErrorSeverity.MEDIUM,
component="unknown",
operation="unknown"
)
self.cause = cause
self.timestamp = time.time()
class OntologyLoadingError(OntoRAGError):
"""Error loading ontology."""
pass
class QuestionAnalysisError(OntoRAGError):
"""Error analyzing question."""
pass
class QueryGenerationError(OntoRAGError):
"""Error generating query."""
pass
class QueryExecutionError(OntoRAGError):
"""Error executing query."""
pass
class AnswerGenerationError(OntoRAGError):
"""Error generating answer."""
pass
class BackendConnectionError(OntoRAGError):
"""Error connecting to backend."""
pass
class TimeoutError(OntoRAGError):
"""Operation timeout error."""
pass
@dataclass
class RetryConfig:
"""Configuration for retry logic."""
max_retries: int = 3
base_delay: float = 1.0
max_delay: float = 60.0
exponential_backoff: bool = True
jitter: bool = True
retry_on_exceptions: List[Type[Exception]] = None
class ErrorRecoveryStrategy:
"""Strategy for handling and recovering from errors."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize error recovery strategy.
Args:
config: Recovery configuration
"""
self.config = config or {}
self.retry_configs = self._build_retry_configs()
self.fallback_strategies = self._build_fallback_strategies()
self.error_counters: Dict[str, int] = {}
self.circuit_breakers: Dict[str, Dict[str, Any]] = {}
def _build_retry_configs(self) -> Dict[ErrorCategory, RetryConfig]:
"""Build retry configurations for different error categories."""
return {
ErrorCategory.BACKEND_CONNECTION: RetryConfig(
max_retries=5,
base_delay=2.0,
retry_on_exceptions=[BackendConnectionError, ConnectionError, TimeoutError]
),
ErrorCategory.QUERY_EXECUTION: RetryConfig(
max_retries=3,
base_delay=1.0,
retry_on_exceptions=[QueryExecutionError, TimeoutError]
),
ErrorCategory.ONTOLOGY_LOADING: RetryConfig(
max_retries=2,
base_delay=0.5,
retry_on_exceptions=[OntologyLoadingError, IOError]
),
ErrorCategory.QUESTION_ANALYSIS: RetryConfig(
max_retries=2,
base_delay=1.0,
retry_on_exceptions=[QuestionAnalysisError, TimeoutError]
),
ErrorCategory.ANSWER_GENERATION: RetryConfig(
max_retries=2,
base_delay=1.0,
retry_on_exceptions=[AnswerGenerationError, TimeoutError]
)
}
def _build_fallback_strategies(self) -> Dict[ErrorCategory, Callable]:
"""Build fallback strategies for different error categories."""
return {
ErrorCategory.QUESTION_ANALYSIS: self._fallback_question_analysis,
ErrorCategory.QUERY_GENERATION: self._fallback_query_generation,
ErrorCategory.QUERY_EXECUTION: self._fallback_query_execution,
ErrorCategory.ANSWER_GENERATION: self._fallback_answer_generation,
ErrorCategory.BACKEND_CONNECTION: self._fallback_backend_connection
}
async def handle_error(self,
error: Exception,
context: ErrorContext,
operation: Callable,
*args,
**kwargs) -> Any:
"""Handle error with recovery strategies.
Args:
error: The exception that occurred
context: Error context
operation: Function to retry
*args: Operation arguments
**kwargs: Operation keyword arguments
Returns:
Result of successful operation or fallback
"""
logger.error(f"Handling error in {context.component}.{context.operation}: {error}")
# Update error counters
error_key = f"{context.category.value}:{context.component}"
self.error_counters[error_key] = self.error_counters.get(error_key, 0) + 1
# Check circuit breaker
if self._is_circuit_open(error_key):
return await self._execute_fallback(context, *args, **kwargs)
# Try retry if configured
retry_config = self.retry_configs.get(context.category)
if retry_config and context.retry_count < retry_config.max_retries:
if any(isinstance(error, exc_type) for exc_type in retry_config.retry_on_exceptions or []):
return await self._retry_operation(
operation, context, retry_config, *args, **kwargs
)
# Execute fallback strategy
return await self._execute_fallback(context, *args, **kwargs)
async def _retry_operation(self,
operation: Callable,
context: ErrorContext,
retry_config: RetryConfig,
*args,
**kwargs) -> Any:
"""Retry operation with backoff."""
context.retry_count += 1
# Calculate delay
delay = retry_config.base_delay
if retry_config.exponential_backoff:
delay *= (2 ** (context.retry_count - 1))
delay = min(delay, retry_config.max_delay)
# Add jitter
if retry_config.jitter:
import random
delay *= (0.5 + random.random())
logger.info(f"Retrying {context.component}.{context.operation} "
f"(attempt {context.retry_count}) after {delay:.2f}s")
await asyncio.sleep(delay)
try:
if asyncio.iscoroutinefunction(operation):
return await operation(*args, **kwargs)
else:
return operation(*args, **kwargs)
except Exception as e:
return await self.handle_error(e, context, operation, *args, **kwargs)
async def _execute_fallback(self,
context: ErrorContext,
*args,
**kwargs) -> Any:
"""Execute fallback strategy."""
fallback_func = self.fallback_strategies.get(context.category)
if fallback_func:
logger.info(f"Executing fallback for {context.category.value}")
try:
if asyncio.iscoroutinefunction(fallback_func):
return await fallback_func(context, *args, **kwargs)
else:
return fallback_func(context, *args, **kwargs)
except Exception as e:
logger.error(f"Fallback strategy failed: {e}")
# Default fallback
return self._default_fallback(context)
def _is_circuit_open(self, error_key: str) -> bool:
"""Check if circuit breaker is open."""
circuit = self.circuit_breakers.get(error_key, {})
error_count = self.error_counters.get(error_key, 0)
error_threshold = self.config.get('circuit_breaker_threshold', 10)
window_seconds = self.config.get('circuit_breaker_window', 300) # 5 minutes
current_time = time.time()
window_start = circuit.get('window_start', current_time)
# Reset window if expired
if current_time - window_start > window_seconds:
self.circuit_breakers[error_key] = {'window_start': current_time}
self.error_counters[error_key] = 0
return False
return error_count >= error_threshold
def _default_fallback(self, context: ErrorContext) -> Any:
"""Default fallback response."""
if context.category == ErrorCategory.ANSWER_GENERATION:
return "I'm sorry, I encountered an error while processing your question. Please try again."
elif context.category == ErrorCategory.QUERY_EXECUTION:
return {"error": "Query execution failed", "results": []}
else:
return None
# Fallback strategy implementations
async def _fallback_question_analysis(self, context: ErrorContext, question: str, **kwargs):
"""Fallback for question analysis."""
from .question_analyzer import QuestionComponents, QuestionType
# Simple keyword-based analysis
question_lower = question.lower()
# Determine question type
if any(word in question_lower for word in ['how many', 'count', 'number']):
question_type = QuestionType.AGGREGATION
elif question_lower.startswith(('is', 'are', 'does', 'can')):
question_type = QuestionType.BOOLEAN
elif any(word in question_lower for word in ['what', 'which', 'who', 'where']):
question_type = QuestionType.RETRIEVAL
else:
question_type = QuestionType.FACTUAL
# Extract simple entities (nouns)
import re
words = re.findall(r'\b[a-zA-Z]+\b', question)
entities = [word for word in words if len(word) > 3 and word.lower() not in
{'what', 'which', 'where', 'when', 'who', 'how', 'does', 'are', 'the'}]
return QuestionComponents(
original_question=question,
normalized_question=question.lower(),
question_type=question_type,
entities=entities[:3], # Limit to 3 entities
keywords=words[:5], # Limit to 5 keywords
relationships=[],
constraints=[],
aggregations=['count'] if question_type == QuestionType.AGGREGATION else [],
expected_answer_type='text'
)
async def _fallback_query_generation(self, context: ErrorContext, **kwargs):
"""Fallback for query generation."""
# Generate simple query based on available information
if 'sparql' in context.metadata.get('query_language', '').lower():
query = """
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?predicate ?object WHERE {
?subject ?predicate ?object .
}
LIMIT 10
"""
from .sparql_generator import SPARQLQuery
return SPARQLQuery(
query=query,
variables=['subject', 'predicate', 'object'],
query_type='SELECT',
explanation='Fallback SPARQL query',
complexity_score=0.1
)
else:
query = "MATCH (n) RETURN n LIMIT 10"
from .cypher_generator import CypherQuery
return CypherQuery(
query=query,
variables=['n'],
query_type='MATCH',
explanation='Fallback Cypher query',
complexity_score=0.1
)
async def _fallback_query_execution(self, context: ErrorContext, **kwargs):
"""Fallback for query execution."""
# Return empty results
if 'sparql' in context.metadata.get('query_language', '').lower():
from .sparql_cassandra import SPARQLResult
return SPARQLResult(
bindings=[],
variables=[],
execution_time=0.0
)
else:
from .cypher_executor import CypherResult
return CypherResult(
records=[],
summary={'type': 'fallback'},
metadata={'query': 'fallback'},
execution_time=0.0
)
async def _fallback_answer_generation(self, context: ErrorContext, question: str = None, **kwargs):
"""Fallback for answer generation."""
fallback_messages = [
"I'm experiencing some technical difficulties. Please try rephrasing your question.",
"I couldn't process your question at the moment. Could you try asking it differently?",
"There seems to be an issue with my analysis. Please try again in a moment.",
"I'm having trouble understanding your question right now. Please try again."
]
import random
return random.choice(fallback_messages)
async def _fallback_backend_connection(self, context: ErrorContext, **kwargs):
"""Fallback for backend connection."""
logger.warning(f"Backend connection failed for {context.component}")
# Could switch to alternative backend here
return None
def with_error_handling(category: ErrorCategory,
component: str,
operation: str,
severity: ErrorSeverity = ErrorSeverity.MEDIUM):
"""Decorator for automatic error handling.
Args:
category: Error category
component: Component name
operation: Operation name
severity: Error severity
"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
context = ErrorContext(
category=category,
severity=severity,
component=component,
operation=operation,
technical_details=str(e),
metadata={'args': str(args), 'kwargs': str(kwargs)}
)
# Get error recovery strategy from first argument if it's available
error_strategy = None
if args and hasattr(args[0], '_error_strategy'):
error_strategy = args[0]._error_strategy
if error_strategy:
return await error_strategy.handle_error(e, context, func, *args, **kwargs)
else:
# Re-raise as OntoRAG error
raise OntoRAGError(
f"Error in {component}.{operation}: {str(e)}",
context=context,
cause=e
)
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
context = ErrorContext(
category=category,
severity=severity,
component=component,
operation=operation,
technical_details=str(e),
metadata={'args': str(args), 'kwargs': str(kwargs)}
)
raise OntoRAGError(
f"Error in {component}.{operation}: {str(e)}",
context=context,
cause=e
)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
class ErrorReporter:
"""Reports and tracks errors for monitoring and debugging."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize error reporter.
Args:
config: Reporter configuration
"""
self.config = config or {}
self.error_log: List[Dict[str, Any]] = []
self.max_log_size = self.config.get('max_log_size', 1000)
def report_error(self, error: OntoRAGError):
"""Report an error for tracking.
Args:
error: The error to report
"""
error_entry = {
'timestamp': error.timestamp,
'message': error.message,
'category': error.context.category.value,
'severity': error.context.severity.value,
'component': error.context.component,
'operation': error.context.operation,
'retry_count': error.context.retry_count,
'technical_details': error.context.technical_details,
'stack_trace': traceback.format_exc() if error.cause else None
}
self.error_log.append(error_entry)
# Trim log if too large
if len(self.error_log) > self.max_log_size:
self.error_log = self.error_log[-self.max_log_size:]
# Log based on severity
if error.context.severity == ErrorSeverity.CRITICAL:
logger.critical(f"CRITICAL ERROR: {error.message}")
elif error.context.severity == ErrorSeverity.HIGH:
logger.error(f"HIGH SEVERITY: {error.message}")
elif error.context.severity == ErrorSeverity.MEDIUM:
logger.warning(f"MEDIUM SEVERITY: {error.message}")
else:
logger.info(f"LOW SEVERITY: {error.message}")
def get_error_summary(self) -> Dict[str, Any]:
"""Get summary of recent errors.
Returns:
Error summary statistics
"""
if not self.error_log:
return {'total_errors': 0}
recent_errors = [
e for e in self.error_log
if time.time() - e['timestamp'] < 3600 # Last hour
]
category_counts = {}
severity_counts = {}
component_counts = {}
for error in recent_errors:
category_counts[error['category']] = category_counts.get(error['category'], 0) + 1
severity_counts[error['severity']] = severity_counts.get(error['severity'], 0) + 1
component_counts[error['component']] = component_counts.get(error['component'], 0) + 1
return {
'total_errors': len(self.error_log),
'recent_errors': len(recent_errors),
'category_breakdown': category_counts,
'severity_breakdown': severity_counts,
'component_breakdown': component_counts,
'most_recent_error': self.error_log[-1] if self.error_log else None
}

View file

@ -0,0 +1,737 @@
"""
Performance monitoring and metrics collection for OntoRAG.
Provides comprehensive monitoring of system performance, query patterns, and resource usage.
"""
import logging
import time
import asyncio
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from collections import defaultdict, deque
import statistics
from enum import Enum
logger = logging.getLogger(__name__)
class MetricType(Enum):
"""Types of metrics to collect."""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
TIMER = "timer"
@dataclass
class Metric:
"""Individual metric data point."""
name: str
value: float
timestamp: datetime
labels: Dict[str, str] = field(default_factory=dict)
metric_type: MetricType = MetricType.GAUGE
@dataclass
class TimerMetric:
"""Timer metric for measuring duration."""
name: str
start_time: float
labels: Dict[str, str] = field(default_factory=dict)
def stop(self) -> float:
"""Stop timer and return duration."""
return time.time() - self.start_time
@dataclass
class PerformanceStats:
"""Performance statistics for a component."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
avg_response_time: float = 0.0
min_response_time: float = float('inf')
max_response_time: float = 0.0
p95_response_time: float = 0.0
p99_response_time: float = 0.0
throughput_per_second: float = 0.0
error_rate: float = 0.0
@dataclass
class SystemHealth:
"""Overall system health metrics."""
status: str = "healthy" # healthy, degraded, unhealthy
uptime_seconds: float = 0.0
cpu_usage_percent: float = 0.0
memory_usage_percent: float = 0.0
active_connections: int = 0
queue_size: int = 0
cache_hit_rate: float = 0.0
error_rate: float = 0.0
class MetricsCollector:
"""Collects and stores metrics data."""
def __init__(self, max_metrics: int = 10000, retention_hours: int = 24):
"""Initialize metrics collector.
Args:
max_metrics: Maximum number of metrics to retain
retention_hours: Hours to retain metrics
"""
self.max_metrics = max_metrics
self.retention_hours = retention_hours
self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_metrics))
self.counters: Dict[str, float] = defaultdict(float)
self.gauges: Dict[str, float] = defaultdict(float)
self.timers: Dict[str, List[float]] = defaultdict(list)
self._lock = threading.RLock()
def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None):
"""Increment a counter metric.
Args:
name: Metric name
value: Value to increment by
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.counters[metric_key] += value
self._add_metric(name, value, MetricType.COUNTER, labels)
def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None):
"""Set a gauge metric value.
Args:
name: Metric name
value: Gauge value
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.gauges[metric_key] = value
self._add_metric(name, value, MetricType.GAUGE, labels)
def record_timer(self, name: str, duration: float, labels: Dict[str, str] = None):
"""Record a timer measurement.
Args:
name: Metric name
duration: Duration in seconds
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.timers[metric_key].append(duration)
# Keep only recent measurements
max_timer_values = 1000
if len(self.timers[metric_key]) > max_timer_values:
self.timers[metric_key] = self.timers[metric_key][-max_timer_values:]
self._add_metric(name, duration, MetricType.TIMER, labels)
def start_timer(self, name: str, labels: Dict[str, str] = None) -> TimerMetric:
"""Start a timer.
Args:
name: Metric name
labels: Metric labels
Returns:
Timer metric object
"""
return TimerMetric(name=name, start_time=time.time(), labels=labels or {})
def stop_timer(self, timer: TimerMetric):
"""Stop a timer and record the measurement.
Args:
timer: Timer metric to stop
"""
duration = timer.stop()
self.record_timer(timer.name, duration, timer.labels)
return duration
def get_counter(self, name: str, labels: Dict[str, str] = None) -> float:
"""Get counter value.
Args:
name: Metric name
labels: Metric labels
Returns:
Counter value
"""
metric_key = self._build_key(name, labels)
return self.counters.get(metric_key, 0.0)
def get_gauge(self, name: str, labels: Dict[str, str] = None) -> float:
"""Get gauge value.
Args:
name: Metric name
labels: Metric labels
Returns:
Gauge value
"""
metric_key = self._build_key(name, labels)
return self.gauges.get(metric_key, 0.0)
def get_timer_stats(self, name: str, labels: Dict[str, str] = None) -> Dict[str, float]:
"""Get timer statistics.
Args:
name: Metric name
labels: Metric labels
Returns:
Timer statistics
"""
metric_key = self._build_key(name, labels)
values = self.timers.get(metric_key, [])
if not values:
return {}
sorted_values = sorted(values)
return {
'count': len(values),
'sum': sum(values),
'avg': statistics.mean(values),
'min': min(values),
'max': max(values),
'p50': sorted_values[int(len(sorted_values) * 0.5)],
'p95': sorted_values[int(len(sorted_values) * 0.95)],
'p99': sorted_values[int(len(sorted_values) * 0.99)]
}
def get_metrics(self,
name_pattern: Optional[str] = None,
since: Optional[datetime] = None) -> List[Metric]:
"""Get metrics matching pattern and time range.
Args:
name_pattern: Pattern to match metric names
since: Only return metrics since this time
Returns:
List of matching metrics
"""
with self._lock:
results = []
cutoff_time = since or datetime.now() - timedelta(hours=self.retention_hours)
for metric_name, metric_queue in self.metrics.items():
if name_pattern and name_pattern not in metric_name:
continue
for metric in metric_queue:
if metric.timestamp >= cutoff_time:
results.append(metric)
return sorted(results, key=lambda m: m.timestamp)
def cleanup_old_metrics(self):
"""Remove old metrics beyond retention period."""
with self._lock:
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
for metric_name in list(self.metrics.keys()):
metric_queue = self.metrics[metric_name]
# Remove old metrics
while metric_queue and metric_queue[0].timestamp < cutoff_time:
metric_queue.popleft()
# Remove empty queues
if not metric_queue:
del self.metrics[metric_name]
def _add_metric(self, name: str, value: float, metric_type: MetricType, labels: Dict[str, str]):
"""Add metric to storage."""
metric = Metric(
name=name,
value=value,
timestamp=datetime.now(),
labels=labels or {},
metric_type=metric_type
)
self.metrics[name].append(metric)
def _build_key(self, name: str, labels: Dict[str, str]) -> str:
"""Build metric key from name and labels."""
if not labels:
return name
label_str = ','.join(f"{k}={v}" for k, v in sorted(labels.items()))
return f"{name}{{{label_str}}}"
class PerformanceMonitor:
"""Monitors system performance and component health."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize performance monitor.
Args:
config: Monitor configuration
"""
self.config = config or {}
self.metrics_collector = MetricsCollector(
max_metrics=self.config.get('max_metrics', 10000),
retention_hours=self.config.get('retention_hours', 24)
)
self.component_stats: Dict[str, PerformanceStats] = {}
self.start_time = time.time()
self.monitoring_enabled = self.config.get('enabled', True)
# Start background monitoring tasks
if self.monitoring_enabled:
self._start_background_tasks()
def record_request(self,
component: str,
operation: str,
duration: float,
success: bool = True,
labels: Dict[str, str] = None):
"""Record a request completion.
Args:
component: Component name
operation: Operation name
duration: Request duration in seconds
success: Whether request was successful
labels: Additional labels
"""
if not self.monitoring_enabled:
return
base_labels = {'component': component, 'operation': operation}
if labels:
base_labels.update(labels)
# Record metrics
self.metrics_collector.increment('requests_total', labels=base_labels)
self.metrics_collector.record_timer('request_duration', duration, base_labels)
if success:
self.metrics_collector.increment('requests_successful', labels=base_labels)
else:
self.metrics_collector.increment('requests_failed', labels=base_labels)
# Update component stats
self._update_component_stats(component, duration, success)
def record_query_complexity(self,
complexity_score: float,
query_type: str,
backend: str):
"""Record query complexity metrics.
Args:
complexity_score: Query complexity score (0.0 to 1.0)
query_type: Type of query (SPARQL, Cypher)
backend: Backend used
"""
if not self.monitoring_enabled:
return
labels = {'query_type': query_type, 'backend': backend}
self.metrics_collector.set_gauge('query_complexity', complexity_score, labels)
def record_cache_access(self, hit: bool, cache_type: str = 'default'):
"""Record cache access.
Args:
hit: Whether it was a cache hit
cache_type: Type of cache
"""
if not self.monitoring_enabled:
return
labels = {'cache_type': cache_type}
self.metrics_collector.increment('cache_requests_total', labels=labels)
if hit:
self.metrics_collector.increment('cache_hits_total', labels=labels)
else:
self.metrics_collector.increment('cache_misses_total', labels=labels)
def record_ontology_selection(self,
selected_elements: int,
total_elements: int,
ontology_id: str):
"""Record ontology selection metrics.
Args:
selected_elements: Number of selected ontology elements
total_elements: Total ontology elements
ontology_id: Ontology identifier
"""
if not self.monitoring_enabled:
return
labels = {'ontology_id': ontology_id}
self.metrics_collector.set_gauge('ontology_elements_selected', selected_elements, labels)
self.metrics_collector.set_gauge('ontology_elements_total', total_elements, labels)
selection_ratio = selected_elements / total_elements if total_elements > 0 else 0
self.metrics_collector.set_gauge('ontology_selection_ratio', selection_ratio, labels)
def get_component_stats(self, component: str) -> Optional[PerformanceStats]:
"""Get performance statistics for a component.
Args:
component: Component name
Returns:
Performance statistics or None
"""
return self.component_stats.get(component)
def get_system_health(self) -> SystemHealth:
"""Get overall system health status.
Returns:
System health metrics
"""
# Calculate uptime
uptime = time.time() - self.start_time
# Get error rate
total_requests = self.metrics_collector.get_counter('requests_total')
failed_requests = self.metrics_collector.get_counter('requests_failed')
error_rate = failed_requests / total_requests if total_requests > 0 else 0.0
# Get cache hit rate
cache_hits = self.metrics_collector.get_counter('cache_hits_total')
cache_requests = self.metrics_collector.get_counter('cache_requests_total')
cache_hit_rate = cache_hits / cache_requests if cache_requests > 0 else 0.0
# Determine status
status = "healthy"
if error_rate > 0.1: # More than 10% error rate
status = "degraded"
if error_rate > 0.3: # More than 30% error rate
status = "unhealthy"
return SystemHealth(
status=status,
uptime_seconds=uptime,
error_rate=error_rate,
cache_hit_rate=cache_hit_rate
)
def get_performance_report(self) -> Dict[str, Any]:
"""Get comprehensive performance report.
Returns:
Performance report
"""
report = {
'system_health': self.get_system_health(),
'component_stats': {},
'top_slow_operations': [],
'error_patterns': {},
'cache_performance': {},
'ontology_usage': {}
}
# Component statistics
for component, stats in self.component_stats.items():
report['component_stats'][component] = stats
# Top slow operations
timer_stats = {}
for metric_name in self.metrics_collector.timers.keys():
if 'request_duration' in metric_name:
stats = self.metrics_collector.get_timer_stats(metric_name)
if stats:
timer_stats[metric_name] = stats
# Sort by p95 latency
slow_ops = sorted(
timer_stats.items(),
key=lambda x: x[1].get('p95', 0),
reverse=True
)[:10]
report['top_slow_operations'] = [
{'operation': op, 'stats': stats} for op, stats in slow_ops
]
# Cache performance
cache_types = set()
for metric_name in self.metrics_collector.counters.keys():
if 'cache_type=' in metric_name:
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0]
cache_types.add(cache_type)
for cache_type in cache_types:
labels = {'cache_type': cache_type}
hits = self.metrics_collector.get_counter('cache_hits_total', labels)
requests = self.metrics_collector.get_counter('cache_requests_total', labels)
hit_rate = hits / requests if requests > 0 else 0.0
report['cache_performance'][cache_type] = {
'hit_rate': hit_rate,
'total_requests': requests,
'total_hits': hits
}
return report
def _update_component_stats(self, component: str, duration: float, success: bool):
"""Update component performance statistics."""
if component not in self.component_stats:
self.component_stats[component] = PerformanceStats()
stats = self.component_stats[component]
stats.total_requests += 1
if success:
stats.successful_requests += 1
else:
stats.failed_requests += 1
# Update response time stats
stats.min_response_time = min(stats.min_response_time, duration)
stats.max_response_time = max(stats.max_response_time, duration)
# Get timer stats for percentiles
timer_stats = self.metrics_collector.get_timer_stats(
'request_duration', {'component': component}
)
if timer_stats:
stats.avg_response_time = timer_stats.get('avg', 0.0)
stats.p95_response_time = timer_stats.get('p95', 0.0)
stats.p99_response_time = timer_stats.get('p99', 0.0)
# Calculate rates
stats.error_rate = stats.failed_requests / stats.total_requests
# Calculate throughput (requests per second over last minute)
recent_requests = len([
m for m in self.metrics_collector.get_metrics('requests_total')
if m.labels.get('component') == component and
m.timestamp > datetime.now() - timedelta(minutes=1)
])
stats.throughput_per_second = recent_requests / 60.0
def _start_background_tasks(self):
"""Start background monitoring tasks."""
def cleanup_worker():
"""Worker to clean up old metrics."""
while True:
try:
time.sleep(300) # 5 minutes
self.metrics_collector.cleanup_old_metrics()
except Exception as e:
logger.error(f"Metrics cleanup error: {e}")
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
# Monitoring decorators
def monitor_performance(component: str,
operation: str,
monitor: Optional[PerformanceMonitor] = None):
"""Decorator to monitor function performance.
Args:
component: Component name
operation: Operation name
monitor: Performance monitor instance
"""
def decorator(func):
def wrapper(*args, **kwargs):
if not monitor or not monitor.monitoring_enabled:
return func(*args, **kwargs)
timer = monitor.metrics_collector.start_timer(
'request_duration',
{'component': component, 'operation': operation}
)
success = True
try:
result = func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = monitor.metrics_collector.stop_timer(timer)
monitor.record_request(component, operation, duration, success)
async def async_wrapper(*args, **kwargs):
if not monitor or not monitor.monitoring_enabled:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
timer = monitor.metrics_collector.start_timer(
'request_duration',
{'component': component, 'operation': operation}
)
success = True
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = monitor.metrics_collector.stop_timer(timer)
monitor.record_request(component, operation, duration, success)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return wrapper
return decorator
class QueryPatternAnalyzer:
"""Analyzes query patterns for optimization insights."""
def __init__(self, monitor: PerformanceMonitor):
"""Initialize query pattern analyzer.
Args:
monitor: Performance monitor instance
"""
self.monitor = monitor
self.query_patterns: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
def record_query_pattern(self,
question_type: str,
entities: List[str],
complexity: float,
backend: str,
duration: float,
success: bool):
"""Record a query pattern for analysis.
Args:
question_type: Type of question
entities: Entities in question
complexity: Query complexity score
backend: Backend used
duration: Query duration
success: Whether query succeeded
"""
pattern = {
'timestamp': datetime.now(),
'question_type': question_type,
'entity_count': len(entities),
'entities': entities,
'complexity': complexity,
'backend': backend,
'duration': duration,
'success': success
}
pattern_key = f"{question_type}:{len(entities)}"
self.query_patterns[pattern_key].append(pattern)
# Keep only recent patterns
cutoff_time = datetime.now() - timedelta(hours=24)
self.query_patterns[pattern_key] = [
p for p in self.query_patterns[pattern_key]
if p['timestamp'] > cutoff_time
]
def get_optimization_insights(self) -> Dict[str, Any]:
"""Get insights for query optimization.
Returns:
Optimization insights and recommendations
"""
insights = {
'slow_patterns': [],
'common_failures': [],
'backend_performance': {},
'complexity_analysis': {},
'recommendations': []
}
# Analyze slow patterns
for pattern_key, patterns in self.query_patterns.items():
if not patterns:
continue
avg_duration = statistics.mean([p['duration'] for p in patterns])
success_rate = sum(1 for p in patterns if p['success']) / len(patterns)
if avg_duration > 5.0: # Slow queries > 5 seconds
insights['slow_patterns'].append({
'pattern': pattern_key,
'avg_duration': avg_duration,
'count': len(patterns),
'success_rate': success_rate
})
if success_rate < 0.8: # Low success rate
insights['common_failures'].append({
'pattern': pattern_key,
'success_rate': success_rate,
'count': len(patterns)
})
# Analyze backend performance
backend_stats = defaultdict(list)
for patterns in self.query_patterns.values():
for pattern in patterns:
backend_stats[pattern['backend']].append(pattern['duration'])
for backend, durations in backend_stats.items():
insights['backend_performance'][backend] = {
'avg_duration': statistics.mean(durations),
'p95_duration': sorted(durations)[int(len(durations) * 0.95)],
'query_count': len(durations)
}
# Generate recommendations
recommendations = []
# Slow pattern recommendations
for slow_pattern in insights['slow_patterns']:
recommendations.append(
f"Consider optimizing {slow_pattern['pattern']} queries - "
f"average duration {slow_pattern['avg_duration']:.2f}s"
)
# Backend recommendations
if len(insights['backend_performance']) > 1:
fastest_backend = min(
insights['backend_performance'].items(),
key=lambda x: x[1]['avg_duration']
)[0]
recommendations.append(
f"Consider routing more queries to {fastest_backend} "
f"for better performance"
)
insights['recommendations'] = recommendations
return insights

View file

@ -0,0 +1,656 @@
"""
Multi-language support for OntoRAG.
Provides language detection, translation, and multilingual query processing.
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class Language(Enum):
"""Supported languages."""
ENGLISH = "en"
SPANISH = "es"
FRENCH = "fr"
GERMAN = "de"
ITALIAN = "it"
PORTUGUESE = "pt"
CHINESE = "zh"
JAPANESE = "ja"
KOREAN = "ko"
ARABIC = "ar"
RUSSIAN = "ru"
DUTCH = "nl"
@dataclass
class LanguageDetectionResult:
"""Language detection result."""
language: Language
confidence: float
detected_text: str
alternative_languages: List[Tuple[Language, float]] = None
@dataclass
class TranslationResult:
"""Translation result."""
original_text: str
translated_text: str
source_language: Language
target_language: Language
confidence: float
class LanguageDetector:
"""Detects language of input text."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize language detector.
Args:
config: Detector configuration
"""
self.config = config or {}
self.default_language = Language(self.config.get('default_language', 'en'))
self.confidence_threshold = self.config.get('confidence_threshold', 0.7)
# Try to import language detection libraries
self.detector = None
self._init_detector()
def _init_detector(self):
"""Initialize language detection backend."""
try:
# Try langdetect first
import langdetect
self.detector = 'langdetect'
logger.info("Using langdetect for language detection")
except ImportError:
try:
# Try textblob as fallback
from textblob import TextBlob
self.detector = 'textblob'
logger.info("Using TextBlob for language detection")
except ImportError:
logger.warning("No language detection library available, using rule-based detection")
self.detector = 'rule_based'
def detect_language(self, text: str) -> LanguageDetectionResult:
"""Detect language of input text.
Args:
text: Text to analyze
Returns:
Language detection result
"""
if not text or not text.strip():
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
try:
if self.detector == 'langdetect':
return self._detect_with_langdetect(text)
elif self.detector == 'textblob':
return self._detect_with_textblob(text)
else:
return self._detect_with_rules(text)
except Exception as e:
logger.error(f"Language detection failed: {e}")
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_langdetect(self, text: str) -> LanguageDetectionResult:
"""Detect language using langdetect library."""
import langdetect
from langdetect.lang_detect_exception import LangDetectException
try:
# Get detailed detection results
probabilities = langdetect.detect_langs(text)
if not probabilities:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
best_match = probabilities[0]
detected_lang_code = best_match.lang
confidence = best_match.prob
# Map to our Language enum
try:
detected_language = Language(detected_lang_code)
except ValueError:
# Map common variations
lang_mapping = {
'ca': Language.SPANISH, # Catalan -> Spanish
'eu': Language.SPANISH, # Basque -> Spanish
'gl': Language.SPANISH, # Galician -> Spanish
'zh-cn': Language.CHINESE,
'zh-tw': Language.CHINESE,
}
detected_language = lang_mapping.get(detected_lang_code, self.default_language)
# Get alternatives
alternatives = []
for lang_prob in probabilities[1:3]: # Top 3 alternatives
try:
alt_lang = Language(lang_prob.lang)
alternatives.append((alt_lang, lang_prob.prob))
except ValueError:
continue
return LanguageDetectionResult(
language=detected_language,
confidence=confidence,
detected_text=text,
alternative_languages=alternatives
)
except LangDetectException:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_textblob(self, text: str) -> LanguageDetectionResult:
"""Detect language using TextBlob."""
from textblob import TextBlob
try:
blob = TextBlob(text)
detected_lang_code = blob.detect_language()
try:
detected_language = Language(detected_lang_code)
except ValueError:
detected_language = self.default_language
# TextBlob doesn't provide confidence, so estimate based on text length
confidence = min(0.8, len(text) / 100.0) if len(text) > 10 else 0.5
return LanguageDetectionResult(
language=detected_language,
confidence=confidence,
detected_text=text
)
except Exception:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_rules(self, text: str) -> LanguageDetectionResult:
"""Rule-based language detection fallback."""
text_lower = text.lower()
# Simple keyword-based detection
language_keywords = {
Language.SPANISH: ['qué', 'cuál', 'cuándo', 'dónde', 'cómo', 'por qué', 'cuántos'],
Language.FRENCH: ['que', 'quel', 'quand', '', 'comment', 'pourquoi', 'combien'],
Language.GERMAN: ['was', 'welche', 'wann', 'wo', 'wie', 'warum', 'wieviele'],
Language.ITALIAN: ['che', 'quale', 'quando', 'dove', 'come', 'perché', 'quanti'],
Language.PORTUGUESE: ['que', 'qual', 'quando', 'onde', 'como', 'por que', 'quantos'],
Language.DUTCH: ['wat', 'welke', 'wanneer', 'waar', 'hoe', 'waarom', 'hoeveel']
}
best_match = self.default_language
best_score = 0
for language, keywords in language_keywords.items():
score = sum(1 for keyword in keywords if keyword in text_lower)
if score > best_score:
best_score = score
best_match = language
confidence = min(0.8, best_score / 3.0) if best_score > 0 else 0.1
return LanguageDetectionResult(
language=best_match,
confidence=confidence,
detected_text=text
)
class TextTranslator:
"""Translates text between languages."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize text translator.
Args:
config: Translator configuration
"""
self.config = config or {}
self.translator = None
self._init_translator()
def _init_translator(self):
"""Initialize translation backend."""
try:
# Try Google Translate first
from googletrans import Translator
self.translator = Translator()
self.backend = 'googletrans'
logger.info("Using Google Translate for translation")
except ImportError:
try:
# Try TextBlob as fallback
from textblob import TextBlob
self.backend = 'textblob'
logger.info("Using TextBlob for translation")
except ImportError:
logger.warning("No translation library available")
self.backend = None
def translate(self,
text: str,
target_language: Language,
source_language: Optional[Language] = None) -> TranslationResult:
"""Translate text to target language.
Args:
text: Text to translate
target_language: Target language
source_language: Source language (auto-detect if None)
Returns:
Translation result
"""
if not text or not text.strip():
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
try:
if self.backend == 'googletrans':
return self._translate_with_googletrans(text, target_language, source_language)
elif self.backend == 'textblob':
return self._translate_with_textblob(text, target_language, source_language)
else:
# No translation available
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
except Exception as e:
logger.error(f"Translation failed: {e}")
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
def _translate_with_googletrans(self,
text: str,
target_language: Language,
source_language: Optional[Language]) -> TranslationResult:
"""Translate using Google Translate."""
try:
src_code = source_language.value if source_language else 'auto'
dest_code = target_language.value
result = self.translator.translate(text, src=src_code, dest=dest_code)
detected_source = Language(result.src) if result.src != 'auto' else Language.ENGLISH
confidence = 0.9 # Google Translate is generally reliable
return TranslationResult(
original_text=text,
translated_text=result.text,
source_language=detected_source,
target_language=target_language,
confidence=confidence
)
except Exception as e:
logger.error(f"Google Translate error: {e}")
raise
def _translate_with_textblob(self,
text: str,
target_language: Language,
source_language: Optional[Language]) -> TranslationResult:
"""Translate using TextBlob."""
from textblob import TextBlob
try:
blob = TextBlob(text)
if not source_language:
# Auto-detect source language
detected_lang = blob.detect_language()
try:
source_language = Language(detected_lang)
except ValueError:
source_language = Language.ENGLISH
translated_blob = blob.translate(to=target_language.value)
translated_text = str(translated_blob)
# TextBlob confidence estimation
confidence = 0.7 if len(text) > 10 else 0.5
return TranslationResult(
original_text=text,
translated_text=translated_text,
source_language=source_language,
target_language=target_language,
confidence=confidence
)
except Exception as e:
logger.error(f"TextBlob translation error: {e}")
raise
class MultiLanguageQueryProcessor:
"""Processes queries in multiple languages."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize multi-language query processor.
Args:
config: Processor configuration
"""
self.config = config or {}
self.language_detector = LanguageDetector(config.get('language_detection', {}))
self.translator = TextTranslator(config.get('translation', {}))
self.supported_languages = [Language(lang) for lang in config.get('supported_languages', ['en'])]
self.primary_language = Language(config.get('primary_language', 'en'))
async def process_multilingual_query(self, question: str) -> Dict[str, Any]:
"""Process a query in any supported language.
Args:
question: Question in any language
Returns:
Processing result with language information
"""
# Step 1: Detect language
detection_result = self.language_detector.detect_language(question)
detected_language = detection_result.language
logger.info(f"Detected language: {detected_language.value} "
f"(confidence: {detection_result.confidence:.2f})")
# Step 2: Translate to primary language if needed
translated_question = question
translation_result = None
if detected_language != self.primary_language:
if detection_result.confidence >= self.language_detector.confidence_threshold:
translation_result = self.translator.translate(
question, self.primary_language, detected_language
)
translated_question = translation_result.translated_text
logger.info(f"Translated question: {translated_question}")
else:
logger.warning(f"Low confidence language detection, processing in {self.primary_language.value}")
# Step 3: Return processing information
return {
'original_question': question,
'translated_question': translated_question,
'detected_language': detected_language,
'detection_confidence': detection_result.confidence,
'translation_result': translation_result,
'processing_language': self.primary_language,
'alternative_languages': detection_result.alternative_languages
}
async def translate_answer(self,
answer: str,
target_language: Language) -> TranslationResult:
"""Translate answer back to target language.
Args:
answer: Answer in primary language
target_language: Target language for answer
Returns:
Translation result
"""
if target_language == self.primary_language:
# No translation needed
return TranslationResult(
original_text=answer,
translated_text=answer,
source_language=self.primary_language,
target_language=target_language,
confidence=1.0
)
return self.translator.translate(answer, target_language, self.primary_language)
def get_language_specific_ontology_terms(self,
ontology_subset: Dict[str, Any],
language: Language) -> Dict[str, Any]:
"""Get language-specific terms from ontology.
Args:
ontology_subset: Ontology subset
language: Target language
Returns:
Language-specific ontology terms
"""
# Extract language-specific labels and descriptions
lang_code = language.value
result = {}
# Process classes
if 'classes' in ontology_subset:
result['classes'] = {}
for class_id, class_def in ontology_subset['classes'].items():
lang_labels = []
if 'labels' in class_def:
for label in class_def['labels']:
if isinstance(label, dict) and label.get('language') == lang_code:
lang_labels.append(label['value'])
elif isinstance(label, str):
lang_labels.append(label)
result['classes'][class_id] = {
**class_def,
'language_labels': lang_labels
}
# Process properties
for prop_type in ['object_properties', 'datatype_properties']:
if prop_type in ontology_subset:
result[prop_type] = {}
for prop_id, prop_def in ontology_subset[prop_type].items():
lang_labels = []
if 'labels' in prop_def:
for label in prop_def['labels']:
if isinstance(label, dict) and label.get('language') == lang_code:
lang_labels.append(label['value'])
elif isinstance(label, str):
lang_labels.append(label)
result[prop_type][prop_id] = {
**prop_def,
'language_labels': lang_labels
}
return result
def is_language_supported(self, language: Language) -> bool:
"""Check if language is supported.
Args:
language: Language to check
Returns:
True if language is supported
"""
return language in self.supported_languages
def get_supported_languages(self) -> List[Language]:
"""Get list of supported languages.
Returns:
List of supported languages
"""
return self.supported_languages.copy()
def add_language_support(self, language: Language):
"""Add support for a new language.
Args:
language: Language to add support for
"""
if language not in self.supported_languages:
self.supported_languages.append(language)
logger.info(f"Added support for language: {language.value}")
def remove_language_support(self, language: Language):
"""Remove support for a language.
Args:
language: Language to remove support for
"""
if language in self.supported_languages and language != self.primary_language:
self.supported_languages.remove(language)
logger.info(f"Removed support for language: {language.value}")
else:
logger.warning(f"Cannot remove primary language or unsupported language: {language.value}")
class LanguageSpecificTemplates:
"""Manages language-specific query and answer templates."""
def __init__(self):
"""Initialize language-specific templates."""
self.question_templates = {
Language.ENGLISH: {
'count': ['how many', 'count of', 'number of'],
'boolean': ['is', 'are', 'does', 'can', 'will'],
'retrieval': ['what', 'which', 'who', 'where'],
'factual': ['tell me about', 'describe', 'explain']
},
Language.SPANISH: {
'count': ['cuántos', 'cuántas', 'número de', 'cantidad de'],
'boolean': ['es', 'son', 'está', 'están', 'puede', 'pueden'],
'retrieval': ['qué', 'cuál', 'cuáles', 'quién', 'dónde'],
'factual': ['dime sobre', 'describe', 'explica']
},
Language.FRENCH: {
'count': ['combien', 'nombre de', 'quantité de'],
'boolean': ['est', 'sont', 'peut', 'peuvent'],
'retrieval': ['que', 'quel', 'quelle', 'qui', ''],
'factual': ['dis-moi sur', 'décris', 'explique']
},
Language.GERMAN: {
'count': ['wie viele', 'anzahl der', 'zahl der'],
'boolean': ['ist', 'sind', 'kann', 'können'],
'retrieval': ['was', 'welche', 'wer', 'wo'],
'factual': ['erzähl mir über', 'beschreibe', 'erkläre']
}
}
self.answer_templates = {
Language.ENGLISH: {
'count': 'There are {count} {entity}.',
'boolean_true': 'Yes, {statement}.',
'boolean_false': 'No, {statement}.',
'not_found': 'No information found.',
'error': 'Sorry, I encountered an error.'
},
Language.SPANISH: {
'count': 'Hay {count} {entity}.',
'boolean_true': 'Sí, {statement}.',
'boolean_false': 'No, {statement}.',
'not_found': 'No se encontró información.',
'error': 'Lo siento, encontré un error.'
},
Language.FRENCH: {
'count': 'Il y a {count} {entity}.',
'boolean_true': 'Oui, {statement}.',
'boolean_false': 'Non, {statement}.',
'not_found': 'Aucune information trouvée.',
'error': 'Désolé, j\'ai rencontré une erreur.'
},
Language.GERMAN: {
'count': 'Es gibt {count} {entity}.',
'boolean_true': 'Ja, {statement}.',
'boolean_false': 'Nein, {statement}.',
'not_found': 'Keine Informationen gefunden.',
'error': 'Entschuldigung, ich bin auf einen Fehler gestoßen.'
}
}
def get_question_patterns(self, language: Language) -> Dict[str, List[str]]:
"""Get question patterns for a language.
Args:
language: Target language
Returns:
Dictionary of question patterns
"""
return self.question_templates.get(language, self.question_templates[Language.ENGLISH])
def get_answer_template(self, language: Language, template_type: str) -> str:
"""Get answer template for a language and type.
Args:
language: Target language
template_type: Template type
Returns:
Answer template string
"""
templates = self.answer_templates.get(language, self.answer_templates[Language.ENGLISH])
return templates.get(template_type, templates.get('error', 'Error'))
def format_answer(self,
language: Language,
template_type: str,
**kwargs) -> str:
"""Format answer using language-specific template.
Args:
language: Target language
template_type: Template type
**kwargs: Template variables
Returns:
Formatted answer
"""
template = self.get_answer_template(language, template_type)
try:
return template.format(**kwargs)
except KeyError as e:
logger.error(f"Missing template variable: {e}")
return self.get_answer_template(language, 'error')

View file

@ -0,0 +1,256 @@
"""
Ontology matcher for query system.
Identifies relevant ontology subsets for answering questions.
"""
import logging
from typing import List, Dict, Any, Set, Optional
from dataclasses import dataclass
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder
from ...extract.kg.ontology.text_processor import TextSegment
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
from .question_analyzer import QuestionComponents, QuestionType
logger = logging.getLogger(__name__)
@dataclass
class QueryOntologySubset(OntologySubset):
"""Extended ontology subset for query processing."""
traversal_properties: Dict[str, Any] = None # Additional properties for graph traversal
inference_rules: List[Dict[str, Any]] = None # Inference rules for reasoning
class OntologyMatcherForQueries(OntologySelector):
"""
Specialized ontology matcher for question answering.
Extends OntologySelector with query-specific logic.
"""
def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader,
top_k: int = 15, # Higher k for queries
similarity_threshold: float = 0.6): # Lower threshold for broader coverage
"""Initialize query-specific ontology matcher.
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve
similarity_threshold: Minimum similarity score
"""
super().__init__(ontology_embedder, ontology_loader, top_k, similarity_threshold)
async def match_question_to_ontology(self,
question_components: QuestionComponents,
question_segments: List[str]) -> List[QueryOntologySubset]:
"""Match question components to relevant ontology elements.
Args:
question_components: Analyzed question components
question_segments: Text segments from question
Returns:
List of query-optimized ontology subsets
"""
# Convert question segments to TextSegment objects
text_segments = [
TextSegment(text=seg, type='question', position=i)
for i, seg in enumerate(question_segments)
]
# Get base ontology subsets using parent class method
base_subsets = await self.select_ontology_subset(text_segments)
# Enhance subsets for query processing
query_subsets = []
for subset in base_subsets:
query_subset = self._enhance_for_query(subset, question_components)
query_subsets.append(query_subset)
return query_subsets
def _enhance_for_query(self, subset: OntologySubset,
question_components: QuestionComponents) -> QueryOntologySubset:
"""Enhance ontology subset with query-specific elements.
Args:
subset: Base ontology subset
question_components: Analyzed question components
Returns:
Enhanced query ontology subset
"""
# Create query subset
query_subset = QueryOntologySubset(
ontology_id=subset.ontology_id,
classes=dict(subset.classes),
object_properties=dict(subset.object_properties),
datatype_properties=dict(subset.datatype_properties),
metadata=subset.metadata,
relevance_score=subset.relevance_score,
traversal_properties={},
inference_rules=[]
)
# Add traversal properties based on question type
self._add_traversal_properties(query_subset, question_components)
# Add related properties for exploration
self._add_related_properties(query_subset)
# Add inference rules if needed
self._add_inference_rules(query_subset, question_components)
return query_subset
def _add_traversal_properties(self, subset: QueryOntologySubset,
question_components: QuestionComponents):
"""Add properties useful for graph traversal.
Args:
subset: Query ontology subset to enhance
question_components: Question analysis
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# For relationship questions, add all properties connecting mentioned classes
if question_components.question_type == QuestionType.RELATIONSHIP:
for prop_id, prop_def in ontology.object_properties.items():
domain = prop_def.domain
range_val = prop_def.range
# Check if property connects relevant classes
if domain in subset.classes or range_val in subset.classes:
if prop_id not in subset.object_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
logger.debug(f"Added traversal property: {prop_id}")
# For retrieval questions, add properties that might filter results
elif question_components.question_type == QuestionType.RETRIEVAL:
# Add all properties with domains in our classes
for prop_id, prop_def in ontology.object_properties.items():
if prop_def.domain in subset.classes:
if prop_id not in subset.object_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
for prop_id, prop_def in ontology.datatype_properties.items():
if prop_def.domain in subset.classes:
if prop_id not in subset.datatype_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
# For aggregation questions, ensure we have counting properties
elif question_components.question_type == QuestionType.AGGREGATION:
# Add properties that might be counted
for prop_id, prop_def in ontology.datatype_properties.items():
if 'count' in prop_id.lower() or 'number' in prop_id.lower():
if prop_id not in subset.datatype_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
def _add_related_properties(self, subset: QueryOntologySubset):
"""Add properties related to already selected ones.
Args:
subset: Query ontology subset to enhance
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# Add inverse properties
for prop_id in list(subset.object_properties.keys()):
prop = ontology.object_properties.get(prop_id)
if prop and prop.inverse_of:
inverse_prop = ontology.object_properties.get(prop.inverse_of)
if inverse_prop and prop.inverse_of not in subset.object_properties:
subset.object_properties[prop.inverse_of] = inverse_prop.__dict__
logger.debug(f"Added inverse property: {prop.inverse_of}")
# Add sibling properties (same domain)
domains_in_subset = set()
for prop_def in subset.object_properties.values():
if 'domain' in prop_def and prop_def['domain']:
domains_in_subset.add(prop_def['domain'])
for domain in domains_in_subset:
for prop_id, prop_def in ontology.object_properties.items():
if prop_def.domain == domain and prop_id not in subset.object_properties:
# Add up to 3 sibling properties
if len(subset.traversal_properties) < 3:
subset.traversal_properties[prop_id] = prop_def.__dict__
def _add_inference_rules(self, subset: QueryOntologySubset,
question_components: QuestionComponents):
"""Add inference rules for reasoning.
Args:
subset: Query ontology subset to enhance
question_components: Question analysis
"""
# Add transitivity rules for subclass relationships
if any(cls.get('subclass_of') for cls in subset.classes.values()):
subset.inference_rules.append({
'type': 'transitivity',
'property': 'rdfs:subClassOf',
'description': 'Subclass relationships are transitive'
})
# Add symmetry rules for equivalent classes
if any(cls.get('equivalent_classes') for cls in subset.classes.values()):
subset.inference_rules.append({
'type': 'symmetry',
'property': 'owl:equivalentClass',
'description': 'Equivalent class relationships are symmetric'
})
# Add inverse property rules
for prop_id, prop_def in subset.object_properties.items():
if 'inverse_of' in prop_def and prop_def['inverse_of']:
subset.inference_rules.append({
'type': 'inverse',
'property': prop_id,
'inverse': prop_def['inverse_of'],
'description': f'{prop_id} is inverse of {prop_def["inverse_of"]}'
})
def expand_for_hierarchical_queries(self, subset: QueryOntologySubset) -> QueryOntologySubset:
"""Expand subset to include full class hierarchies.
Args:
subset: Query ontology subset
Returns:
Expanded subset with complete hierarchies
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return subset
# Add all parent and child classes
classes_to_add = set()
for class_id in list(subset.classes.keys()):
# Add all parents
parents = ontology.get_parent_classes(class_id)
for parent_id in parents:
if parent_id not in subset.classes:
parent_class = ontology.get_class(parent_id)
if parent_class:
classes_to_add.add(parent_id)
# Add all children
for other_class_id, other_class in ontology.classes.items():
if other_class.subclass_of == class_id and other_class_id not in subset.classes:
classes_to_add.add(other_class_id)
# Add collected classes
for class_id in classes_to_add:
cls = ontology.get_class(class_id)
if cls:
subset.classes[class_id] = cls.__dict__
logger.debug(f"Expanded hierarchy: added {len(classes_to_add)} classes")
return subset

View file

@ -0,0 +1,640 @@
"""
Query explanation system for OntoRAG.
Provides detailed explanations of how queries are processed and results are derived.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_generator import SPARQLQuery
from .cypher_generator import CypherQuery
from .sparql_cassandra import SPARQLResult
from .cypher_executor import CypherResult
logger = logging.getLogger(__name__)
@dataclass
class ExplanationStep:
"""Individual step in query explanation."""
step_number: int
component: str
operation: str
input_data: Dict[str, Any]
output_data: Dict[str, Any]
explanation: str
duration_ms: float
success: bool
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class QueryExplanation:
"""Complete explanation of query processing."""
query_id: str
original_question: str
processing_steps: List[ExplanationStep]
final_answer: str
confidence_score: float
total_duration_ms: float
ontologies_used: List[str]
backend_used: str
reasoning_chain: List[str]
technical_details: Dict[str, Any]
user_friendly_explanation: str
class QueryExplainer:
"""Generates explanations for query processing."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize query explainer.
Args:
config: Explainer configuration
"""
self.config = config or {}
self.explanation_level = self.config.get('explanation_level', 'detailed') # basic, detailed, technical
self.include_technical_details = self.config.get('include_technical_details', True)
self.max_reasoning_steps = self.config.get('max_reasoning_steps', 10)
# Templates for different explanation types
self.step_templates = {
'question_analysis': {
'basic': "I analyzed your question to understand what you're asking.",
'detailed': "I analyzed your question '{question}' and identified it as a {question_type} query about {entities}.",
'technical': "Question analysis: Type={question_type}, Entities={entities}, Keywords={keywords}, Expected answer={answer_type}"
},
'ontology_matching': {
'basic': "I found relevant knowledge about {entities} in the available ontologies.",
'detailed': "I searched through {ontology_count} ontologies and found {selected_elements} relevant concepts related to your question.",
'technical': "Ontology matching: Selected {classes} classes, {properties} properties from {ontologies}"
},
'query_generation': {
'basic': "I generated a query to search for the information.",
'detailed': "I created a {query_type} query using {query_language} to search the {backend} database.",
'technical': "Query generation: {query_language} query with {variables} variables, complexity score {complexity}"
},
'query_execution': {
'basic': "I searched the database and found {result_count} results.",
'detailed': "I executed the query against the {backend} database and retrieved {result_count} results in {duration}ms.",
'technical': "Query execution: {backend} backend, {result_count} results, execution time {duration}ms"
},
'answer_generation': {
'basic': "I generated a natural language answer from the results.",
'detailed': "I processed {result_count} results and generated an answer with {confidence}% confidence.",
'technical': "Answer generation: {result_count} input results, {generation_method} method, confidence {confidence}"
}
}
self.reasoning_templates = {
'entity_identification': "I identified '{entity}' as a key concept in your question.",
'ontology_selection': "I selected the '{ontology}' ontology because it contains relevant information about {concepts}.",
'query_strategy': "I chose a {strategy} query approach because {reason}.",
'result_filtering': "I filtered the results to show only the most relevant {count} items.",
'confidence_assessment': "I'm {confidence}% confident in this answer because {reasoning}."
}
def explain_query_processing(self,
question: str,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
generated_query: Union[SPARQLQuery, CypherQuery],
query_results: Union[SPARQLResult, CypherResult],
final_answer: str,
processing_metadata: Dict[str, Any]) -> QueryExplanation:
"""Generate comprehensive explanation of query processing.
Args:
question: Original question
question_components: Analyzed question components
ontology_subsets: Selected ontology subsets
generated_query: Generated query
query_results: Query execution results
final_answer: Final generated answer
processing_metadata: Processing metadata
Returns:
Complete query explanation
"""
query_id = processing_metadata.get('query_id', f"query_{datetime.now().timestamp()}")
start_time = processing_metadata.get('start_time', datetime.now())
# Build explanation steps
steps = []
step_number = 1
# Step 1: Question Analysis
steps.append(self._explain_question_analysis(
step_number, question, question_components
))
step_number += 1
# Step 2: Ontology Matching
steps.append(self._explain_ontology_matching(
step_number, question_components, ontology_subsets
))
step_number += 1
# Step 3: Query Generation
steps.append(self._explain_query_generation(
step_number, generated_query, processing_metadata
))
step_number += 1
# Step 4: Query Execution
steps.append(self._explain_query_execution(
step_number, generated_query, query_results, processing_metadata
))
step_number += 1
# Step 5: Answer Generation
steps.append(self._explain_answer_generation(
step_number, query_results, final_answer, processing_metadata
))
# Build reasoning chain
reasoning_chain = self._build_reasoning_chain(
question_components, ontology_subsets, generated_query, processing_metadata
)
# Calculate overall confidence
confidence_score = self._calculate_explanation_confidence(
question_components, query_results, processing_metadata
)
# Generate user-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer
)
# Calculate total duration
total_duration = processing_metadata.get('total_duration_ms', 0)
return QueryExplanation(
query_id=query_id,
original_question=question,
processing_steps=steps,
final_answer=final_answer,
confidence_score=confidence_score,
total_duration_ms=total_duration,
ontologies_used=[subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets],
backend_used=processing_metadata.get('backend_used', 'unknown'),
reasoning_chain=reasoning_chain,
technical_details=self._extract_technical_details(processing_metadata),
user_friendly_explanation=user_friendly_explanation
)
def _explain_question_analysis(self,
step_number: int,
question: str,
question_components: QuestionComponents) -> ExplanationStep:
"""Explain question analysis step."""
template = self.step_templates['question_analysis'][self.explanation_level]
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
question=question,
question_type=question_components.question_type.value.replace('_', ' '),
entities=', '.join(question_components.entities[:3])
)
else: # technical
explanation = template.format(
question_type=question_components.question_type.value,
entities=question_components.entities,
keywords=question_components.keywords,
answer_type=question_components.expected_answer_type
)
return ExplanationStep(
step_number=step_number,
component="question_analyzer",
operation="analyze_question",
input_data={"question": question},
output_data={
"question_type": question_components.question_type.value,
"entities": question_components.entities,
"keywords": question_components.keywords
},
explanation=explanation,
duration_ms=0.0, # Would be tracked in actual implementation
success=True
)
def _explain_ontology_matching(self,
step_number: int,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> ExplanationStep:
"""Explain ontology matching step."""
template = self.step_templates['ontology_matching'][self.explanation_level]
total_elements = sum(
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
if self.explanation_level == 'basic':
explanation = template.format(
entities=', '.join(question_components.entities[:3])
)
elif self.explanation_level == 'detailed':
explanation = template.format(
ontology_count=len(ontology_subsets),
selected_elements=total_elements
)
else: # technical
total_classes = sum(len(subset.classes) for subset in ontology_subsets)
total_properties = sum(
len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
ontology_names = [subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets]
explanation = template.format(
classes=total_classes,
properties=total_properties,
ontologies=', '.join(ontology_names)
)
return ExplanationStep(
step_number=step_number,
component="ontology_matcher",
operation="select_relevant_subset",
input_data={"entities": question_components.entities},
output_data={
"ontology_count": len(ontology_subsets),
"total_elements": total_elements
},
explanation=explanation,
duration_ms=0.0,
success=True
)
def _explain_query_generation(self,
step_number: int,
generated_query: Union[SPARQLQuery, CypherQuery],
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain query generation step."""
template = self.step_templates['query_generation'][self.explanation_level]
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
backend = metadata.get('backend_used', 'unknown')
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
query_type=generated_query.query_type,
query_language=query_language,
backend=backend
)
else: # technical
explanation = template.format(
query_language=query_language,
variables=len(generated_query.variables),
complexity=f"{generated_query.complexity_score:.2f}"
)
return ExplanationStep(
step_number=step_number,
component="query_generator",
operation="generate_query",
input_data={"query_type": generated_query.query_type},
output_data={
"query_language": query_language,
"variables": generated_query.variables,
"complexity": generated_query.complexity_score
},
explanation=explanation,
duration_ms=0.0,
success=True,
metadata={"generated_query": generated_query.query}
)
def _explain_query_execution(self,
step_number: int,
generated_query: Union[SPARQLQuery, CypherQuery],
query_results: Union[SPARQLResult, CypherResult],
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain query execution step."""
template = self.step_templates['query_execution'][self.explanation_level]
backend = metadata.get('backend_used', 'unknown')
duration = getattr(query_results, 'execution_time', 0) * 1000 # Convert to ms
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
if self.explanation_level == 'basic':
explanation = template.format(result_count=result_count)
elif self.explanation_level == 'detailed':
explanation = template.format(
backend=backend,
result_count=result_count,
duration=f"{duration:.1f}"
)
else: # technical
explanation = template.format(
backend=backend,
result_count=result_count,
duration=f"{duration:.1f}"
)
return ExplanationStep(
step_number=step_number,
component="query_executor",
operation="execute_query",
input_data={"query": generated_query.query},
output_data={
"result_count": result_count,
"execution_time_ms": duration
},
explanation=explanation,
duration_ms=duration,
success=result_count >= 0
)
def _explain_answer_generation(self,
step_number: int,
query_results: Union[SPARQLResult, CypherResult],
final_answer: str,
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain answer generation step."""
template = self.step_templates['answer_generation'][self.explanation_level]
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
confidence = metadata.get('answer_confidence', 0.8) * 100 # Convert to percentage
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
result_count=result_count,
confidence=f"{confidence:.0f}"
)
else: # technical
generation_method = metadata.get('generation_method', 'template_based')
explanation = template.format(
result_count=result_count,
generation_method=generation_method,
confidence=f"{confidence:.1f}"
)
return ExplanationStep(
step_number=step_number,
component="answer_generator",
operation="generate_answer",
input_data={"result_count": result_count},
output_data={
"answer": final_answer,
"confidence": confidence / 100
},
explanation=explanation,
duration_ms=0.0,
success=bool(final_answer)
)
def _build_reasoning_chain(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
generated_query: Union[SPARQLQuery, CypherQuery],
metadata: Dict[str, Any]) -> List[str]:
"""Build reasoning chain explaining the decision process."""
reasoning = []
# Entity identification reasoning
if question_components.entities:
for entity in question_components.entities[:3]:
reasoning.append(
self.reasoning_templates['entity_identification'].format(entity=entity)
)
# Ontology selection reasoning
if ontology_subsets:
primary_ontology = ontology_subsets[0]
ontology_id = primary_ontology.metadata.get('ontology_id', 'primary')
concepts = list(primary_ontology.classes.keys())[:3]
reasoning.append(
self.reasoning_templates['ontology_selection'].format(
ontology=ontology_id,
concepts=', '.join(concepts)
)
)
# Query strategy reasoning
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
if question_components.question_type == QuestionType.AGGREGATION:
strategy = "aggregation"
reason = "you asked for a count or sum"
elif question_components.question_type == QuestionType.BOOLEAN:
strategy = "boolean"
reason = "you asked a yes/no question"
else:
strategy = "retrieval"
reason = "you asked for specific information"
reasoning.append(
self.reasoning_templates['query_strategy'].format(
strategy=strategy,
reason=reason
)
)
# Confidence assessment
confidence = metadata.get('answer_confidence', 0.8) * 100
if confidence > 90:
confidence_reason = "the query matched well with available data"
elif confidence > 70:
confidence_reason = "the query found relevant information with some uncertainty"
else:
confidence_reason = "the available data partially matches your question"
reasoning.append(
self.reasoning_templates['confidence_assessment'].format(
confidence=f"{confidence:.0f}",
reasoning=confidence_reason
)
)
return reasoning[:self.max_reasoning_steps]
def _calculate_explanation_confidence(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
metadata: Dict[str, Any]) -> float:
"""Calculate confidence score for the explanation."""
confidence = 0.8 # Base confidence
# Adjust based on result count
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else:
result_count = len(query_results.records)
if result_count > 0:
confidence += 0.1
if result_count > 5:
confidence += 0.05
# Adjust based on question complexity
if len(question_components.entities) > 0:
confidence += 0.05
# Adjust based on processing success
if metadata.get('all_steps_successful', True):
confidence += 0.05
return min(confidence, 1.0)
def _generate_user_friendly_explanation(self,
question: str,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str:
"""Generate user-friendly explanation of the process."""
explanation_parts = []
# Introduction
explanation_parts.append(f"To answer your question '{question}', I followed these steps:")
# Process summary
if question_components.question_type == QuestionType.AGGREGATION:
explanation_parts.append("1. I recognized this as a counting or aggregation question")
elif question_components.question_type == QuestionType.BOOLEAN:
explanation_parts.append("1. I recognized this as a yes/no question")
else:
explanation_parts.append("1. I analyzed your question to understand what information you need")
# Ontology usage
if ontology_subsets:
ontology_count = len(ontology_subsets)
if ontology_count == 1:
explanation_parts.append("2. I searched through the relevant knowledge base")
else:
explanation_parts.append(f"2. I searched through {ontology_count} knowledge bases")
# Result processing
explanation_parts.append("3. I found the relevant information and generated your answer")
# Conclusion
explanation_parts.append(f"The answer is: {final_answer}")
return " ".join(explanation_parts)
def _extract_technical_details(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Extract technical details for debugging and optimization."""
return {
'query_optimization': metadata.get('query_optimization', {}),
'backend_performance': metadata.get('backend_performance', {}),
'cache_usage': metadata.get('cache_usage', {}),
'error_handling': metadata.get('error_handling', {}),
'routing_decision': metadata.get('routing_decision', {})
}
def format_explanation_for_display(self,
explanation: QueryExplanation,
format_type: str = 'html') -> str:
"""Format explanation for display.
Args:
explanation: Query explanation
format_type: Output format ('html', 'markdown', 'text')
Returns:
Formatted explanation
"""
if format_type == 'html':
return self._format_html_explanation(explanation)
elif format_type == 'markdown':
return self._format_markdown_explanation(explanation)
else:
return self._format_text_explanation(explanation)
def _format_html_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as HTML."""
html_parts = [
f"<h2>Query Explanation: {explanation.query_id}</h2>",
f"<p><strong>Question:</strong> {explanation.original_question}</p>",
f"<p><strong>Answer:</strong> {explanation.final_answer}</p>",
f"<p><strong>Confidence:</strong> {explanation.confidence_score:.1%}</p>",
"<h3>Processing Steps:</h3>",
"<ol>"
]
for step in explanation.processing_steps:
html_parts.append(f"<li><strong>{step.component}</strong>: {step.explanation}</li>")
html_parts.extend([
"</ol>",
"<h3>Reasoning:</h3>",
"<ul>"
])
for reasoning in explanation.reasoning_chain:
html_parts.append(f"<li>{reasoning}</li>")
html_parts.append("</ul>")
return "".join(html_parts)
def _format_markdown_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as Markdown."""
md_parts = [
f"## Query Explanation: {explanation.query_id}",
f"**Question:** {explanation.original_question}",
f"**Answer:** {explanation.final_answer}",
f"**Confidence:** {explanation.confidence_score:.1%}",
"",
"### Processing Steps:",
""
]
for i, step in enumerate(explanation.processing_steps, 1):
md_parts.append(f"{i}. **{step.component}**: {step.explanation}")
md_parts.extend([
"",
"### Reasoning:",
""
])
for reasoning in explanation.reasoning_chain:
md_parts.append(f"- {reasoning}")
return "\n".join(md_parts)
def _format_text_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as plain text."""
text_parts = [
f"Query Explanation: {explanation.query_id}",
f"Question: {explanation.original_question}",
f"Answer: {explanation.final_answer}",
f"Confidence: {explanation.confidence_score:.1%}",
"",
"Processing Steps:",
]
for i, step in enumerate(explanation.processing_steps, 1):
text_parts.append(f" {i}. {step.component}: {step.explanation}")
text_parts.extend([
"",
"Reasoning:",
])
for reasoning in explanation.reasoning_chain:
text_parts.append(f" - {reasoning}")
return "\n".join(text_parts)

View file

@ -0,0 +1,519 @@
"""
Query optimization module for OntoRAG.
Optimizes SPARQL and Cypher queries for better performance and accuracy.
"""
import logging
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass
from enum import Enum
import re
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_generator import SPARQLQuery
from .cypher_generator import CypherQuery
logger = logging.getLogger(__name__)
class OptimizationStrategy(Enum):
"""Query optimization strategies."""
PERFORMANCE = "performance"
ACCURACY = "accuracy"
BALANCED = "balanced"
@dataclass
class OptimizationHint:
"""Optimization hint for query processing."""
strategy: OptimizationStrategy
max_results: Optional[int] = None
timeout_seconds: Optional[int] = None
use_indices: bool = True
enable_parallel: bool = False
cache_results: bool = True
@dataclass
class QueryPlan:
"""Query execution plan with optimization metadata."""
original_query: str
optimized_query: str
estimated_cost: float
optimization_notes: List[str]
index_hints: List[str]
execution_order: List[str]
class QueryOptimizer:
"""Optimizes SPARQL and Cypher queries for performance and accuracy."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize query optimizer.
Args:
config: Optimizer configuration
"""
self.config = config or {}
self.default_strategy = OptimizationStrategy(
self.config.get('default_strategy', 'balanced')
)
self.max_query_complexity = self.config.get('max_query_complexity', 10)
self.enable_query_rewriting = self.config.get('enable_query_rewriting', True)
# Performance thresholds
self.large_result_threshold = self.config.get('large_result_threshold', 1000)
self.complex_join_threshold = self.config.get('complex_join_threshold', 3)
def optimize_sparql(self,
sparql_query: SPARQLQuery,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[SPARQLQuery, QueryPlan]:
"""Optimize SPARQL query.
Args:
sparql_query: Original SPARQL query
question_components: Question analysis
ontology_subset: Ontology subset
optimization_hint: Optimization hints
Returns:
Optimized SPARQL query and execution plan
"""
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
optimized_query = sparql_query.query
optimization_notes = []
index_hints = []
execution_order = []
# Apply optimizations based on strategy
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
optimized_query, perf_notes, perf_hints = self._optimize_sparql_performance(
optimized_query, question_components, ontology_subset, hint
)
optimization_notes.extend(perf_notes)
index_hints.extend(perf_hints)
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
optimized_query, acc_notes = self._optimize_sparql_accuracy(
optimized_query, question_components, ontology_subset
)
optimization_notes.extend(acc_notes)
# Estimate query cost
estimated_cost = self._estimate_sparql_cost(optimized_query, ontology_subset)
# Build execution plan
query_plan = QueryPlan(
original_query=sparql_query.query,
optimized_query=optimized_query,
estimated_cost=estimated_cost,
optimization_notes=optimization_notes,
index_hints=index_hints,
execution_order=execution_order
)
# Create optimized query object
optimized_sparql = SPARQLQuery(
query=optimized_query,
variables=sparql_query.variables,
query_type=sparql_query.query_type,
explanation=f"Optimized: {sparql_query.explanation}",
complexity_score=min(sparql_query.complexity_score * 0.8, 1.0) # Assume optimization reduces complexity
)
return optimized_sparql, query_plan
def optimize_cypher(self,
cypher_query: CypherQuery,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[CypherQuery, QueryPlan]:
"""Optimize Cypher query.
Args:
cypher_query: Original Cypher query
question_components: Question analysis
ontology_subset: Ontology subset
optimization_hint: Optimization hints
Returns:
Optimized Cypher query and execution plan
"""
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
optimized_query = cypher_query.query
optimization_notes = []
index_hints = []
execution_order = []
# Apply optimizations based on strategy
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
optimized_query, perf_notes, perf_hints = self._optimize_cypher_performance(
optimized_query, question_components, ontology_subset, hint
)
optimization_notes.extend(perf_notes)
index_hints.extend(perf_hints)
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
optimized_query, acc_notes = self._optimize_cypher_accuracy(
optimized_query, question_components, ontology_subset
)
optimization_notes.extend(acc_notes)
# Estimate query cost
estimated_cost = self._estimate_cypher_cost(optimized_query, ontology_subset)
# Build execution plan
query_plan = QueryPlan(
original_query=cypher_query.query,
optimized_query=optimized_query,
estimated_cost=estimated_cost,
optimization_notes=optimization_notes,
index_hints=index_hints,
execution_order=execution_order
)
# Create optimized query object
optimized_cypher = CypherQuery(
query=optimized_query,
variables=cypher_query.variables,
query_type=cypher_query.query_type,
explanation=f"Optimized: {cypher_query.explanation}",
complexity_score=min(cypher_query.complexity_score * 0.8, 1.0)
)
return optimized_cypher, query_plan
def _optimize_sparql_performance(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
"""Apply performance optimizations to SPARQL query.
Args:
query: SPARQL query string
question_components: Question analysis
ontology_subset: Ontology subset
hint: Optimization hints
Returns:
Optimized query, optimization notes, and index hints
"""
optimized = query
notes = []
index_hints = []
# Add LIMIT if not present and large results expected
if hint.max_results and 'LIMIT' not in optimized.upper():
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
# Optimize OPTIONAL clauses (move to end)
optional_pattern = re.compile(r'OPTIONAL\s*\{[^}]+\}', re.IGNORECASE | re.DOTALL)
optionals = optional_pattern.findall(optimized)
if optionals:
# Remove optionals from current position
for optional in optionals:
optimized = optimized.replace(optional, '')
# Add them at the end (before ORDER BY/LIMIT)
insert_point = optimized.rfind('ORDER BY')
if insert_point == -1:
insert_point = optimized.rfind('LIMIT')
if insert_point == -1:
insert_point = len(optimized.rstrip())
for optional in optionals:
optimized = optimized[:insert_point] + f"\n {optional}" + optimized[insert_point:]
notes.append("Moved OPTIONAL clauses to end for better performance")
# Add index hints for Cassandra
if 'WHERE' in optimized.upper():
# Suggest indices for common patterns
if '?subject rdf:type' in optimized:
index_hints.append("type_index")
if 'rdfs:subClassOf' in optimized:
index_hints.append("hierarchy_index")
# Optimize FILTER clauses (move closer to variable bindings)
filter_pattern = re.compile(r'FILTER\s*\([^)]+\)', re.IGNORECASE)
filters = filter_pattern.findall(optimized)
if filters:
notes.append("FILTER clauses present - ensure they're positioned optimally")
return optimized, notes, index_hints
def _optimize_sparql_accuracy(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
"""Apply accuracy optimizations to SPARQL query.
Args:
query: SPARQL query string
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Optimized query and optimization notes
"""
optimized = query
notes = []
# Add missing namespace checks
if question_components.question_type == QuestionType.RETRIEVAL:
# Ensure we're not mixing namespaces inappropriately
if 'http://' in optimized and '?' in optimized:
notes.append("Verified namespace consistency for accuracy")
# Add type constraints for better precision
if '?entity' in optimized and 'rdf:type' not in optimized:
# Find a good insertion point
where_clause = re.search(r'WHERE\s*\{(.+)\}', optimized, re.DOTALL | re.IGNORECASE)
if where_clause and ontology_subset.classes:
# Add type constraint for the most relevant class
main_class = list(ontology_subset.classes.keys())[0]
type_constraint = f"\n ?entity rdf:type :{main_class} ."
# Insert after the WHERE {
where_start = where_clause.start(1)
optimized = optimized[:where_start] + type_constraint + optimized[where_start:]
notes.append(f"Added type constraint for {main_class} to improve accuracy")
# Add DISTINCT if not present for retrieval queries
if (question_components.question_type == QuestionType.RETRIEVAL and
'DISTINCT' not in optimized.upper() and
'SELECT' in optimized.upper()):
optimized = optimized.replace('SELECT ', 'SELECT DISTINCT ', 1)
notes.append("Added DISTINCT to eliminate duplicate results")
return optimized, notes
def _optimize_cypher_performance(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
"""Apply performance optimizations to Cypher query.
Args:
query: Cypher query string
question_components: Question analysis
ontology_subset: Ontology subset
hint: Optimization hints
Returns:
Optimized query, optimization notes, and index hints
"""
optimized = query
notes = []
index_hints = []
# Add LIMIT if not present
if hint.max_results and 'LIMIT' not in optimized.upper():
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
# Use parameters for literals to enable query plan caching
if "'" in optimized or '"' in optimized:
notes.append("Consider using parameters for literal values to enable query plan caching")
# Suggest indices based on query patterns
if 'MATCH (n:' in optimized:
label_match = re.search(r'MATCH \(n:(\w+)\)', optimized)
if label_match:
label = label_match.group(1)
index_hints.append(f"node_label_index:{label}")
if 'WHERE' in optimized.upper() and '.' in optimized:
# Property access patterns
property_pattern = re.compile(r'\.(\w+)', re.IGNORECASE)
properties = property_pattern.findall(optimized)
for prop in set(properties):
index_hints.append(f"property_index:{prop}")
# Optimize relationship traversals
if '-[' in optimized and '*' in optimized:
notes.append("Variable length path detected - consider adding relationship type filters")
# Early filtering optimization
if 'WHERE' in optimized.upper():
# Move WHERE clauses closer to MATCH clauses
notes.append("WHERE clauses present - ensure early filtering for performance")
return optimized, notes, index_hints
def _optimize_cypher_accuracy(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
"""Apply accuracy optimizations to Cypher query.
Args:
query: Cypher query string
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Optimized query and optimization notes
"""
optimized = query
notes = []
# Add DISTINCT if not present for retrieval queries
if (question_components.question_type == QuestionType.RETRIEVAL and
'DISTINCT' not in optimized.upper() and
'RETURN' in optimized.upper()):
optimized = re.sub(r'RETURN\s+', 'RETURN DISTINCT ', optimized, count=1, flags=re.IGNORECASE)
notes.append("Added DISTINCT to eliminate duplicate results")
# Ensure proper relationship direction
if '-[' in optimized and question_components.relationships:
notes.append("Verified relationship directions for semantic accuracy")
# Add null checks for optional properties
if '?' in optimized or 'OPTIONAL' in optimized.upper():
notes.append("Consider adding null checks for optional properties")
return optimized, notes
def _estimate_sparql_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
"""Estimate execution cost for SPARQL query.
Args:
query: SPARQL query string
ontology_subset: Ontology subset
Returns:
Estimated cost (0.0 to 1.0)
"""
cost = 0.0
# Basic query complexity
cost += len(query.split('\n')) * 0.01
# Join complexity
triple_patterns = len(re.findall(r'\?\w+\s+\?\w+\s+\?\w+', query))
cost += triple_patterns * 0.1
# OPTIONAL clauses
optional_count = len(re.findall(r'OPTIONAL', query, re.IGNORECASE))
cost += optional_count * 0.15
# FILTER clauses
filter_count = len(re.findall(r'FILTER', query, re.IGNORECASE))
cost += filter_count * 0.1
# Property paths
path_count = len(re.findall(r'\*|\+', query))
cost += path_count * 0.2
# Ontology subset size impact
total_elements = (len(ontology_subset.classes) +
len(ontology_subset.object_properties) +
len(ontology_subset.datatype_properties))
cost += (total_elements / 100.0) * 0.1
return min(cost, 1.0)
def _estimate_cypher_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
"""Estimate execution cost for Cypher query.
Args:
query: Cypher query string
ontology_subset: Ontology subset
Returns:
Estimated cost (0.0 to 1.0)
"""
cost = 0.0
# Basic query complexity
cost += len(query.split('\n')) * 0.01
# Pattern complexity
match_count = len(re.findall(r'MATCH', query, re.IGNORECASE))
cost += match_count * 0.1
# Relationship traversals
rel_count = len(re.findall(r'-\[.*?\]-', query))
cost += rel_count * 0.1
# Variable length paths
var_path_count = len(re.findall(r'\*\d*\.\.', query))
cost += var_path_count * 0.3
# WHERE clauses
where_count = len(re.findall(r'WHERE', query, re.IGNORECASE))
cost += where_count * 0.05
# Aggregation functions
agg_count = len(re.findall(r'COUNT|SUM|AVG|MIN|MAX', query, re.IGNORECASE))
cost += agg_count * 0.1
# Ontology subset size impact
total_elements = (len(ontology_subset.classes) +
len(ontology_subset.object_properties) +
len(ontology_subset.datatype_properties))
cost += (total_elements / 100.0) * 0.1
return min(cost, 1.0)
def should_use_cache(self,
query: str,
question_components: QuestionComponents,
optimization_hint: OptimizationHint) -> bool:
"""Determine if query results should be cached.
Args:
query: Query string
question_components: Question analysis
optimization_hint: Optimization hints
Returns:
True if results should be cached
"""
if not optimization_hint.cache_results:
return False
# Cache simple retrieval and factual queries
if question_components.question_type in [QuestionType.RETRIEVAL, QuestionType.FACTUAL]:
return True
# Cache expensive aggregation queries
if (question_components.question_type == QuestionType.AGGREGATION and
('COUNT' in query.upper() or 'SUM' in query.upper())):
return True
# Don't cache real-time or time-sensitive queries
if any(keyword in question_components.original_question.lower()
for keyword in ['now', 'current', 'latest', 'recent']):
return False
return False
def get_cache_key(self,
query: str,
ontology_subset: QueryOntologySubset) -> str:
"""Generate cache key for query.
Args:
query: Query string
ontology_subset: Ontology subset
Returns:
Cache key string
"""
import hashlib
# Create stable representation
ontology_repr = f"{sorted(ontology_subset.classes.keys())}-{sorted(ontology_subset.object_properties.keys())}"
combined = f"{query.strip()}-{ontology_repr}"
return hashlib.md5(combined.encode()).hexdigest()

View file

@ -0,0 +1,438 @@
"""
Main OntoRAG query service.
Orchestrates question analysis, ontology matching, query generation, execution, and answer generation.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from ....flow.flow_processor import FlowProcessor
from ....tables.config import ConfigTableStore
from ...extract.kg.ontology.ontology_loader import OntologyLoader
from ...extract.kg.ontology.vector_store import InMemoryVectorStore
from .question_analyzer import QuestionAnalyzer, QuestionComponents
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .backend_router import BackendRouter, QueryRoute, BackendType
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
from .cypher_generator import CypherGenerator, CypherQuery
from .cypher_executor import CypherExecutor, CypherResult
from .answer_generator import AnswerGenerator, GeneratedAnswer
logger = logging.getLogger(__name__)
@dataclass
class QueryRequest:
"""Query request from user."""
question: str
context: Optional[str] = None
ontology_hint: Optional[str] = None
max_results: int = 10
confidence_threshold: float = 0.7
@dataclass
class QueryResponse:
"""Complete query response."""
answer: str
confidence: float
execution_time: float
question_analysis: QuestionComponents
ontology_subsets: List[QueryOntologySubset]
query_route: QueryRoute
generated_query: Union[SPARQLQuery, CypherQuery]
raw_results: Union[SPARQLResult, CypherResult]
supporting_facts: List[str]
metadata: Dict[str, Any]
class OntoRAGQueryService(FlowProcessor):
"""Main OntoRAG query service orchestrating all components."""
def __init__(self, config: Dict[str, Any]):
"""Initialize OntoRAG query service.
Args:
config: Service configuration
"""
super().__init__(config)
self.config = config
# Initialize components
self.config_store = None
self.ontology_loader = None
self.vector_store = None
self.question_analyzer = None
self.ontology_matcher = None
self.backend_router = None
self.sparql_generator = None
self.sparql_engine = None
self.cypher_generator = None
self.cypher_executor = None
self.answer_generator = None
# Cache for loaded ontologies
self.ontology_cache = {}
async def init(self):
"""Initialize all components."""
await super().init()
# Initialize configuration store
self.config_store = ConfigTableStore(self.config.get('config_store', {}))
# Initialize ontology components
self.ontology_loader = OntologyLoader(self.config_store)
# Initialize vector store
vector_config = self.config.get('vector_store', {})
self.vector_store = InMemoryVectorStore.create(
store_type=vector_config.get('type', 'numpy'),
dimension=vector_config.get('dimension', 384),
similarity_threshold=vector_config.get('similarity_threshold', 0.7)
)
# Initialize question analyzer
analyzer_config = self.config.get('question_analyzer', {})
self.question_analyzer = QuestionAnalyzer(
prompt_service=self.prompt_service,
config=analyzer_config
)
# Initialize ontology matcher
matcher_config = self.config.get('ontology_matcher', {})
self.ontology_matcher = OntologyMatcher(
vector_store=self.vector_store,
embedding_service=self.embedding_service,
config=matcher_config
)
# Initialize backend router
router_config = self.config.get('backend_router', {})
self.backend_router = BackendRouter(router_config)
# Initialize query generators
self.sparql_generator = SPARQLGenerator(prompt_service=self.prompt_service)
self.cypher_generator = CypherGenerator(prompt_service=self.prompt_service)
# Initialize executors
sparql_config = self.config.get('sparql_executor', {})
if self.backend_router.is_backend_enabled(BackendType.CASSANDRA):
cassandra_config = self.backend_router.get_backend_config(BackendType.CASSANDRA)
if cassandra_config:
self.sparql_engine = SPARQLCassandraEngine(cassandra_config)
await self.sparql_engine.initialize()
cypher_config = self.config.get('cypher_executor', {})
enabled_graph_backends = [
bt for bt in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB]
if self.backend_router.is_backend_enabled(bt)
]
if enabled_graph_backends:
self.cypher_executor = CypherExecutor(cypher_config)
await self.cypher_executor.initialize()
# Initialize answer generator
self.answer_generator = AnswerGenerator(prompt_service=self.prompt_service)
logger.info("OntoRAG query service initialized")
async def process(self, request: QueryRequest) -> QueryResponse:
"""Process a natural language query.
Args:
request: Query request
Returns:
Complete query response
"""
start_time = datetime.now()
try:
logger.info(f"Processing query: {request.question}")
# Step 1: Analyze question
question_components = await self.question_analyzer.analyze_question(
request.question, context=request.context
)
logger.debug(f"Question analysis: {question_components.question_type}")
# Step 2: Load and match ontologies
ontology_subsets = await self._load_and_match_ontologies(
question_components, request.ontology_hint
)
logger.debug(f"Found {len(ontology_subsets)} relevant ontology subsets")
# Step 3: Route to appropriate backend
query_route = self.backend_router.route_query(
question_components, ontology_subsets
)
logger.debug(f"Routed to {query_route.backend_type.value} backend")
# Step 4: Generate and execute query
if query_route.query_language == 'sparql':
query_results = await self._execute_sparql_path(
question_components, ontology_subsets, query_route
)
else: # cypher
query_results = await self._execute_cypher_path(
question_components, ontology_subsets, query_route
)
# Step 5: Generate natural language answer
generated_answer = await self.answer_generator.generate_answer(
question_components,
query_results['raw_results'],
ontology_subsets[0] if ontology_subsets else None,
query_route.backend_type.value
)
# Build response
execution_time = (datetime.now() - start_time).total_seconds()
response = QueryResponse(
answer=generated_answer.answer,
confidence=min(query_route.confidence, generated_answer.metadata.confidence),
execution_time=execution_time,
question_analysis=question_components,
ontology_subsets=ontology_subsets,
query_route=query_route,
generated_query=query_results['generated_query'],
raw_results=query_results['raw_results'],
supporting_facts=generated_answer.supporting_facts,
metadata={
'backend_used': query_route.backend_type.value,
'query_language': query_route.query_language,
'ontology_count': len(ontology_subsets),
'result_count': generated_answer.metadata.result_count,
'routing_reasoning': query_route.reasoning,
'generation_time': generated_answer.generation_time
}
)
logger.info(f"Query processed successfully in {execution_time:.2f}s")
return response
except Exception as e:
logger.error(f"Query processing failed: {e}")
execution_time = (datetime.now() - start_time).total_seconds()
# Return error response
return QueryResponse(
answer=f"I encountered an error processing your query: {str(e)}",
confidence=0.0,
execution_time=execution_time,
question_analysis=QuestionComponents(
original_question=request.question,
normalized_question=request.question,
question_type=None,
entities=[], keywords=[], relationships=[], constraints=[],
aggregations=[], expected_answer_type="unknown"
),
ontology_subsets=[],
query_route=None,
generated_query=None,
raw_results=None,
supporting_facts=[],
metadata={'error': str(e), 'execution_time': execution_time}
)
async def _load_and_match_ontologies(self,
question_components: QuestionComponents,
ontology_hint: Optional[str] = None) -> List[QueryOntologySubset]:
"""Load ontologies and find relevant subsets.
Args:
question_components: Analyzed question
ontology_hint: Optional ontology hint
Returns:
List of relevant ontology subsets
"""
try:
# Load available ontologies
if ontology_hint:
# Load specific ontology
ontologies = [await self.ontology_loader.load_ontology(ontology_hint)]
else:
# Load all available ontologies
available_ontologies = await self.ontology_loader.list_available_ontologies()
ontologies = []
for ontology_id in available_ontologies[:5]: # Limit to 5 for performance
try:
ontology = await self.ontology_loader.load_ontology(ontology_id)
ontologies.append(ontology)
except Exception as e:
logger.warning(f"Failed to load ontology {ontology_id}: {e}")
if not ontologies:
logger.warning("No ontologies loaded")
return []
# Extract relevant subsets
ontology_subsets = []
for ontology in ontologies:
subset = await self.ontology_matcher.select_relevant_subset(
question_components, ontology
)
if subset and (subset.classes or subset.object_properties or subset.datatype_properties):
ontology_subsets.append(subset)
return ontology_subsets
except Exception as e:
logger.error(f"Failed to load and match ontologies: {e}")
return []
async def _execute_sparql_path(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
query_route: QueryRoute) -> Dict[str, Any]:
"""Execute SPARQL query path.
Args:
question_components: Question analysis
ontology_subsets: Ontology subsets
query_route: Query route
Returns:
Query execution results
"""
if not self.sparql_engine:
raise RuntimeError("SPARQL engine not initialized")
# Generate SPARQL query
primary_subset = ontology_subsets[0] if ontology_subsets else None
sparql_query = await self.sparql_generator.generate_sparql(
question_components, primary_subset
)
logger.debug(f"Generated SPARQL: {sparql_query.query}")
# Execute query
sparql_results = self.sparql_engine.execute_sparql(sparql_query.query)
return {
'generated_query': sparql_query,
'raw_results': sparql_results
}
async def _execute_cypher_path(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
query_route: QueryRoute) -> Dict[str, Any]:
"""Execute Cypher query path.
Args:
question_components: Question analysis
ontology_subsets: Ontology subsets
query_route: Query route
Returns:
Query execution results
"""
if not self.cypher_executor:
raise RuntimeError("Cypher executor not initialized")
# Generate Cypher query
primary_subset = ontology_subsets[0] if ontology_subsets else None
cypher_query = await self.cypher_generator.generate_cypher(
question_components, primary_subset
)
logger.debug(f"Generated Cypher: {cypher_query.query}")
# Execute query
database_type = query_route.backend_type.value
cypher_results = await self.cypher_executor.execute_query(
cypher_query.query, database_type=database_type
)
return {
'generated_query': cypher_query,
'raw_results': cypher_results
}
async def get_supported_backends(self) -> List[str]:
"""Get list of supported and enabled backends.
Returns:
List of backend names
"""
return [bt.value for bt in self.backend_router.get_available_backends()]
async def get_available_ontologies(self) -> List[str]:
"""Get list of available ontologies.
Returns:
List of ontology identifiers
"""
if self.ontology_loader:
return await self.ontology_loader.list_available_ontologies()
return []
async def health_check(self) -> Dict[str, Any]:
"""Perform health check on all components.
Returns:
Health status of all components
"""
health = {
'service': 'healthy',
'components': {},
'backends': {},
'ontologies': {}
}
try:
# Check ontology loader
if self.ontology_loader:
ontologies = await self.ontology_loader.list_available_ontologies()
health['components']['ontology_loader'] = 'healthy'
health['ontologies']['count'] = len(ontologies)
else:
health['components']['ontology_loader'] = 'not_initialized'
# Check vector store
if self.vector_store:
health['components']['vector_store'] = 'healthy'
health['components']['vector_store_type'] = type(self.vector_store).__name__
else:
health['components']['vector_store'] = 'not_initialized'
# Check backends
for backend_type in self.backend_router.get_available_backends():
if backend_type == BackendType.CASSANDRA and self.sparql_engine:
health['backends']['cassandra'] = 'healthy'
elif backend_type in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] and self.cypher_executor:
health['backends'][backend_type.value] = 'healthy'
else:
health['backends'][backend_type.value] = 'configured_but_not_initialized'
except Exception as e:
health['service'] = 'degraded'
health['error'] = str(e)
return health
async def close(self):
"""Close all connections and cleanup resources."""
try:
if self.sparql_engine:
self.sparql_engine.close()
if self.cypher_executor:
await self.cypher_executor.close()
if self.config_store:
# ConfigTableStore cleanup if needed
pass
logger.info("OntoRAG query service closed")
except Exception as e:
logger.error(f"Error closing OntoRAG query service: {e}")

View file

@ -0,0 +1,364 @@
"""
Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components.
"""
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class QuestionType(Enum):
"""Types of questions that can be asked."""
FACTUAL = "factual" # What is X?
RETRIEVAL = "retrieval" # Find all X
AGGREGATION = "aggregation" # How many X?
COMPARISON = "comparison" # Is X better than Y?
RELATIONSHIP = "relationship" # How is X related to Y?
BOOLEAN = "boolean" # Yes/no questions
PROCESS = "process" # How to do X?
TEMPORAL = "temporal" # When did X happen?
SPATIAL = "spatial" # Where is X?
@dataclass
class QuestionComponents:
"""Components extracted from a question."""
original_question: str
question_type: QuestionType
entities: List[str]
relationships: List[str]
constraints: List[str]
aggregations: List[str]
expected_answer_type: str
keywords: List[str]
class QuestionAnalyzer:
"""Analyzes natural language questions to extract semantic components."""
def __init__(self):
"""Initialize question analyzer."""
# Question word patterns
self.question_patterns = {
QuestionType.FACTUAL: [
r'^what\s+(?:is|are)',
r'^who\s+(?:is|are)',
r'^which\s+',
],
QuestionType.RETRIEVAL: [
r'^find\s+',
r'^list\s+',
r'^show\s+',
r'^get\s+',
r'^retrieve\s+',
],
QuestionType.AGGREGATION: [
r'^how\s+many',
r'^count\s+',
r'^what\s+(?:is|are)\s+the\s+(?:number|total|sum)',
],
QuestionType.COMPARISON: [
r'(?:better|worse|more|less|greater|smaller)\s+than',
r'compare\s+',
r'difference\s+between',
],
QuestionType.RELATIONSHIP: [
r'^how\s+(?:is|are).*related',
r'relationship\s+between',
r'connection\s+between',
],
QuestionType.BOOLEAN: [
r'^(?:is|are|was|were|do|does|did|can|could|will|would|should)',
r'^has\s+',
r'^have\s+',
],
QuestionType.PROCESS: [
r'^how\s+(?:to|do)',
r'^explain\s+how',
],
QuestionType.TEMPORAL: [
r'^when\s+',
r'what\s+time',
r'what\s+date',
],
QuestionType.SPATIAL: [
r'^where\s+',
r'location\s+of',
],
}
# Aggregation keywords
self.aggregation_keywords = [
'count', 'sum', 'total', 'average', 'mean', 'median',
'maximum', 'minimum', 'max', 'min', 'number of'
]
# Constraint patterns
self.constraint_patterns = [
r'(?:with|having|where)\s+(.+?)(?:\s+and|\s+or|$)',
r'(?:greater|less|more|fewer)\s+than\s+(\d+)',
r'(?:between|from)\s+(.+?)\s+(?:and|to)\s+(.+)',
r'(?:before|after|since|until)\s+(.+)',
]
def analyze(self, question: str) -> QuestionComponents:
"""Analyze a question to extract components.
Args:
question: Natural language question
Returns:
QuestionComponents with extracted information
"""
# Normalize question
question_lower = question.lower().strip()
# Determine question type
question_type = self._identify_question_type(question_lower)
# Extract entities
entities = self._extract_entities(question)
# Extract relationships
relationships = self._extract_relationships(question_lower)
# Extract constraints
constraints = self._extract_constraints(question_lower)
# Extract aggregations
aggregations = self._extract_aggregations(question_lower)
# Determine expected answer type
answer_type = self._determine_answer_type(question_type, aggregations)
# Extract keywords
keywords = self._extract_keywords(question_lower)
return QuestionComponents(
original_question=question,
question_type=question_type,
entities=entities,
relationships=relationships,
constraints=constraints,
aggregations=aggregations,
expected_answer_type=answer_type,
keywords=keywords
)
def _identify_question_type(self, question: str) -> QuestionType:
"""Identify the type of question.
Args:
question: Lowercase question text
Returns:
QuestionType enum value
"""
for q_type, patterns in self.question_patterns.items():
for pattern in patterns:
if re.search(pattern, question):
return q_type
# Default to factual
return QuestionType.FACTUAL
def _extract_entities(self, question: str) -> List[str]:
"""Extract potential entities from question.
Args:
question: Original question text
Returns:
List of entity strings
"""
entities = []
# Extract capitalized words/phrases (potential proper nouns)
# Pattern for consecutive capitalized words
pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b'
matches = re.findall(pattern, question)
entities.extend(matches)
# Extract quoted strings
quoted = re.findall(r'"([^"]+)"', question)
entities.extend(quoted)
quoted = re.findall(r"'([^']+)'", question)
entities.extend(quoted)
# Remove duplicates while preserving order
seen = set()
unique_entities = []
for entity in entities:
if entity not in seen:
seen.add(entity)
unique_entities.append(entity)
return unique_entities
def _extract_relationships(self, question: str) -> List[str]:
"""Extract relationship indicators from question.
Args:
question: Lowercase question text
Returns:
List of relationship strings
"""
relationships = []
# Common relationship patterns
rel_patterns = [
r'(\w+)\s+(?:of|by|from|to|with|for)\s+',
r'has\s+(\w+)',
r'belongs?\s+to',
r'(?:created|written|authored|owned)\s+by',
r'related\s+to',
r'connected\s+to',
r'associated\s+with',
]
for pattern in rel_patterns:
matches = re.findall(pattern, question)
relationships.extend(matches)
# Clean up
relationships = [r for r in relationships if len(r) > 2]
return list(set(relationships))
def _extract_constraints(self, question: str) -> List[str]:
"""Extract constraints from question.
Args:
question: Lowercase question text
Returns:
List of constraint strings
"""
constraints = []
for pattern in self.constraint_patterns:
matches = re.findall(pattern, question)
if matches:
if isinstance(matches[0], tuple):
constraints.extend(list(matches[0]))
else:
constraints.extend(matches)
# Clean up
constraints = [c.strip() for c in constraints if c and len(c.strip()) > 0]
return constraints
def _extract_aggregations(self, question: str) -> List[str]:
"""Extract aggregation operations from question.
Args:
question: Lowercase question text
Returns:
List of aggregation operations
"""
aggregations = []
for keyword in self.aggregation_keywords:
if keyword in question:
aggregations.append(keyword)
return aggregations
def _determine_answer_type(self, question_type: QuestionType,
aggregations: List[str]) -> str:
"""Determine expected answer type.
Args:
question_type: Type of question
aggregations: Aggregation operations found
Returns:
Expected answer type string
"""
if aggregations:
if any(a in ['count', 'number of', 'total'] for a in aggregations):
return 'number'
elif any(a in ['average', 'mean', 'median'] for a in aggregations):
return 'number'
elif any(a in ['sum'] for a in aggregations):
return 'number'
if question_type == QuestionType.BOOLEAN:
return 'boolean'
elif question_type == QuestionType.TEMPORAL:
return 'datetime'
elif question_type == QuestionType.SPATIAL:
return 'location'
elif question_type == QuestionType.RETRIEVAL:
return 'list'
elif question_type == QuestionType.COMPARISON:
return 'comparison'
else:
return 'text'
def _extract_keywords(self, question: str) -> List[str]:
"""Extract important keywords from question.
Args:
question: Lowercase question text
Returns:
List of keywords
"""
# Remove common stop words
stop_words = {
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are',
'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do',
'does', 'did', 'will', 'would', 'could', 'should', 'may',
'might', 'must', 'can', 'shall', 'what', 'which', 'who',
'when', 'where', 'why', 'how'
}
# Extract words
words = re.findall(r'\b\w+\b', question)
# Filter stop words and short words
keywords = [w for w in words if w not in stop_words and len(w) > 2]
# Remove duplicates while preserving order
seen = set()
unique_keywords = []
for kw in keywords:
if kw not in seen:
seen.add(kw)
unique_keywords.append(kw)
return unique_keywords
def get_question_segments(self, question: str) -> List[str]:
"""Split question into segments for embedding.
Args:
question: Question text
Returns:
List of question segments
"""
segments = []
# Add full question
segments.append(question)
# Split by clauses
clauses = re.split(r'[,;]', question)
segments.extend([c.strip() for c in clauses if len(c.strip()) > 3])
# Extract key phrases
components = self.analyze(question)
segments.extend(components.entities)
segments.extend(components.keywords)
# Remove duplicates
return list(dict.fromkeys(segments))

View file

@ -0,0 +1,481 @@
"""
SPARQL-Cassandra engine using Python rdflib.
Executes SPARQL queries against Cassandra using a custom Store implementation.
"""
import logging
from typing import Dict, Any, List, Optional, Iterator, Tuple
from dataclasses import dataclass
import json
# Try to import rdflib
try:
from rdflib import Graph, Namespace, URIRef, Literal, BNode
from rdflib.store import Store
from rdflib.plugins.sparql.processor import SPARQLResult
from rdflib.plugins.sparql import prepareQuery
from rdflib.term import Node
RDFLIB_AVAILABLE = True
except ImportError:
RDFLIB_AVAILABLE = False
# Try to import Cassandra driver
try:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.policies import DCAwareRoundRobinPolicy
CASSANDRA_AVAILABLE = True
except ImportError:
CASSANDRA_AVAILABLE = False
from ....tables.config import ConfigTableStore
logger = logging.getLogger(__name__)
@dataclass
class SPARQLResult:
"""Result from SPARQL query execution."""
bindings: List[Dict[str, Any]]
variables: List[str]
ask_result: Optional[bool] = None # For ASK queries
execution_time: float = 0.0
query_plan: Optional[str] = None
class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
"""Custom rdflib Store implementation for Cassandra."""
def __init__(self, cassandra_config: Dict[str, Any]):
"""Initialize Cassandra triple store.
Args:
cassandra_config: Cassandra connection configuration
"""
if not CASSANDRA_AVAILABLE:
raise RuntimeError("Cassandra driver not available")
if not RDFLIB_AVAILABLE:
raise RuntimeError("rdflib not available")
super().__init__()
self.cassandra_config = cassandra_config
self.cluster = None
self.session = None
self.keyspace = cassandra_config.get('keyspace', 'trustgraph')
# Triple storage table structure
self.triple_table = f"{self.keyspace}.triples"
self.metadata_table = f"{self.keyspace}.triple_metadata"
def open(self, configuration=None, create=False):
"""Open connection to Cassandra."""
try:
# Create authentication if provided
auth_provider = None
if 'username' in self.cassandra_config and 'password' in self.cassandra_config:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_config['username'],
password=self.cassandra_config['password']
)
# Create cluster
self.cluster = Cluster(
[self.cassandra_config.get('host', 'localhost')],
port=self.cassandra_config.get('port', 9042),
auth_provider=auth_provider,
load_balancing_policy=DCAwareRoundRobinPolicy()
)
# Connect
self.session = self.cluster.connect()
# Ensure keyspace exists
if create:
self._create_schema()
# Set keyspace
self.session.set_keyspace(self.keyspace)
logger.info(f"Connected to Cassandra cluster: {self.cassandra_config.get('host')}")
return True
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}")
return False
def close(self, commit_pending_transaction=True):
"""Close Cassandra connection."""
if self.session:
self.session.shutdown()
if self.cluster:
self.cluster.shutdown()
def _create_schema(self):
"""Create Cassandra schema for triple storage."""
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
""")
# Create triples table optimized for SPARQL queries
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.triple_table} (
subject text,
predicate text,
object text,
object_datatype text,
object_language text,
is_literal boolean,
graph_id text,
PRIMARY KEY ((subject), predicate, object)
)
""")
# Create indexes for efficient querying
self.session.execute(f"""
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (predicate)
""")
self.session.execute(f"""
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (object)
""")
# Metadata table for graph information
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
graph_id text PRIMARY KEY,
created timestamp,
modified timestamp,
triple_count counter
)
""")
def triples(self, triple_pattern, context=None):
"""Retrieve triples matching the given pattern.
Args:
triple_pattern: (subject, predicate, object) pattern with None for variables
context: Graph context (optional)
Yields:
Matching triples as (subject, predicate, object) tuples
"""
if not self.session:
return
subject, predicate, object_val = triple_pattern
# Build CQL query based on pattern
cql_queries = self._pattern_to_cql(subject, predicate, object_val)
for cql, params in cql_queries:
try:
rows = self.session.execute(cql, params)
for row in rows:
yield self._row_to_triple(row)
except Exception as e:
logger.error(f"Error executing CQL query: {e}")
def _pattern_to_cql(self, subject, predicate, object_val) -> List[Tuple[str, List]]:
"""Convert triple pattern to CQL queries.
Args:
subject: Subject node or None
predicate: Predicate node or None
object_val: Object node or None
Returns:
List of (CQL query, parameters) tuples
"""
queries = []
# Convert None to wildcard, nodes to strings
s_str = str(subject) if subject else None
p_str = str(predicate) if predicate else None
o_str = str(object_val) if object_val else None
if s_str and p_str and o_str:
# Specific triple lookup
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ? AND object = ?"
queries.append((cql, [s_str, p_str, o_str]))
elif s_str and p_str:
# Subject and predicate known
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ?"
queries.append((cql, [s_str, p_str]))
elif s_str:
# Subject known
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ?"
queries.append((cql, [s_str]))
elif p_str:
# Predicate known (requires index scan)
cql = f"SELECT * FROM {self.triple_table} WHERE predicate = ? ALLOW FILTERING"
queries.append((cql, [p_str]))
elif o_str:
# Object known (requires index scan)
cql = f"SELECT * FROM {self.triple_table} WHERE object = ? ALLOW FILTERING"
queries.append((cql, [o_str]))
else:
# Full scan (should be avoided in production)
cql = f"SELECT * FROM {self.triple_table}"
queries.append((cql, []))
return queries
def _row_to_triple(self, row):
"""Convert Cassandra row to RDF triple.
Args:
row: Cassandra row object
Returns:
(subject, predicate, object) tuple with rdflib nodes
"""
# Convert to rdflib nodes
subject = URIRef(row.subject) if row.subject.startswith('http') else BNode(row.subject)
predicate = URIRef(row.predicate)
if row.is_literal:
# Create literal with datatype/language
if row.object_datatype:
object_node = Literal(row.object, datatype=URIRef(row.object_datatype))
elif row.object_language:
object_node = Literal(row.object, lang=row.object_language)
else:
object_node = Literal(row.object)
else:
object_node = URIRef(row.object) if row.object.startswith('http') else BNode(row.object)
return (subject, predicate, object_node)
def add(self, triple, context=None, quoted=False):
"""Add a triple to the store.
Args:
triple: (subject, predicate, object) tuple
context: Graph context
quoted: Whether triple is quoted
"""
if not self.session:
return
subject, predicate, object_val = triple
# Convert to storage format
s_str = str(subject)
p_str = str(predicate)
is_literal = isinstance(object_val, Literal)
o_str = str(object_val)
o_datatype = str(object_val.datatype) if is_literal and object_val.datatype else None
o_language = object_val.language if is_literal and object_val.language else None
# Insert into Cassandra
cql = f"""
INSERT INTO {self.triple_table}
(subject, predicate, object, object_datatype, object_language, is_literal, graph_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
"""
try:
self.session.execute(cql, [
s_str, p_str, o_str, o_datatype, o_language, is_literal,
str(context) if context else 'default'
])
except Exception as e:
logger.error(f"Error adding triple: {e}")
def remove(self, triple, context=None):
"""Remove a triple from the store.
Args:
triple: (subject, predicate, object) tuple
context: Graph context
"""
if not self.session:
return
subject, predicate, object_val = triple
cql = f"""
DELETE FROM {self.triple_table}
WHERE subject = ? AND predicate = ? AND object = ?
"""
try:
self.session.execute(cql, [str(subject), str(predicate), str(object_val)])
except Exception as e:
logger.error(f"Error removing triple: {e}")
def __len__(self, context=None):
"""Get number of triples in store.
Args:
context: Graph context
Returns:
Number of triples
"""
if not self.session:
return 0
try:
cql = f"SELECT COUNT(*) FROM {self.triple_table}"
result = self.session.execute(cql)
return result.one().count
except Exception as e:
logger.error(f"Error counting triples: {e}")
return 0
class SPARQLCassandraEngine:
"""SPARQL processor using Cassandra backend."""
def __init__(self, cassandra_config: Dict[str, Any]):
"""Initialize SPARQL-Cassandra engine.
Args:
cassandra_config: Cassandra configuration
"""
if not RDFLIB_AVAILABLE:
raise RuntimeError("rdflib is required for SPARQL processing")
if not CASSANDRA_AVAILABLE:
raise RuntimeError("Cassandra driver is required")
self.cassandra_config = cassandra_config
self.store = CassandraTripleStore(cassandra_config)
self.graph = Graph(store=self.store)
# Common namespaces
self.namespaces = {
'rdf': Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
'rdfs': Namespace('http://www.w3.org/2000/01/rdf-schema#'),
'owl': Namespace('http://www.w3.org/2002/07/owl#'),
'xsd': Namespace('http://www.w3.org/2001/XMLSchema#'),
}
# Bind namespaces to graph
for prefix, namespace in self.namespaces.items():
self.graph.bind(prefix, namespace)
async def initialize(self, create_schema=False):
"""Initialize the engine.
Args:
create_schema: Whether to create Cassandra schema
"""
success = self.store.open(create=create_schema)
if not success:
raise RuntimeError("Failed to connect to Cassandra")
logger.info("SPARQL-Cassandra engine initialized")
def execute_sparql(self, sparql_query: str) -> SPARQLResult:
"""Execute SPARQL query against Cassandra.
Args:
sparql_query: SPARQL query string
Returns:
Query results
"""
import time
start_time = time.time()
try:
# Prepare and execute query
prepared_query = prepareQuery(sparql_query)
result = self.graph.query(prepared_query)
execution_time = time.time() - start_time
# Format results based on query type
if sparql_query.strip().upper().startswith('ASK'):
return SPARQLResult(
bindings=[],
variables=[],
ask_result=bool(result),
execution_time=execution_time
)
else:
# SELECT query
bindings = []
variables = result.vars if hasattr(result, 'vars') else []
for row in result:
binding = {}
for i, var in enumerate(variables):
if i < len(row):
value = row[i]
binding[str(var)] = self._format_result_value(value)
bindings.append(binding)
return SPARQLResult(
bindings=bindings,
variables=[str(v) for v in variables],
execution_time=execution_time
)
except Exception as e:
logger.error(f"SPARQL execution error: {e}")
return SPARQLResult(
bindings=[],
variables=[],
execution_time=time.time() - start_time
)
def _format_result_value(self, value):
"""Format result value for output.
Args:
value: RDF value (URIRef, Literal, BNode)
Returns:
Formatted value
"""
if isinstance(value, URIRef):
return {'type': 'uri', 'value': str(value)}
elif isinstance(value, Literal):
result = {'type': 'literal', 'value': str(value)}
if value.datatype:
result['datatype'] = str(value.datatype)
if value.language:
result['language'] = value.language
return result
elif isinstance(value, BNode):
return {'type': 'bnode', 'value': str(value)}
else:
return {'type': 'unknown', 'value': str(value)}
def load_triples_from_store(self, config_store: ConfigTableStore):
"""Load triples from TrustGraph's storage into the RDF graph.
Args:
config_store: Configuration store with triples
"""
# This would need to be implemented based on how triples are stored
# in TrustGraph's Cassandra tables
logger.info("Loading triples from TrustGraph store...")
# Example implementation - would need to be adapted
# to actual TrustGraph storage format
try:
# Get all triple data
# This is a placeholder - actual implementation would need
# to query the appropriate TrustGraph tables
pass
except Exception as e:
logger.error(f"Error loading triples: {e}")
def close(self):
"""Close the engine and connections."""
if self.store:
self.store.close()
logger.info("SPARQL-Cassandra engine closed")

View file

@ -0,0 +1,487 @@
"""
SPARQL query generator for ontology-sensitive queries.
Converts natural language questions to SPARQL queries for Cassandra execution.
"""
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
@dataclass
class SPARQLQuery:
"""Generated SPARQL query with metadata."""
query: str
variables: List[str]
query_type: str # SELECT, ASK, CONSTRUCT, DESCRIBE
explanation: str
complexity_score: float
class SPARQLGenerator:
"""Generates SPARQL queries from natural language questions using LLM assistance."""
def __init__(self, prompt_service=None):
"""Initialize SPARQL generator.
Args:
prompt_service: Service for LLM-based query generation
"""
self.prompt_service = prompt_service
# SPARQL query templates for common patterns
self.templates = {
'simple_class_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?entity ?label WHERE {{
?entity rdf:type :{class_name} .
OPTIONAL {{ ?entity rdfs:label ?label }}
}}""",
'property_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?object WHERE {{
?subject :{property} ?object .
?subject rdf:type :{subject_class} .
}}""",
'hierarchy_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subclass ?superclass WHERE {{
?subclass rdfs:subClassOf* ?superclass .
?superclass rdf:type :{root_class} .
}}""",
'count_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
SELECT (COUNT(?entity) AS ?count) WHERE {{
?entity rdf:type :{class_name} .
{additional_constraints}
}}""",
'boolean_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
ASK {{
{triple_pattern}
}}"""
}
async def generate_sparql(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
"""Generate SPARQL query for a question.
Args:
question_components: Analyzed question components
ontology_subset: Relevant ontology subset
Returns:
Generated SPARQL query
"""
# Try template-based generation first
template_query = self._try_template_generation(question_components, ontology_subset)
if template_query:
logger.debug("Generated SPARQL using template")
return template_query
# Fall back to LLM-based generation
if self.prompt_service:
llm_query = await self._generate_with_llm(question_components, ontology_subset)
if llm_query:
logger.debug("Generated SPARQL using LLM")
return llm_query
# Final fallback to simple pattern
logger.warning("Falling back to simple SPARQL pattern")
return self._generate_fallback_query(question_components, ontology_subset)
def _try_template_generation(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
"""Try to generate query using templates.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Generated query or None if no template matches
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Simple class query (What are the animals?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) == 1 and
question_components.entities[0].lower() in [c.lower() for c in ontology_subset.classes]):
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
if class_name:
query = self.templates['simple_class_query'].format(
namespace=namespace,
class_name=class_name
)
return SPARQLQuery(
query=query,
variables=['entity', 'label'],
query_type='SELECT',
explanation=f"Retrieve all instances of {class_name}",
complexity_score=0.3
)
# Count query (How many animals are there?)
if (question_components.question_type == QuestionType.AGGREGATION and
'count' in question_components.aggregations and
len(question_components.entities) >= 1):
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
if class_name:
query = self.templates['count_query'].format(
namespace=namespace,
class_name=class_name,
additional_constraints=self._build_constraints(question_components, ontology_subset)
)
return SPARQLQuery(
query=query,
variables=['count'],
query_type='SELECT',
explanation=f"Count instances of {class_name}",
complexity_score=0.4
)
# Boolean query (Is X a Y?)
if question_components.question_type == QuestionType.BOOLEAN:
triple_pattern = self._build_boolean_pattern(question_components, ontology_subset)
if triple_pattern:
query = self.templates['boolean_query'].format(
namespace=namespace,
triple_pattern=triple_pattern
)
return SPARQLQuery(
query=query,
variables=[],
query_type='ASK',
explanation="Boolean query for fact checking",
complexity_score=0.2
)
return None
async def _generate_with_llm(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
"""Generate SPARQL using LLM.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Generated query or None if failed
"""
try:
prompt = self._build_sparql_prompt(question_components, ontology_subset)
response = await self.prompt_service.generate_sparql(prompt=prompt)
if response and isinstance(response, dict):
query = response.get('query', '').strip()
if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')):
return SPARQLQuery(
query=query,
variables=self._extract_variables(query),
query_type=query.split()[0].upper(),
explanation=response.get('explanation', 'Generated by LLM'),
complexity_score=self._calculate_complexity(query)
)
except Exception as e:
logger.error(f"LLM SPARQL generation failed: {e}")
return None
def _build_sparql_prompt(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Build prompt for LLM SPARQL generation.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Formatted prompt string
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Format ontology elements
classes_str = self._format_classes_for_prompt(ontology_subset.classes, namespace)
props_str = self._format_properties_for_prompt(
ontology_subset.object_properties,
ontology_subset.datatype_properties,
namespace
)
prompt = f"""Generate a SPARQL query for the following question using the provided ontology.
QUESTION: {question_components.original_question}
ONTOLOGY NAMESPACE: {namespace}
AVAILABLE CLASSES:
{classes_str}
AVAILABLE PROPERTIES:
{props_str}
RULES:
- Use proper SPARQL syntax
- Include appropriate prefixes
- Use property paths for hierarchical queries (rdfs:subClassOf*)
- Add FILTER clauses for constraints
- Optimize for Cassandra backend
- Return both query and explanation
QUERY TYPE HINTS:
- Question type: {question_components.question_type.value}
- Expected answer: {question_components.expected_answer_type}
- Entities mentioned: {', '.join(question_components.entities)}
- Aggregations: {', '.join(question_components.aggregations)}
Generate a complete SPARQL query:"""
return prompt
def _generate_fallback_query(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
"""Generate simple fallback query.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Basic SPARQL query
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Very basic SELECT query
query = f"""PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?predicate ?object WHERE {{
?subject ?predicate ?object .
FILTER(CONTAINS(STR(?subject), "{question_components.keywords[0] if question_components.keywords else 'entity'}"))
}}
LIMIT 10"""
return SPARQLQuery(
query=query,
variables=['subject', 'predicate', 'object'],
query_type='SELECT',
explanation="Fallback query for basic pattern matching",
complexity_score=0.1
)
def _find_matching_class(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching class in ontology subset.
Args:
entity: Entity string to match
ontology_subset: Ontology subset
Returns:
Matching class name or None
"""
entity_lower = entity.lower()
# Direct match
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower:
return class_id
# Label match
for class_id, class_def in ontology_subset.classes.items():
labels = class_def.get('labels', [])
for label in labels:
if isinstance(label, dict):
label_value = label.get('value', '').lower()
if label_value == entity_lower:
return class_id
# Partial match
for class_id in ontology_subset.classes:
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
return class_id
return None
def _build_constraints(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Build constraint clauses for SPARQL.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
SPARQL constraint string
"""
constraints = []
for constraint in question_components.constraints:
# Simple constraint patterns
if 'greater than' in constraint.lower():
# Extract number
import re
numbers = re.findall(r'\d+', constraint)
if numbers:
constraints.append(f"FILTER(?value > {numbers[0]})")
elif 'less than' in constraint.lower():
numbers = re.findall(r'\d+', constraint)
if numbers:
constraints.append(f"FILTER(?value < {numbers[0]})")
return '\n '.join(constraints)
def _build_boolean_pattern(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Build triple pattern for boolean queries.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
SPARQL triple pattern or None
"""
if len(question_components.entities) >= 2:
subject = question_components.entities[0]
object_val = question_components.entities[1]
# Try to find connecting property
for prop_id in ontology_subset.object_properties:
return f":{subject} :{prop_id} :{object_val} ."
# Fallback to type check
return f":{subject} rdf:type :{object_val} ."
return None
def _format_classes_for_prompt(self, classes: Dict[str, Any], namespace: str) -> str:
"""Format classes for prompt.
Args:
classes: Classes dictionary
namespace: Ontology namespace
Returns:
Formatted classes string
"""
if not classes:
return "None"
lines = []
for class_id, definition in classes.items():
comment = definition.get('comment', '')
parent = definition.get('subclass_of', 'Thing')
lines.append(f"- :{class_id} (subclass of :{parent}) - {comment}")
return '\n'.join(lines)
def _format_properties_for_prompt(self,
object_props: Dict[str, Any],
datatype_props: Dict[str, Any],
namespace: str) -> str:
"""Format properties for prompt.
Args:
object_props: Object properties
datatype_props: Datatype properties
namespace: Ontology namespace
Returns:
Formatted properties string
"""
lines = []
for prop_id, definition in object_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'Any')
comment = definition.get('comment', '')
lines.append(f"- :{prop_id} (:{domain} -> :{range_val}) - {comment}")
for prop_id, definition in datatype_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'xsd:string')
comment = definition.get('comment', '')
lines.append(f"- :{prop_id} (:{domain} -> {range_val}) - {comment}")
return '\n'.join(lines) if lines else "None"
def _extract_variables(self, query: str) -> List[str]:
"""Extract variables from SPARQL query.
Args:
query: SPARQL query string
Returns:
List of variable names
"""
import re
variables = re.findall(r'\?(\w+)', query)
return list(set(variables))
def _calculate_complexity(self, query: str) -> float:
"""Calculate complexity score for SPARQL query.
Args:
query: SPARQL query string
Returns:
Complexity score (0.0 to 1.0)
"""
complexity = 0.0
# Count different SPARQL features
query_upper = query.upper()
if 'JOIN' in query_upper or 'UNION' in query_upper:
complexity += 0.3
if 'FILTER' in query_upper:
complexity += 0.2
if 'OPTIONAL' in query_upper:
complexity += 0.1
if 'GROUP BY' in query_upper:
complexity += 0.2
if 'ORDER BY' in query_upper:
complexity += 0.1
if '*' in query: # Property paths
complexity += 0.1
# Count variables
variables = self._extract_variables(query)
complexity += len(variables) * 0.05
return min(complexity, 1.0)

View file

@ -132,20 +132,20 @@ class Processor(DocumentEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_create_collection(self, request):
"""Create a Milvus collection for document embeddings"""
"""
No-op for collection creation - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
if self.vecstore.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.vecstore.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write")
self.vecstore.create_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
logger.error(f"Failed to handle create collection request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",

View file

@ -123,19 +123,6 @@ class Processor(DocumentEmbeddingsStoreService):
async def store_document_embeddings(self, message):
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection
)
# Validate collection exists before accepting writes
if not self.pinecone.has_index(index_name):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
@ -145,6 +132,17 @@ class Processor(DocumentEmbeddingsStoreService):
for vec in emb.vectors:
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
@ -220,23 +218,19 @@ class Processor(DocumentEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_create_collection(self, request):
"""Create a Pinecone index for document embeddings"""
"""
No-op for collection creation - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
index_name = f"d-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
logger.info(f"Pinecone index {index_name} already exists")
else:
# Create with default dimension - will need to be recreated if dimension doesn't match
self.create_index(index_name, dim=384)
logger.info(f"Created Pinecone index: {index_name}")
logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
logger.error(f"Failed to handle create collection request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
@ -246,22 +240,34 @@ class Processor(DocumentEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
"""
Delete all dimension variants of the index for document embeddings.
Since indexes are created with dimension suffixes (e.g., d-user-coll-384),
we need to find and delete all matching indexes.
"""
try:
index_name = f"d-{request.user}-{request.collection}"
prefix = f"d-{request.user}-{request.collection}-"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
matching_indexes = [
idx.name for idx in all_indexes
if idx.name.startswith(prefix)
]
if not matching_indexes:
logger.info(f"No indexes found matching prefix {prefix}")
else:
logger.info(f"Index {index_name} does not exist, nothing to delete")
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -79,20 +79,6 @@ class Processor(DocumentEmbeddingsStoreService):
async def store_document_embeddings(self, message):
# Validate collection exists before accepting writes
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection
)
if not self.qdrant.collection_exists(collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
chunk = emb.chunk.decode("utf-8")
@ -100,6 +86,23 @@ class Processor(DocumentEmbeddingsStoreService):
for vec in emb.vectors:
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection,
points=[
@ -160,30 +163,19 @@ class Processor(DocumentEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_create_collection(self, request):
"""Create a Qdrant collection for document embeddings"""
"""
No-op for collection creation - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
collection_name = f"d_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
logger.info(f"Qdrant collection {collection_name} already exists")
else:
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
# Using a placeholder dimension - actual dimension determined by first embedding
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=384, # Default dimension, common for many models
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {collection_name}")
logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
logger.error(f"Failed to handle create collection request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
@ -193,22 +185,34 @@ class Processor(DocumentEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
"""
Delete all dimension variants of the collection for document embeddings.
Since collections are created with dimension suffixes (e.g., d_user_coll_384),
we need to find and delete all matching collections.
"""
try:
collection_name = f"d_{request.user}_{request.collection}"
prefix = f"d_{request.user}_{request.collection}_"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No collections found matching prefix {prefix}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -128,20 +128,20 @@ class Processor(GraphEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_create_collection(self, request):
"""Create a Milvus collection for graph embeddings"""
"""
No-op for collection creation - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
if self.vecstore.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.vecstore.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write")
self.vecstore.create_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
logger.error(f"Failed to handle create collection request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",

View file

@ -123,19 +123,6 @@ class Processor(GraphEmbeddingsStoreService):
async def store_graph_embeddings(self, message):
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection
)
# Validate collection exists before accepting writes
if not self.pinecone.has_index(index_name):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for entity in message.entities:
if entity.entity.value == "" or entity.entity.value is None:
@ -143,6 +130,17 @@ class Processor(GraphEmbeddingsStoreService):
for vec in entity.vectors:
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
@ -218,23 +216,19 @@ class Processor(GraphEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_create_collection(self, request):
"""Create a Pinecone index for graph embeddings"""
"""
No-op for collection creation - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
index_name = f"t-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
logger.info(f"Pinecone index {index_name} already exists")
else:
# Create with default dimension - will need to be recreated if dimension doesn't match
self.create_index(index_name, dim=384)
logger.info(f"Created Pinecone index: {index_name}")
logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
logger.error(f"Failed to handle create collection request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
@ -244,22 +238,34 @@ class Processor(GraphEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
"""
Delete all dimension variants of the index for graph embeddings.
Since indexes are created with dimension suffixes (e.g., t-user-coll-384),
we need to find and delete all matching indexes.
"""
try:
index_name = f"t-{request.user}-{request.collection}"
prefix = f"t-{request.user}-{request.collection}-"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
matching_indexes = [
idx.name for idx in all_indexes
if idx.name.startswith(prefix)
]
if not matching_indexes:
logger.info(f"No indexes found matching prefix {prefix}")
else:
logger.info(f"Index {index_name} does not exist, nothing to delete")
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -69,22 +69,6 @@ class Processor(GraphEmbeddingsStoreService):
metrics=storage_response_metrics,
)
def get_collection(self, user, collection):
"""Get collection name and validate it exists"""
cname = (
"t_" + user + "_" + collection
)
if not self.qdrant.collection_exists(cname):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
return cname
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
@ -101,10 +85,23 @@ class Processor(GraphEmbeddingsStoreService):
for vec in entity.vectors:
collection = self.get_collection(
message.metadata.user, message.metadata.collection
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection,
points=[
@ -165,30 +162,19 @@ class Processor(GraphEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_create_collection(self, request):
"""Create a Qdrant collection for graph embeddings"""
"""
No-op for collection creation - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
collection_name = f"t_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
logger.info(f"Qdrant collection {collection_name} already exists")
else:
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
# Using a placeholder dimension - actual dimension determined by first embedding
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=384, # Default dimension, common for many models
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {collection_name}")
logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
logger.error(f"Failed to handle create collection request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
@ -198,22 +184,34 @@ class Processor(GraphEmbeddingsStoreService):
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
"""
Delete all dimension variants of the collection for graph embeddings.
Since collections are created with dimension suffixes (e.g., t_user_coll_384),
we need to find and delete all matching collections.
"""
try:
collection_name = f"t_{request.user}_{request.collection}"
prefix = f"t_{request.user}_{request.collection}_"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No collections found matching prefix {prefix}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"pulsar-client",
"prometheus-client",
"boto3",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"pulsar-client",
"google-cloud-aiplatform",
"prometheus-client",

View file

@ -10,12 +10,12 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=1.4,<1.5",
"trustgraph-bedrock>=1.4,<1.5",
"trustgraph-cli>=1.4,<1.5",
"trustgraph-embeddings-hf>=1.4,<1.5",
"trustgraph-flow>=1.4,<1.5",
"trustgraph-vertexai>=1.4,<1.5",
"trustgraph-base>=1.5,<1.6",
"trustgraph-bedrock>=1.5,<1.6",
"trustgraph-cli>=1.5,<1.6",
"trustgraph-embeddings-hf>=1.5,<1.6",
"trustgraph-flow>=1.5,<1.6",
"trustgraph-vertexai>=1.5,<1.6",
]
classifiers = [
"Programming Language :: Python :: 3",